pubsub: optional argument at broadcast time

This commit is contained in:
Dan Bungert 2021-09-22 13:31:53 -06:00
parent 2f9c22961e
commit 350ce11dd9
2 changed files with 35 additions and 16 deletions

View File

@ -29,11 +29,14 @@ class MessageHub:
def subscribe(self, channel, method, *args):
self.subscriptions.setdefault(channel, []).append((method, args))
async def abroadcast(self, channel):
async def abroadcast(self, channel, data=None):
for m, args in self.subscriptions.get(channel, []):
if data:
args = [data] + list(args)
v = m(*args)
if inspect.iscoroutine(v):
await v
def broadcast(self, channel):
return asyncio.get_event_loop().create_task(self.abroadcast(channel))
def broadcast(self, channel, data=None):
loop = asyncio.get_event_loop()
return loop.create_task(self.abroadcast(channel, data))

View File

@ -20,19 +20,35 @@ from subiquitycore.tests.util import run_coro
class TestMessageHub(SubiTestCase):
def test_basic(self):
def cb(mydata):
self.assertEqual(private_data, mydata)
self.called += 1
def cb(actual_private):
self.assertEqual(private_data, actual_private)
nonlocal actual_calls
actual_calls += 1
async def fn():
calls_expected = 3
for _ in range(calls_expected):
self.hub.subscribe(channel_id, cb, private_data)
await self.hub.broadcast(channel_id)
self.assertEqual(calls_expected, self.called)
self.called = 0
actual_calls = 0
expected_calls = 3
channel_id = 1234
private_data = 42
self.hub = MessageHub()
run_coro(fn())
hub = MessageHub()
for _ in range(expected_calls):
hub.subscribe(channel_id, cb, private_data)
run_coro(hub.abroadcast(channel_id))
self.assertEqual(expected_calls, actual_calls)
def test_message_arg(self):
def cb(zero, one, two, three, *args):
self.assertEqual(broadcast_data, zero)
self.assertEqual(1, one)
self.assertEqual('two', two)
self.assertEqual([3], three)
self.assertEqual(0, len(args))
nonlocal called
called = True
called = False
channel_id = 'test-message-arg'
broadcast_data = 'broadcast-data'
hub = MessageHub()
hub.subscribe(channel_id, cb, 1, 'two', [3])
run_coro(hub.abroadcast(channel_id, broadcast_data))
self.assertTrue(called)