diff --git a/console_conf/controllers/tests/test_chooser.py b/console_conf/controllers/tests/test_chooser.py
index f43fc789..601515f7 100644
--- a/console_conf/controllers/tests/test_chooser.py
+++ b/console_conf/controllers/tests/test_chooser.py
@@ -15,7 +15,7 @@
import unittest
from unittest import mock
-from subiquitycore.context import Context
+from subiquitycore.tests.mocks import make_app
from console_conf.controllers.chooser import (
RecoveryChooserController,
RecoveryChooserConfirmController,
@@ -26,29 +26,6 @@ from console_conf.models.systems import (
)
-class MockedApplication:
- signal = loop = None
- project = "mini"
- autoinstall_config = {}
- answers = {}
- opts = None
-
-
-def make_app(model=None):
- app = MockedApplication()
- app.ui = mock.Mock()
- if model is not None:
- app.base_model = model
- else:
- app.base_model = mock.Mock()
- app.context = Context.new(app)
- app.exit = mock.Mock()
- app.respond = mock.Mock()
- app.next_screen = mock.Mock()
- app.prev_screen = mock.Mock()
- return app
-
-
model1_non_current = {
"current": False,
"label": "1234",
diff --git a/subiquity/common/api/tests/test_endtoend.py b/subiquity/common/api/tests/test_endtoend.py
index eb3d7faa..9e844edc 100644
--- a/subiquity/common/api/tests/test_endtoend.py
+++ b/subiquity/common/api/tests/test_endtoend.py
@@ -21,10 +21,10 @@ import aiohttp
from aiohttp import web
from subiquitycore import contextlib38
+from subiquitycore.tests.util import run_coro
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import api, Payload
-from subiquity.tests.util import run_coro
from .test_server import (
makeTestClient,
diff --git a/subiquity/common/api/tests/test_server.py b/subiquity/common/api/tests/test_server.py
index b6b5c327..983cf2b4 100644
--- a/subiquity/common/api/tests/test_server.py
+++ b/subiquity/common/api/tests/test_server.py
@@ -19,6 +19,7 @@ from aiohttp.test_utils import TestClient, TestServer
from aiohttp import web
from subiquitycore.context import Context
+from subiquitycore.tests.util import run_coro
from subiquitycore import contextlib38
from subiquity.common.api.defs import api, Payload
@@ -28,7 +29,6 @@ from subiquity.common.api.server import (
MissingImplementationError,
SignatureMisatchError,
)
-from subiquity.tests.util import run_coro
class TestApp:
diff --git a/subiquity/server/controllers/mirror.py b/subiquity/server/controllers/mirror.py
index 25283cfc..15e3610a 100644
--- a/subiquity/server/controllers/mirror.py
+++ b/subiquity/server/controllers/mirror.py
@@ -14,28 +14,18 @@
# along with this program. If not, see .
import asyncio
-import enum
import logging
from curtin.config import merge_config
-from subiquitycore.async_helpers import SingleInstanceTask
from subiquitycore.context import with_context
from subiquity.common.apidef import API
-from subiquity.common.geoip import GeoIP
from subiquity.server.controller import SubiquityController
log = logging.getLogger('subiquity.server.controllers.mirror')
-class CheckState(enum.IntEnum):
- NOT_STARTED = enum.auto()
- CHECKING = enum.auto()
- FAILED = enum.auto()
- DONE = enum.auto()
-
-
class MirrorController(SubiquityController):
endpoint = API.mirror
@@ -55,11 +45,8 @@ class MirrorController(SubiquityController):
def __init__(self, app):
super().__init__(app)
self.geoip_enabled = True
- self.check_state = CheckState.NOT_STARTED
- self.lookup_task = SingleInstanceTask(self.lookup)
- self.geoip = GeoIP()
- self.app.hub.subscribe('network-up', self.maybe_start_check)
- self.app.hub.subscribe('network-proxy-set', self.maybe_start_check)
+ self.app.geoip.on_countrycode.subscribe(self.on_countrycode)
+ self.cc_event = asyncio.Event()
def load_autoinstall_data(self, data):
if data is None:
@@ -72,30 +59,16 @@ class MirrorController(SubiquityController):
async def apply_autoinstall_config(self, context):
if not self.geoip_enabled:
return
- if self.lookup_task.task is None:
- return
try:
with context.child('waiting'):
- await asyncio.wait_for(self.lookup_task.wait(), 10)
+ await asyncio.wait_for(self.cc_event.wait(), 10)
except asyncio.TimeoutError:
pass
- def maybe_start_check(self):
- if not self.geoip_enabled:
- return
- if self.check_state != CheckState.DONE:
- self.check_state = CheckState.CHECKING
- self.lookup_task.start_sync()
-
- @with_context()
- async def lookup(self, context):
- if await self.geoip.lookup():
- cc = self.geoip.countrycode
- if cc:
- self.model.set_country(cc)
- self.check_state = CheckState.DONE
- return
- self.check_state = CheckState.FAILED
+ def on_countrycode(self, cc):
+ if self.geoip_enabled:
+ self.model.set_country(cc)
+ self.cc_event.set()
def serialize(self):
return self.model.get_mirror()
diff --git a/subiquity/common/geoip.py b/subiquity/server/geoip.py
similarity index 57%
rename from subiquity/common/geoip.py
rename to subiquity/server/geoip.py
index 011752d7..667f5799 100644
--- a/subiquity/common/geoip.py
+++ b/subiquity/server/geoip.py
@@ -14,19 +14,53 @@
# along with this program. If not, see .
import logging
+import enum
import requests
from xml.etree import ElementTree
-from subiquitycore.async_helpers import run_in_thread
+from subiquitycore.async_helpers import (
+ run_in_thread,
+ SingleInstanceTask,
+)
+from subiquitycore.pubsub import EventCallback
log = logging.getLogger('subiquity.common.geoip')
+class CheckState(enum.IntEnum):
+ NOT_STARTED = enum.auto()
+ CHECKING = enum.auto()
+ FAILED = enum.auto()
+ DONE = enum.auto()
+
+
class GeoIP:
- def __init__(self):
+ def __init__(self, app):
+ self.app = app
self.element = None
+ self.cc = None
+ self.tz = None
+ self.check_state = CheckState.NOT_STARTED
+ self.on_countrycode = EventCallback()
+ self.on_timezone = EventCallback()
+ self.lookup_task = SingleInstanceTask(self.lookup)
+ self.app.hub.subscribe('network-up', self.maybe_start_check)
+ self.app.hub.subscribe('network-proxy-set', self.maybe_start_check)
+
+ def maybe_start_check(self):
+ if self.check_state != CheckState.DONE:
+ self.check_state = CheckState.CHECKING
+ self.lookup_task.start_sync()
async def lookup(self):
+ rv = await self._lookup()
+ if rv:
+ self.check_state = CheckState.DONE
+ else:
+ self.check_state = CheckState.FAILED
+ return rv
+
+ async def _lookup(self):
try:
response = await run_in_thread(
requests.get, "https://geoip.ubuntu.com/lookup")
@@ -40,28 +74,33 @@ class GeoIP:
except ElementTree.ParseError:
log.exception("parsing %r failed", self.response_text)
return False
+
+ cc = self.element.find("CountryCode")
+ if cc is None or cc.text is None:
+ log.debug("no CountryCode found in %r", self.response_text)
+ return False
+ cc = cc.text.lower()
+ if len(cc) != 2:
+ log.debug("bogus CountryCode found in %r", self.response_text)
+ return False
+ if cc != self.cc:
+ self.on_countrycode.broadcast(cc)
+ self.cc = cc
+
+ tz = self.element.find("TimeZone")
+ if tz is None or not tz.text:
+ log.debug("no TimeZone found in %r", self.response_text)
+ return False
+ if tz != self.tz:
+ self.on_timezone.broadcast(tz)
+ self.tz = tz.text
+
return True
@property
def countrycode(self):
- if not self.element:
- return None
- cc = self.element.find("CountryCode")
- if cc is None or cc.text is None:
- log.debug("no CountryCode found in %r", self.response_text)
- return None
- cc = cc.text.lower()
- if len(cc) != 2:
- log.debug("bogus CountryCode found in %r", self.response_text)
- return None
- return cc
+ return self.cc
@property
def timezone(self):
- if not self.element:
- return None
- tz = self.element.find("TimeZone")
- if tz is None or not tz.text:
- log.debug("no TimeZone found in %r", self.response_text)
- return None
- return tz.text
+ return self.tz
diff --git a/subiquity/server/server.py b/subiquity/server/server.py
index 70ece7c5..50b8f843 100644
--- a/subiquity/server/server.py
+++ b/subiquity/server/server.py
@@ -61,8 +61,9 @@ from subiquity.common.types import (
LiveSessionSSHInfo,
PasswordKind,
)
-from subiquity.server.controller import SubiquityController
from subiquity.models.subiquity import SubiquityModel
+from subiquity.server.controller import SubiquityController
+from subiquity.server.geoip import GeoIP
from subiquity.server.errors import ErrorController
from subiquitycore.snapd import (
AsyncSnapd,
@@ -246,6 +247,7 @@ class SubiquityServer(Application):
self.autoinstall_config = None
self.hub.subscribe('network-up', self._network_change)
self.hub.subscribe('network-proxy-set', self._proxy_set)
+ self.geoip = GeoIP(self)
def load_serialized_state(self):
for controller in self.controllers.instances:
diff --git a/subiquity/server/tests/__init__.py b/subiquity/server/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/subiquity/common/tests/test_geoip.py b/subiquity/server/tests/test_geoip.py
similarity index 88%
rename from subiquity/common/tests/test_geoip.py
rename to subiquity/server/tests/test_geoip.py
index ca9d2217..8cf71291 100644
--- a/subiquity/common/tests/test_geoip.py
+++ b/subiquity/server/tests/test_geoip.py
@@ -16,9 +16,9 @@
import mock
from subiquitycore.tests import SubiTestCase
-from subiquity.common.geoip import GeoIP
-from subiquity.tests.util import run_coro
-
+from subiquitycore.tests.mocks import make_app
+from subiquitycore.tests.util import run_coro
+from subiquity.server.geoip import GeoIP
xml = '''
@@ -62,7 +62,7 @@ def requests_get_factory(text):
class TestGeoIP(SubiTestCase):
@mock.patch('requests.get', new=requests_get_factory(xml))
def setUp(self):
- self.geoip = GeoIP()
+ self.geoip = GeoIP(make_app())
async def fn():
self.assertTrue(await self.geoip.lookup())
@@ -77,7 +77,7 @@ class TestGeoIP(SubiTestCase):
class TestGeoIPBadData(SubiTestCase):
def setUp(self):
- self.geoip = GeoIP()
+ self.geoip = GeoIP(make_app())
@mock.patch('requests.get', new=requests_get_factory(partial))
def test_partial_reponse(self):
@@ -88,7 +88,7 @@ class TestGeoIPBadData(SubiTestCase):
@mock.patch('requests.get', new=requests_get_factory(incomplete))
def test_incomplete(self):
async def fn():
- self.assertTrue(await self.geoip.lookup())
+ self.assertFalse(await self.geoip.lookup())
run_coro(fn())
self.assertIsNone(self.geoip.countrycode)
self.assertIsNone(self.geoip.timezone)
@@ -96,20 +96,20 @@ class TestGeoIPBadData(SubiTestCase):
@mock.patch('requests.get', new=requests_get_factory(long_cc))
def test_long_cc(self):
async def fn():
- self.assertTrue(await self.geoip.lookup())
+ self.assertFalse(await self.geoip.lookup())
run_coro(fn())
self.assertIsNone(self.geoip.countrycode)
@mock.patch('requests.get', new=requests_get_factory(empty_cc))
def test_empty_cc(self):
async def fn():
- self.assertTrue(await self.geoip.lookup())
+ self.assertFalse(await self.geoip.lookup())
run_coro(fn())
self.assertIsNone(self.geoip.countrycode)
@mock.patch('requests.get', new=requests_get_factory(empty_tz))
def test_empty_tz(self):
async def fn():
- self.assertTrue(await self.geoip.lookup())
+ self.assertFalse(await self.geoip.lookup())
run_coro(fn())
self.assertIsNone(self.geoip.timezone)
diff --git a/subiquitycore/pubsub.py b/subiquitycore/pubsub.py
index de29bdb6..7b3f81cd 100644
--- a/subiquitycore/pubsub.py
+++ b/subiquitycore/pubsub.py
@@ -33,3 +33,21 @@ class MessageHub:
def broadcast(self, channel):
return asyncio.get_event_loop().create_task(self.abroadcast(channel))
+
+
+class EventCallback:
+
+ def __init__(self):
+ self.subscriptions = []
+
+ def subscribe(self, method, *args):
+ self.subscriptions.append((method, args))
+
+ async def abroadcast(self, cbdata):
+ for m, args in self.subscriptions:
+ v = m(cbdata, *args)
+ if inspect.iscoroutine(v):
+ await v
+
+ def broadcast(self, cbdata):
+ return asyncio.get_event_loop().create_task(self.abroadcast(cbdata))
diff --git a/subiquitycore/tests/mocks.py b/subiquitycore/tests/mocks.py
new file mode 100644
index 00000000..e27b0873
--- /dev/null
+++ b/subiquitycore/tests/mocks.py
@@ -0,0 +1,42 @@
+# Copyright 2020-2021 Canonical, Ltd.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from unittest import mock
+
+from subiquitycore.context import Context
+from subiquitycore.pubsub import MessageHub
+
+
+class MockedApplication:
+ signal = loop = None
+ project = "mini"
+ autoinstall_config = {}
+ answers = {}
+ opts = None
+
+
+def make_app(model=None):
+ app = MockedApplication()
+ app.ui = mock.Mock()
+ if model is not None:
+ app.base_model = model
+ else:
+ app.base_model = mock.Mock()
+ app.context = Context.new(app)
+ app.exit = mock.Mock()
+ app.respond = mock.Mock()
+ app.next_screen = mock.Mock()
+ app.prev_screen = mock.Mock()
+ app.hub = MessageHub()
+ return app
diff --git a/subiquitycore/tests/test_pubsub.py b/subiquitycore/tests/test_pubsub.py
new file mode 100644
index 00000000..ff732a34
--- /dev/null
+++ b/subiquitycore/tests/test_pubsub.py
@@ -0,0 +1,53 @@
+# Copyright 2021 Canonical, Ltd.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import asyncio
+
+from subiquitycore.tests import SubiTestCase
+from subiquitycore.pubsub import EventCallback
+from subiquitycore.tests.util import run_coro
+
+
+async def wait_other_tasks():
+ if hasattr(asyncio, 'all_tasks'):
+ tasks = asyncio.all_tasks() # py 3.7+
+ tasks.remove(asyncio.current_task())
+ else:
+ tasks = asyncio.Task.all_tasks() # py 3.6
+ tasks.remove(asyncio.Task.current_task())
+ await asyncio.wait(tasks)
+
+
+class TestEventCallback(SubiTestCase):
+ def test_basic(self):
+ def job():
+ self.thething.broadcast(42)
+
+ def cb(val, mydata):
+ self.assertEqual(42, val)
+ self.assertEqual('bacon', mydata)
+ self.called += 1
+
+ async def fn():
+ self.called = 0
+ self.thething = EventCallback()
+ calls_expected = 3
+ for _ in range(calls_expected):
+ self.thething.subscribe(cb, 'bacon')
+ job()
+ await wait_other_tasks()
+ self.assertEqual(calls_expected, self.called)
+
+ run_coro(fn())
diff --git a/subiquity/tests/util.py b/subiquitycore/tests/util.py
similarity index 100%
rename from subiquity/tests/util.py
rename to subiquitycore/tests/util.py