pubsub: remove subscribe args, more broadcast args

Remove the subscribe time args, as they weren't used.
With that flexibility we can add many args at broadcast time.
This commit is contained in:
Dan Bungert 2021-09-24 13:06:41 -06:00
parent 350ce11dd9
commit af8dcfe6c3
2 changed files with 29 additions and 35 deletions

View File

@ -26,17 +26,15 @@ class MessageHub:
def __init__(self):
self.subscriptions = {}
def subscribe(self, channel, method, *args):
self.subscriptions.setdefault(channel, []).append((method, args))
def subscribe(self, channel, method):
self.subscriptions.setdefault(channel, []).append(method)
async def abroadcast(self, channel, data=None):
for m, args in self.subscriptions.get(channel, []):
if data:
args = [data] + list(args)
v = m(*args)
async def abroadcast(self, channel, *args, **kwargs):
for m in self.subscriptions.get(channel, []):
v = m(*args, **kwargs)
if inspect.iscoroutine(v):
await v
def broadcast(self, channel, data=None):
def broadcast(self, channel, *args, **kwargs):
loop = asyncio.get_event_loop()
return loop.create_task(self.abroadcast(channel, data))
return loop.create_task(self.abroadcast(channel, *args, **kwargs))

View File

@ -13,42 +13,38 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from unittest.mock import MagicMock
from subiquitycore.tests import SubiTestCase
from subiquitycore.pubsub import MessageHub
from subiquitycore.tests.util import run_coro
class TestMessageHub(SubiTestCase):
def test_basic(self):
def cb(actual_private):
self.assertEqual(private_data, actual_private)
nonlocal actual_calls
actual_calls += 1
def setUp(self):
self.hub = MessageHub()
actual_calls = 0
def test_multicall(self):
cb = MagicMock()
expected_calls = 3
channel_id = 1234
private_data = 42
hub = MessageHub()
for _ in range(expected_calls):
hub.subscribe(channel_id, cb, private_data)
run_coro(hub.abroadcast(channel_id))
self.assertEqual(expected_calls, actual_calls)
self.hub.subscribe(channel_id, cb)
run_coro(self.hub.abroadcast(channel_id))
self.assertEqual(expected_calls, cb.call_count)
def test_multisubscriber(self):
cbs = [MagicMock() for _ in range(4)]
channel_id = 2345
for cb in cbs:
self.hub.subscribe(channel_id, cb)
run_coro(self.hub.abroadcast(channel_id))
for cb in cbs:
cb.assert_called_once_with()
def test_message_arg(self):
def cb(zero, one, two, three, *args):
self.assertEqual(broadcast_data, zero)
self.assertEqual(1, one)
self.assertEqual('two', two)
self.assertEqual([3], three)
self.assertEqual(0, len(args))
nonlocal called
called = True
called = False
cb = MagicMock()
channel_id = 'test-message-arg'
broadcast_data = 'broadcast-data'
hub = MessageHub()
hub.subscribe(channel_id, cb, 1, 'two', [3])
run_coro(hub.abroadcast(channel_id, broadcast_data))
self.assertTrue(called)
self.hub.subscribe(channel_id, cb)
run_coro(self.hub.abroadcast(channel_id, '0', 1, 'two', [3], four=4))
cb.assert_called_once_with('0', 1, 'two', [3], four=4)