From 6a189dd59872ac62978e3966a5fea530f3d566a8 Mon Sep 17 00:00:00 2001 From: Dan Bungert Date: Mon, 14 Jun 2021 16:05:27 -0600 Subject: [PATCH] Add EventCallback, and move mirror task things to GeoIP (#983) Move mock_app to common location. Move run_coro to subiquitycore so that subiquitycore doesn't have to reference things in subiquity, even for test. Move task tracking things from mirror to geoip. Server app owns the geoip instance. Create EventCallback as an alternative to MessageHub that should hopefully express clearer intermodule dependencies. --- .../controllers/tests/test_chooser.py | 25 +----- subiquity/common/api/tests/test_endtoend.py | 2 +- subiquity/common/api/tests/test_server.py | 2 +- subiquity/server/controllers/mirror.py | 41 ++-------- subiquity/{common => server}/geoip.py | 79 ++++++++++++++----- subiquity/server/server.py | 4 +- subiquity/server/tests/__init__.py | 0 .../{common => server}/tests/test_geoip.py | 18 ++--- subiquitycore/pubsub.py | 18 +++++ subiquitycore/tests/mocks.py | 42 ++++++++++ subiquitycore/tests/test_pubsub.py | 53 +++++++++++++ {subiquity => subiquitycore}/tests/util.py | 0 12 files changed, 194 insertions(+), 90 deletions(-) rename subiquity/{common => server}/geoip.py (57%) create mode 100644 subiquity/server/tests/__init__.py rename subiquity/{common => server}/tests/test_geoip.py (88%) create mode 100644 subiquitycore/tests/mocks.py create mode 100644 subiquitycore/tests/test_pubsub.py rename {subiquity => subiquitycore}/tests/util.py (100%) diff --git a/console_conf/controllers/tests/test_chooser.py b/console_conf/controllers/tests/test_chooser.py index f43fc789..601515f7 100644 --- a/console_conf/controllers/tests/test_chooser.py +++ b/console_conf/controllers/tests/test_chooser.py @@ -15,7 +15,7 @@ import unittest from unittest import mock -from subiquitycore.context import Context +from subiquitycore.tests.mocks import make_app from console_conf.controllers.chooser import ( RecoveryChooserController, RecoveryChooserConfirmController, @@ -26,29 +26,6 @@ from console_conf.models.systems import ( ) -class MockedApplication: - signal = loop = None - project = "mini" - autoinstall_config = {} - answers = {} - opts = None - - -def make_app(model=None): - app = MockedApplication() - app.ui = mock.Mock() - if model is not None: - app.base_model = model - else: - app.base_model = mock.Mock() - app.context = Context.new(app) - app.exit = mock.Mock() - app.respond = mock.Mock() - app.next_screen = mock.Mock() - app.prev_screen = mock.Mock() - return app - - model1_non_current = { "current": False, "label": "1234", diff --git a/subiquity/common/api/tests/test_endtoend.py b/subiquity/common/api/tests/test_endtoend.py index eb3d7faa..9e844edc 100644 --- a/subiquity/common/api/tests/test_endtoend.py +++ b/subiquity/common/api/tests/test_endtoend.py @@ -21,10 +21,10 @@ import aiohttp from aiohttp import web from subiquitycore import contextlib38 +from subiquitycore.tests.util import run_coro from subiquity.common.api.client import make_client from subiquity.common.api.defs import api, Payload -from subiquity.tests.util import run_coro from .test_server import ( makeTestClient, diff --git a/subiquity/common/api/tests/test_server.py b/subiquity/common/api/tests/test_server.py index b6b5c327..983cf2b4 100644 --- a/subiquity/common/api/tests/test_server.py +++ b/subiquity/common/api/tests/test_server.py @@ -19,6 +19,7 @@ from aiohttp.test_utils import TestClient, TestServer from aiohttp import web from subiquitycore.context import Context +from subiquitycore.tests.util import run_coro from subiquitycore import contextlib38 from subiquity.common.api.defs import api, Payload @@ -28,7 +29,6 @@ from subiquity.common.api.server import ( MissingImplementationError, SignatureMisatchError, ) -from subiquity.tests.util import run_coro class TestApp: diff --git a/subiquity/server/controllers/mirror.py b/subiquity/server/controllers/mirror.py index 25283cfc..15e3610a 100644 --- a/subiquity/server/controllers/mirror.py +++ b/subiquity/server/controllers/mirror.py @@ -14,28 +14,18 @@ # along with this program. If not, see . import asyncio -import enum import logging from curtin.config import merge_config -from subiquitycore.async_helpers import SingleInstanceTask from subiquitycore.context import with_context from subiquity.common.apidef import API -from subiquity.common.geoip import GeoIP from subiquity.server.controller import SubiquityController log = logging.getLogger('subiquity.server.controllers.mirror') -class CheckState(enum.IntEnum): - NOT_STARTED = enum.auto() - CHECKING = enum.auto() - FAILED = enum.auto() - DONE = enum.auto() - - class MirrorController(SubiquityController): endpoint = API.mirror @@ -55,11 +45,8 @@ class MirrorController(SubiquityController): def __init__(self, app): super().__init__(app) self.geoip_enabled = True - self.check_state = CheckState.NOT_STARTED - self.lookup_task = SingleInstanceTask(self.lookup) - self.geoip = GeoIP() - self.app.hub.subscribe('network-up', self.maybe_start_check) - self.app.hub.subscribe('network-proxy-set', self.maybe_start_check) + self.app.geoip.on_countrycode.subscribe(self.on_countrycode) + self.cc_event = asyncio.Event() def load_autoinstall_data(self, data): if data is None: @@ -72,30 +59,16 @@ class MirrorController(SubiquityController): async def apply_autoinstall_config(self, context): if not self.geoip_enabled: return - if self.lookup_task.task is None: - return try: with context.child('waiting'): - await asyncio.wait_for(self.lookup_task.wait(), 10) + await asyncio.wait_for(self.cc_event.wait(), 10) except asyncio.TimeoutError: pass - def maybe_start_check(self): - if not self.geoip_enabled: - return - if self.check_state != CheckState.DONE: - self.check_state = CheckState.CHECKING - self.lookup_task.start_sync() - - @with_context() - async def lookup(self, context): - if await self.geoip.lookup(): - cc = self.geoip.countrycode - if cc: - self.model.set_country(cc) - self.check_state = CheckState.DONE - return - self.check_state = CheckState.FAILED + def on_countrycode(self, cc): + if self.geoip_enabled: + self.model.set_country(cc) + self.cc_event.set() def serialize(self): return self.model.get_mirror() diff --git a/subiquity/common/geoip.py b/subiquity/server/geoip.py similarity index 57% rename from subiquity/common/geoip.py rename to subiquity/server/geoip.py index 011752d7..667f5799 100644 --- a/subiquity/common/geoip.py +++ b/subiquity/server/geoip.py @@ -14,19 +14,53 @@ # along with this program. If not, see . import logging +import enum import requests from xml.etree import ElementTree -from subiquitycore.async_helpers import run_in_thread +from subiquitycore.async_helpers import ( + run_in_thread, + SingleInstanceTask, +) +from subiquitycore.pubsub import EventCallback log = logging.getLogger('subiquity.common.geoip') +class CheckState(enum.IntEnum): + NOT_STARTED = enum.auto() + CHECKING = enum.auto() + FAILED = enum.auto() + DONE = enum.auto() + + class GeoIP: - def __init__(self): + def __init__(self, app): + self.app = app self.element = None + self.cc = None + self.tz = None + self.check_state = CheckState.NOT_STARTED + self.on_countrycode = EventCallback() + self.on_timezone = EventCallback() + self.lookup_task = SingleInstanceTask(self.lookup) + self.app.hub.subscribe('network-up', self.maybe_start_check) + self.app.hub.subscribe('network-proxy-set', self.maybe_start_check) + + def maybe_start_check(self): + if self.check_state != CheckState.DONE: + self.check_state = CheckState.CHECKING + self.lookup_task.start_sync() async def lookup(self): + rv = await self._lookup() + if rv: + self.check_state = CheckState.DONE + else: + self.check_state = CheckState.FAILED + return rv + + async def _lookup(self): try: response = await run_in_thread( requests.get, "https://geoip.ubuntu.com/lookup") @@ -40,28 +74,33 @@ class GeoIP: except ElementTree.ParseError: log.exception("parsing %r failed", self.response_text) return False + + cc = self.element.find("CountryCode") + if cc is None or cc.text is None: + log.debug("no CountryCode found in %r", self.response_text) + return False + cc = cc.text.lower() + if len(cc) != 2: + log.debug("bogus CountryCode found in %r", self.response_text) + return False + if cc != self.cc: + self.on_countrycode.broadcast(cc) + self.cc = cc + + tz = self.element.find("TimeZone") + if tz is None or not tz.text: + log.debug("no TimeZone found in %r", self.response_text) + return False + if tz != self.tz: + self.on_timezone.broadcast(tz) + self.tz = tz.text + return True @property def countrycode(self): - if not self.element: - return None - cc = self.element.find("CountryCode") - if cc is None or cc.text is None: - log.debug("no CountryCode found in %r", self.response_text) - return None - cc = cc.text.lower() - if len(cc) != 2: - log.debug("bogus CountryCode found in %r", self.response_text) - return None - return cc + return self.cc @property def timezone(self): - if not self.element: - return None - tz = self.element.find("TimeZone") - if tz is None or not tz.text: - log.debug("no TimeZone found in %r", self.response_text) - return None - return tz.text + return self.tz diff --git a/subiquity/server/server.py b/subiquity/server/server.py index 70ece7c5..50b8f843 100644 --- a/subiquity/server/server.py +++ b/subiquity/server/server.py @@ -61,8 +61,9 @@ from subiquity.common.types import ( LiveSessionSSHInfo, PasswordKind, ) -from subiquity.server.controller import SubiquityController from subiquity.models.subiquity import SubiquityModel +from subiquity.server.controller import SubiquityController +from subiquity.server.geoip import GeoIP from subiquity.server.errors import ErrorController from subiquitycore.snapd import ( AsyncSnapd, @@ -246,6 +247,7 @@ class SubiquityServer(Application): self.autoinstall_config = None self.hub.subscribe('network-up', self._network_change) self.hub.subscribe('network-proxy-set', self._proxy_set) + self.geoip = GeoIP(self) def load_serialized_state(self): for controller in self.controllers.instances: diff --git a/subiquity/server/tests/__init__.py b/subiquity/server/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/subiquity/common/tests/test_geoip.py b/subiquity/server/tests/test_geoip.py similarity index 88% rename from subiquity/common/tests/test_geoip.py rename to subiquity/server/tests/test_geoip.py index ca9d2217..8cf71291 100644 --- a/subiquity/common/tests/test_geoip.py +++ b/subiquity/server/tests/test_geoip.py @@ -16,9 +16,9 @@ import mock from subiquitycore.tests import SubiTestCase -from subiquity.common.geoip import GeoIP -from subiquity.tests.util import run_coro - +from subiquitycore.tests.mocks import make_app +from subiquitycore.tests.util import run_coro +from subiquity.server.geoip import GeoIP xml = ''' @@ -62,7 +62,7 @@ def requests_get_factory(text): class TestGeoIP(SubiTestCase): @mock.patch('requests.get', new=requests_get_factory(xml)) def setUp(self): - self.geoip = GeoIP() + self.geoip = GeoIP(make_app()) async def fn(): self.assertTrue(await self.geoip.lookup()) @@ -77,7 +77,7 @@ class TestGeoIP(SubiTestCase): class TestGeoIPBadData(SubiTestCase): def setUp(self): - self.geoip = GeoIP() + self.geoip = GeoIP(make_app()) @mock.patch('requests.get', new=requests_get_factory(partial)) def test_partial_reponse(self): @@ -88,7 +88,7 @@ class TestGeoIPBadData(SubiTestCase): @mock.patch('requests.get', new=requests_get_factory(incomplete)) def test_incomplete(self): async def fn(): - self.assertTrue(await self.geoip.lookup()) + self.assertFalse(await self.geoip.lookup()) run_coro(fn()) self.assertIsNone(self.geoip.countrycode) self.assertIsNone(self.geoip.timezone) @@ -96,20 +96,20 @@ class TestGeoIPBadData(SubiTestCase): @mock.patch('requests.get', new=requests_get_factory(long_cc)) def test_long_cc(self): async def fn(): - self.assertTrue(await self.geoip.lookup()) + self.assertFalse(await self.geoip.lookup()) run_coro(fn()) self.assertIsNone(self.geoip.countrycode) @mock.patch('requests.get', new=requests_get_factory(empty_cc)) def test_empty_cc(self): async def fn(): - self.assertTrue(await self.geoip.lookup()) + self.assertFalse(await self.geoip.lookup()) run_coro(fn()) self.assertIsNone(self.geoip.countrycode) @mock.patch('requests.get', new=requests_get_factory(empty_tz)) def test_empty_tz(self): async def fn(): - self.assertTrue(await self.geoip.lookup()) + self.assertFalse(await self.geoip.lookup()) run_coro(fn()) self.assertIsNone(self.geoip.timezone) diff --git a/subiquitycore/pubsub.py b/subiquitycore/pubsub.py index de29bdb6..7b3f81cd 100644 --- a/subiquitycore/pubsub.py +++ b/subiquitycore/pubsub.py @@ -33,3 +33,21 @@ class MessageHub: def broadcast(self, channel): return asyncio.get_event_loop().create_task(self.abroadcast(channel)) + + +class EventCallback: + + def __init__(self): + self.subscriptions = [] + + def subscribe(self, method, *args): + self.subscriptions.append((method, args)) + + async def abroadcast(self, cbdata): + for m, args in self.subscriptions: + v = m(cbdata, *args) + if inspect.iscoroutine(v): + await v + + def broadcast(self, cbdata): + return asyncio.get_event_loop().create_task(self.abroadcast(cbdata)) diff --git a/subiquitycore/tests/mocks.py b/subiquitycore/tests/mocks.py new file mode 100644 index 00000000..e27b0873 --- /dev/null +++ b/subiquitycore/tests/mocks.py @@ -0,0 +1,42 @@ +# Copyright 2020-2021 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from unittest import mock + +from subiquitycore.context import Context +from subiquitycore.pubsub import MessageHub + + +class MockedApplication: + signal = loop = None + project = "mini" + autoinstall_config = {} + answers = {} + opts = None + + +def make_app(model=None): + app = MockedApplication() + app.ui = mock.Mock() + if model is not None: + app.base_model = model + else: + app.base_model = mock.Mock() + app.context = Context.new(app) + app.exit = mock.Mock() + app.respond = mock.Mock() + app.next_screen = mock.Mock() + app.prev_screen = mock.Mock() + app.hub = MessageHub() + return app diff --git a/subiquitycore/tests/test_pubsub.py b/subiquitycore/tests/test_pubsub.py new file mode 100644 index 00000000..ff732a34 --- /dev/null +++ b/subiquitycore/tests/test_pubsub.py @@ -0,0 +1,53 @@ +# Copyright 2021 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import asyncio + +from subiquitycore.tests import SubiTestCase +from subiquitycore.pubsub import EventCallback +from subiquitycore.tests.util import run_coro + + +async def wait_other_tasks(): + if hasattr(asyncio, 'all_tasks'): + tasks = asyncio.all_tasks() # py 3.7+ + tasks.remove(asyncio.current_task()) + else: + tasks = asyncio.Task.all_tasks() # py 3.6 + tasks.remove(asyncio.Task.current_task()) + await asyncio.wait(tasks) + + +class TestEventCallback(SubiTestCase): + def test_basic(self): + def job(): + self.thething.broadcast(42) + + def cb(val, mydata): + self.assertEqual(42, val) + self.assertEqual('bacon', mydata) + self.called += 1 + + async def fn(): + self.called = 0 + self.thething = EventCallback() + calls_expected = 3 + for _ in range(calls_expected): + self.thething.subscribe(cb, 'bacon') + job() + await wait_other_tasks() + self.assertEqual(calls_expected, self.called) + + run_coro(fn()) diff --git a/subiquity/tests/util.py b/subiquitycore/tests/util.py similarity index 100% rename from subiquity/tests/util.py rename to subiquitycore/tests/util.py