pubsub: remove subscribe args, more broadcast args

Remove the subscribe time args, as they weren't used.
With that flexibility we can add many args at broadcast time.
This commit is contained in:
Dan Bungert 2021-09-24 13:06:41 -06:00
parent 350ce11dd9
commit af8dcfe6c3
2 changed files with 29 additions and 35 deletions

View File

@ -26,17 +26,15 @@ class MessageHub:
def __init__(self): def __init__(self):
self.subscriptions = {} self.subscriptions = {}
def subscribe(self, channel, method, *args): def subscribe(self, channel, method):
self.subscriptions.setdefault(channel, []).append((method, args)) self.subscriptions.setdefault(channel, []).append(method)
async def abroadcast(self, channel, data=None): async def abroadcast(self, channel, *args, **kwargs):
for m, args in self.subscriptions.get(channel, []): for m in self.subscriptions.get(channel, []):
if data: v = m(*args, **kwargs)
args = [data] + list(args)
v = m(*args)
if inspect.iscoroutine(v): if inspect.iscoroutine(v):
await v await v
def broadcast(self, channel, data=None): def broadcast(self, channel, *args, **kwargs):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.create_task(self.abroadcast(channel, data)) return loop.create_task(self.abroadcast(channel, *args, **kwargs))

View File

@ -13,42 +13,38 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from unittest.mock import MagicMock
from subiquitycore.tests import SubiTestCase from subiquitycore.tests import SubiTestCase
from subiquitycore.pubsub import MessageHub from subiquitycore.pubsub import MessageHub
from subiquitycore.tests.util import run_coro from subiquitycore.tests.util import run_coro
class TestMessageHub(SubiTestCase): class TestMessageHub(SubiTestCase):
def test_basic(self): def setUp(self):
def cb(actual_private): self.hub = MessageHub()
self.assertEqual(private_data, actual_private)
nonlocal actual_calls
actual_calls += 1
actual_calls = 0 def test_multicall(self):
cb = MagicMock()
expected_calls = 3 expected_calls = 3
channel_id = 1234 channel_id = 1234
private_data = 42
hub = MessageHub()
for _ in range(expected_calls): for _ in range(expected_calls):
hub.subscribe(channel_id, cb, private_data) self.hub.subscribe(channel_id, cb)
run_coro(hub.abroadcast(channel_id)) run_coro(self.hub.abroadcast(channel_id))
self.assertEqual(expected_calls, actual_calls) self.assertEqual(expected_calls, cb.call_count)
def test_multisubscriber(self):
cbs = [MagicMock() for _ in range(4)]
channel_id = 2345
for cb in cbs:
self.hub.subscribe(channel_id, cb)
run_coro(self.hub.abroadcast(channel_id))
for cb in cbs:
cb.assert_called_once_with()
def test_message_arg(self): def test_message_arg(self):
def cb(zero, one, two, three, *args): cb = MagicMock()
self.assertEqual(broadcast_data, zero)
self.assertEqual(1, one)
self.assertEqual('two', two)
self.assertEqual([3], three)
self.assertEqual(0, len(args))
nonlocal called
called = True
called = False
channel_id = 'test-message-arg' channel_id = 'test-message-arg'
broadcast_data = 'broadcast-data' self.hub.subscribe(channel_id, cb)
hub = MessageHub() run_coro(self.hub.abroadcast(channel_id, '0', 1, 'two', [3], four=4))
hub.subscribe(channel_id, cb, 1, 'two', [3]) cb.assert_called_once_with('0', 1, 'two', [3], four=4)
run_coro(hub.abroadcast(channel_id, broadcast_data))
self.assertTrue(called)