diff --git a/subiquitycore/pubsub.py b/subiquitycore/pubsub.py index 3bbc406d..ff4e0075 100644 --- a/subiquitycore/pubsub.py +++ b/subiquitycore/pubsub.py @@ -26,17 +26,15 @@ class MessageHub: def __init__(self): self.subscriptions = {} - def subscribe(self, channel, method, *args): - self.subscriptions.setdefault(channel, []).append((method, args)) + def subscribe(self, channel, method): + self.subscriptions.setdefault(channel, []).append(method) - async def abroadcast(self, channel, data=None): - for m, args in self.subscriptions.get(channel, []): - if data: - args = [data] + list(args) - v = m(*args) + async def abroadcast(self, channel, *args, **kwargs): + for m in self.subscriptions.get(channel, []): + v = m(*args, **kwargs) if inspect.iscoroutine(v): await v - def broadcast(self, channel, data=None): + def broadcast(self, channel, *args, **kwargs): loop = asyncio.get_event_loop() - return loop.create_task(self.abroadcast(channel, data)) + return loop.create_task(self.abroadcast(channel, *args, **kwargs)) diff --git a/subiquitycore/tests/test_pubsub.py b/subiquitycore/tests/test_pubsub.py index 3084cced..ecff7b61 100644 --- a/subiquitycore/tests/test_pubsub.py +++ b/subiquitycore/tests/test_pubsub.py @@ -13,42 +13,38 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from unittest.mock import MagicMock + from subiquitycore.tests import SubiTestCase from subiquitycore.pubsub import MessageHub from subiquitycore.tests.util import run_coro class TestMessageHub(SubiTestCase): - def test_basic(self): - def cb(actual_private): - self.assertEqual(private_data, actual_private) - nonlocal actual_calls - actual_calls += 1 + def setUp(self): + self.hub = MessageHub() - actual_calls = 0 + def test_multicall(self): + cb = MagicMock() expected_calls = 3 channel_id = 1234 - private_data = 42 - hub = MessageHub() for _ in range(expected_calls): - hub.subscribe(channel_id, cb, private_data) - run_coro(hub.abroadcast(channel_id)) - self.assertEqual(expected_calls, actual_calls) + self.hub.subscribe(channel_id, cb) + run_coro(self.hub.abroadcast(channel_id)) + self.assertEqual(expected_calls, cb.call_count) + + def test_multisubscriber(self): + cbs = [MagicMock() for _ in range(4)] + channel_id = 2345 + for cb in cbs: + self.hub.subscribe(channel_id, cb) + run_coro(self.hub.abroadcast(channel_id)) + for cb in cbs: + cb.assert_called_once_with() def test_message_arg(self): - def cb(zero, one, two, three, *args): - self.assertEqual(broadcast_data, zero) - self.assertEqual(1, one) - self.assertEqual('two', two) - self.assertEqual([3], three) - self.assertEqual(0, len(args)) - nonlocal called - called = True - - called = False + cb = MagicMock() channel_id = 'test-message-arg' - broadcast_data = 'broadcast-data' - hub = MessageHub() - hub.subscribe(channel_id, cb, 1, 'two', [3]) - run_coro(hub.abroadcast(channel_id, broadcast_data)) - self.assertTrue(called) + self.hub.subscribe(channel_id, cb) + run_coro(self.hub.abroadcast(channel_id, '0', 1, 'two', [3], four=4)) + cb.assert_called_once_with('0', 1, 'two', [3], four=4)