diff --git a/subiquitycore/pubsub.py b/subiquitycore/pubsub.py index 4c915c0d..3bbc406d 100644 --- a/subiquitycore/pubsub.py +++ b/subiquitycore/pubsub.py @@ -29,11 +29,14 @@ class MessageHub: def subscribe(self, channel, method, *args): self.subscriptions.setdefault(channel, []).append((method, args)) - async def abroadcast(self, channel): + async def abroadcast(self, channel, data=None): for m, args in self.subscriptions.get(channel, []): + if data: + args = [data] + list(args) v = m(*args) if inspect.iscoroutine(v): await v - def broadcast(self, channel): - return asyncio.get_event_loop().create_task(self.abroadcast(channel)) + def broadcast(self, channel, data=None): + loop = asyncio.get_event_loop() + return loop.create_task(self.abroadcast(channel, data)) diff --git a/subiquitycore/tests/test_pubsub.py b/subiquitycore/tests/test_pubsub.py index 92aebf37..3084cced 100644 --- a/subiquitycore/tests/test_pubsub.py +++ b/subiquitycore/tests/test_pubsub.py @@ -20,19 +20,35 @@ from subiquitycore.tests.util import run_coro class TestMessageHub(SubiTestCase): def test_basic(self): - def cb(mydata): - self.assertEqual(private_data, mydata) - self.called += 1 + def cb(actual_private): + self.assertEqual(private_data, actual_private) + nonlocal actual_calls + actual_calls += 1 - async def fn(): - calls_expected = 3 - for _ in range(calls_expected): - self.hub.subscribe(channel_id, cb, private_data) - await self.hub.broadcast(channel_id) - self.assertEqual(calls_expected, self.called) - - self.called = 0 + actual_calls = 0 + expected_calls = 3 channel_id = 1234 private_data = 42 - self.hub = MessageHub() - run_coro(fn()) + hub = MessageHub() + for _ in range(expected_calls): + hub.subscribe(channel_id, cb, private_data) + run_coro(hub.abroadcast(channel_id)) + self.assertEqual(expected_calls, actual_calls) + + def test_message_arg(self): + def cb(zero, one, two, three, *args): + self.assertEqual(broadcast_data, zero) + self.assertEqual(1, one) + self.assertEqual('two', two) + self.assertEqual([3], three) + self.assertEqual(0, len(args)) + nonlocal called + called = True + + called = False + channel_id = 'test-message-arg' + broadcast_data = 'broadcast-data' + hub = MessageHub() + hub.subscribe(channel_id, cb, 1, 'two', [3]) + run_coro(hub.abroadcast(channel_id, broadcast_data)) + self.assertTrue(called)