diff --git a/console_conf/controllers/tests/test_chooser.py b/console_conf/controllers/tests/test_chooser.py index f43fc789..601515f7 100644 --- a/console_conf/controllers/tests/test_chooser.py +++ b/console_conf/controllers/tests/test_chooser.py @@ -15,7 +15,7 @@ import unittest from unittest import mock -from subiquitycore.context import Context +from subiquitycore.tests.mocks import make_app from console_conf.controllers.chooser import ( RecoveryChooserController, RecoveryChooserConfirmController, @@ -26,29 +26,6 @@ from console_conf.models.systems import ( ) -class MockedApplication: - signal = loop = None - project = "mini" - autoinstall_config = {} - answers = {} - opts = None - - -def make_app(model=None): - app = MockedApplication() - app.ui = mock.Mock() - if model is not None: - app.base_model = model - else: - app.base_model = mock.Mock() - app.context = Context.new(app) - app.exit = mock.Mock() - app.respond = mock.Mock() - app.next_screen = mock.Mock() - app.prev_screen = mock.Mock() - return app - - model1_non_current = { "current": False, "label": "1234", diff --git a/subiquity/common/api/tests/test_endtoend.py b/subiquity/common/api/tests/test_endtoend.py index eb3d7faa..9e844edc 100644 --- a/subiquity/common/api/tests/test_endtoend.py +++ b/subiquity/common/api/tests/test_endtoend.py @@ -21,10 +21,10 @@ import aiohttp from aiohttp import web from subiquitycore import contextlib38 +from subiquitycore.tests.util import run_coro from subiquity.common.api.client import make_client from subiquity.common.api.defs import api, Payload -from subiquity.tests.util import run_coro from .test_server import ( makeTestClient, diff --git a/subiquity/common/api/tests/test_server.py b/subiquity/common/api/tests/test_server.py index b6b5c327..983cf2b4 100644 --- a/subiquity/common/api/tests/test_server.py +++ b/subiquity/common/api/tests/test_server.py @@ -19,6 +19,7 @@ from aiohttp.test_utils import TestClient, TestServer from aiohttp import web from subiquitycore.context import Context +from subiquitycore.tests.util import run_coro from subiquitycore import contextlib38 from subiquity.common.api.defs import api, Payload @@ -28,7 +29,6 @@ from subiquity.common.api.server import ( MissingImplementationError, SignatureMisatchError, ) -from subiquity.tests.util import run_coro class TestApp: diff --git a/subiquity/server/controllers/mirror.py b/subiquity/server/controllers/mirror.py index 25283cfc..15e3610a 100644 --- a/subiquity/server/controllers/mirror.py +++ b/subiquity/server/controllers/mirror.py @@ -14,28 +14,18 @@ # along with this program. If not, see . import asyncio -import enum import logging from curtin.config import merge_config -from subiquitycore.async_helpers import SingleInstanceTask from subiquitycore.context import with_context from subiquity.common.apidef import API -from subiquity.common.geoip import GeoIP from subiquity.server.controller import SubiquityController log = logging.getLogger('subiquity.server.controllers.mirror') -class CheckState(enum.IntEnum): - NOT_STARTED = enum.auto() - CHECKING = enum.auto() - FAILED = enum.auto() - DONE = enum.auto() - - class MirrorController(SubiquityController): endpoint = API.mirror @@ -55,11 +45,8 @@ class MirrorController(SubiquityController): def __init__(self, app): super().__init__(app) self.geoip_enabled = True - self.check_state = CheckState.NOT_STARTED - self.lookup_task = SingleInstanceTask(self.lookup) - self.geoip = GeoIP() - self.app.hub.subscribe('network-up', self.maybe_start_check) - self.app.hub.subscribe('network-proxy-set', self.maybe_start_check) + self.app.geoip.on_countrycode.subscribe(self.on_countrycode) + self.cc_event = asyncio.Event() def load_autoinstall_data(self, data): if data is None: @@ -72,30 +59,16 @@ class MirrorController(SubiquityController): async def apply_autoinstall_config(self, context): if not self.geoip_enabled: return - if self.lookup_task.task is None: - return try: with context.child('waiting'): - await asyncio.wait_for(self.lookup_task.wait(), 10) + await asyncio.wait_for(self.cc_event.wait(), 10) except asyncio.TimeoutError: pass - def maybe_start_check(self): - if not self.geoip_enabled: - return - if self.check_state != CheckState.DONE: - self.check_state = CheckState.CHECKING - self.lookup_task.start_sync() - - @with_context() - async def lookup(self, context): - if await self.geoip.lookup(): - cc = self.geoip.countrycode - if cc: - self.model.set_country(cc) - self.check_state = CheckState.DONE - return - self.check_state = CheckState.FAILED + def on_countrycode(self, cc): + if self.geoip_enabled: + self.model.set_country(cc) + self.cc_event.set() def serialize(self): return self.model.get_mirror() diff --git a/subiquity/common/geoip.py b/subiquity/server/geoip.py similarity index 57% rename from subiquity/common/geoip.py rename to subiquity/server/geoip.py index 011752d7..667f5799 100644 --- a/subiquity/common/geoip.py +++ b/subiquity/server/geoip.py @@ -14,19 +14,53 @@ # along with this program. If not, see . import logging +import enum import requests from xml.etree import ElementTree -from subiquitycore.async_helpers import run_in_thread +from subiquitycore.async_helpers import ( + run_in_thread, + SingleInstanceTask, +) +from subiquitycore.pubsub import EventCallback log = logging.getLogger('subiquity.common.geoip') +class CheckState(enum.IntEnum): + NOT_STARTED = enum.auto() + CHECKING = enum.auto() + FAILED = enum.auto() + DONE = enum.auto() + + class GeoIP: - def __init__(self): + def __init__(self, app): + self.app = app self.element = None + self.cc = None + self.tz = None + self.check_state = CheckState.NOT_STARTED + self.on_countrycode = EventCallback() + self.on_timezone = EventCallback() + self.lookup_task = SingleInstanceTask(self.lookup) + self.app.hub.subscribe('network-up', self.maybe_start_check) + self.app.hub.subscribe('network-proxy-set', self.maybe_start_check) + + def maybe_start_check(self): + if self.check_state != CheckState.DONE: + self.check_state = CheckState.CHECKING + self.lookup_task.start_sync() async def lookup(self): + rv = await self._lookup() + if rv: + self.check_state = CheckState.DONE + else: + self.check_state = CheckState.FAILED + return rv + + async def _lookup(self): try: response = await run_in_thread( requests.get, "https://geoip.ubuntu.com/lookup") @@ -40,28 +74,33 @@ class GeoIP: except ElementTree.ParseError: log.exception("parsing %r failed", self.response_text) return False + + cc = self.element.find("CountryCode") + if cc is None or cc.text is None: + log.debug("no CountryCode found in %r", self.response_text) + return False + cc = cc.text.lower() + if len(cc) != 2: + log.debug("bogus CountryCode found in %r", self.response_text) + return False + if cc != self.cc: + self.on_countrycode.broadcast(cc) + self.cc = cc + + tz = self.element.find("TimeZone") + if tz is None or not tz.text: + log.debug("no TimeZone found in %r", self.response_text) + return False + if tz != self.tz: + self.on_timezone.broadcast(tz) + self.tz = tz.text + return True @property def countrycode(self): - if not self.element: - return None - cc = self.element.find("CountryCode") - if cc is None or cc.text is None: - log.debug("no CountryCode found in %r", self.response_text) - return None - cc = cc.text.lower() - if len(cc) != 2: - log.debug("bogus CountryCode found in %r", self.response_text) - return None - return cc + return self.cc @property def timezone(self): - if not self.element: - return None - tz = self.element.find("TimeZone") - if tz is None or not tz.text: - log.debug("no TimeZone found in %r", self.response_text) - return None - return tz.text + return self.tz diff --git a/subiquity/server/server.py b/subiquity/server/server.py index 70ece7c5..50b8f843 100644 --- a/subiquity/server/server.py +++ b/subiquity/server/server.py @@ -61,8 +61,9 @@ from subiquity.common.types import ( LiveSessionSSHInfo, PasswordKind, ) -from subiquity.server.controller import SubiquityController from subiquity.models.subiquity import SubiquityModel +from subiquity.server.controller import SubiquityController +from subiquity.server.geoip import GeoIP from subiquity.server.errors import ErrorController from subiquitycore.snapd import ( AsyncSnapd, @@ -246,6 +247,7 @@ class SubiquityServer(Application): self.autoinstall_config = None self.hub.subscribe('network-up', self._network_change) self.hub.subscribe('network-proxy-set', self._proxy_set) + self.geoip = GeoIP(self) def load_serialized_state(self): for controller in self.controllers.instances: diff --git a/subiquity/server/tests/__init__.py b/subiquity/server/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/subiquity/common/tests/test_geoip.py b/subiquity/server/tests/test_geoip.py similarity index 88% rename from subiquity/common/tests/test_geoip.py rename to subiquity/server/tests/test_geoip.py index ca9d2217..8cf71291 100644 --- a/subiquity/common/tests/test_geoip.py +++ b/subiquity/server/tests/test_geoip.py @@ -16,9 +16,9 @@ import mock from subiquitycore.tests import SubiTestCase -from subiquity.common.geoip import GeoIP -from subiquity.tests.util import run_coro - +from subiquitycore.tests.mocks import make_app +from subiquitycore.tests.util import run_coro +from subiquity.server.geoip import GeoIP xml = ''' @@ -62,7 +62,7 @@ def requests_get_factory(text): class TestGeoIP(SubiTestCase): @mock.patch('requests.get', new=requests_get_factory(xml)) def setUp(self): - self.geoip = GeoIP() + self.geoip = GeoIP(make_app()) async def fn(): self.assertTrue(await self.geoip.lookup()) @@ -77,7 +77,7 @@ class TestGeoIP(SubiTestCase): class TestGeoIPBadData(SubiTestCase): def setUp(self): - self.geoip = GeoIP() + self.geoip = GeoIP(make_app()) @mock.patch('requests.get', new=requests_get_factory(partial)) def test_partial_reponse(self): @@ -88,7 +88,7 @@ class TestGeoIPBadData(SubiTestCase): @mock.patch('requests.get', new=requests_get_factory(incomplete)) def test_incomplete(self): async def fn(): - self.assertTrue(await self.geoip.lookup()) + self.assertFalse(await self.geoip.lookup()) run_coro(fn()) self.assertIsNone(self.geoip.countrycode) self.assertIsNone(self.geoip.timezone) @@ -96,20 +96,20 @@ class TestGeoIPBadData(SubiTestCase): @mock.patch('requests.get', new=requests_get_factory(long_cc)) def test_long_cc(self): async def fn(): - self.assertTrue(await self.geoip.lookup()) + self.assertFalse(await self.geoip.lookup()) run_coro(fn()) self.assertIsNone(self.geoip.countrycode) @mock.patch('requests.get', new=requests_get_factory(empty_cc)) def test_empty_cc(self): async def fn(): - self.assertTrue(await self.geoip.lookup()) + self.assertFalse(await self.geoip.lookup()) run_coro(fn()) self.assertIsNone(self.geoip.countrycode) @mock.patch('requests.get', new=requests_get_factory(empty_tz)) def test_empty_tz(self): async def fn(): - self.assertTrue(await self.geoip.lookup()) + self.assertFalse(await self.geoip.lookup()) run_coro(fn()) self.assertIsNone(self.geoip.timezone) diff --git a/subiquitycore/pubsub.py b/subiquitycore/pubsub.py index de29bdb6..7b3f81cd 100644 --- a/subiquitycore/pubsub.py +++ b/subiquitycore/pubsub.py @@ -33,3 +33,21 @@ class MessageHub: def broadcast(self, channel): return asyncio.get_event_loop().create_task(self.abroadcast(channel)) + + +class EventCallback: + + def __init__(self): + self.subscriptions = [] + + def subscribe(self, method, *args): + self.subscriptions.append((method, args)) + + async def abroadcast(self, cbdata): + for m, args in self.subscriptions: + v = m(cbdata, *args) + if inspect.iscoroutine(v): + await v + + def broadcast(self, cbdata): + return asyncio.get_event_loop().create_task(self.abroadcast(cbdata)) diff --git a/subiquitycore/tests/mocks.py b/subiquitycore/tests/mocks.py new file mode 100644 index 00000000..e27b0873 --- /dev/null +++ b/subiquitycore/tests/mocks.py @@ -0,0 +1,42 @@ +# Copyright 2020-2021 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from unittest import mock + +from subiquitycore.context import Context +from subiquitycore.pubsub import MessageHub + + +class MockedApplication: + signal = loop = None + project = "mini" + autoinstall_config = {} + answers = {} + opts = None + + +def make_app(model=None): + app = MockedApplication() + app.ui = mock.Mock() + if model is not None: + app.base_model = model + else: + app.base_model = mock.Mock() + app.context = Context.new(app) + app.exit = mock.Mock() + app.respond = mock.Mock() + app.next_screen = mock.Mock() + app.prev_screen = mock.Mock() + app.hub = MessageHub() + return app diff --git a/subiquitycore/tests/test_pubsub.py b/subiquitycore/tests/test_pubsub.py new file mode 100644 index 00000000..ff732a34 --- /dev/null +++ b/subiquitycore/tests/test_pubsub.py @@ -0,0 +1,53 @@ +# Copyright 2021 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import asyncio + +from subiquitycore.tests import SubiTestCase +from subiquitycore.pubsub import EventCallback +from subiquitycore.tests.util import run_coro + + +async def wait_other_tasks(): + if hasattr(asyncio, 'all_tasks'): + tasks = asyncio.all_tasks() # py 3.7+ + tasks.remove(asyncio.current_task()) + else: + tasks = asyncio.Task.all_tasks() # py 3.6 + tasks.remove(asyncio.Task.current_task()) + await asyncio.wait(tasks) + + +class TestEventCallback(SubiTestCase): + def test_basic(self): + def job(): + self.thething.broadcast(42) + + def cb(val, mydata): + self.assertEqual(42, val) + self.assertEqual('bacon', mydata) + self.called += 1 + + async def fn(): + self.called = 0 + self.thething = EventCallback() + calls_expected = 3 + for _ in range(calls_expected): + self.thething.subscribe(cb, 'bacon') + job() + await wait_other_tasks() + self.assertEqual(calls_expected, self.called) + + run_coro(fn()) diff --git a/subiquity/tests/util.py b/subiquitycore/tests/util.py similarity index 100% rename from subiquity/tests/util.py rename to subiquitycore/tests/util.py