Add EventCallback, and move mirror task things to GeoIP (#983)
Move mock_app to common location. Move run_coro to subiquitycore so that subiquitycore doesn't have to reference things in subiquity, even for test. Move task tracking things from mirror to geoip. Server app owns the geoip instance. Create EventCallback as an alternative to MessageHub that should hopefully express clearer intermodule dependencies.
This commit is contained in:
parent
92bc06b5c0
commit
6a189dd598
|
@ -15,7 +15,7 @@
|
|||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from subiquitycore.context import Context
|
||||
from subiquitycore.tests.mocks import make_app
|
||||
from console_conf.controllers.chooser import (
|
||||
RecoveryChooserController,
|
||||
RecoveryChooserConfirmController,
|
||||
|
@ -26,29 +26,6 @@ from console_conf.models.systems import (
|
|||
)
|
||||
|
||||
|
||||
class MockedApplication:
|
||||
signal = loop = None
|
||||
project = "mini"
|
||||
autoinstall_config = {}
|
||||
answers = {}
|
||||
opts = None
|
||||
|
||||
|
||||
def make_app(model=None):
|
||||
app = MockedApplication()
|
||||
app.ui = mock.Mock()
|
||||
if model is not None:
|
||||
app.base_model = model
|
||||
else:
|
||||
app.base_model = mock.Mock()
|
||||
app.context = Context.new(app)
|
||||
app.exit = mock.Mock()
|
||||
app.respond = mock.Mock()
|
||||
app.next_screen = mock.Mock()
|
||||
app.prev_screen = mock.Mock()
|
||||
return app
|
||||
|
||||
|
||||
model1_non_current = {
|
||||
"current": False,
|
||||
"label": "1234",
|
||||
|
|
|
@ -21,10 +21,10 @@ import aiohttp
|
|||
from aiohttp import web
|
||||
|
||||
from subiquitycore import contextlib38
|
||||
from subiquitycore.tests.util import run_coro
|
||||
|
||||
from subiquity.common.api.client import make_client
|
||||
from subiquity.common.api.defs import api, Payload
|
||||
from subiquity.tests.util import run_coro
|
||||
|
||||
from .test_server import (
|
||||
makeTestClient,
|
||||
|
|
|
@ -19,6 +19,7 @@ from aiohttp.test_utils import TestClient, TestServer
|
|||
from aiohttp import web
|
||||
|
||||
from subiquitycore.context import Context
|
||||
from subiquitycore.tests.util import run_coro
|
||||
from subiquitycore import contextlib38
|
||||
|
||||
from subiquity.common.api.defs import api, Payload
|
||||
|
@ -28,7 +29,6 @@ from subiquity.common.api.server import (
|
|||
MissingImplementationError,
|
||||
SignatureMisatchError,
|
||||
)
|
||||
from subiquity.tests.util import run_coro
|
||||
|
||||
|
||||
class TestApp:
|
||||
|
|
|
@ -14,28 +14,18 @@
|
|||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
import logging
|
||||
|
||||
from curtin.config import merge_config
|
||||
|
||||
from subiquitycore.async_helpers import SingleInstanceTask
|
||||
from subiquitycore.context import with_context
|
||||
|
||||
from subiquity.common.apidef import API
|
||||
from subiquity.common.geoip import GeoIP
|
||||
from subiquity.server.controller import SubiquityController
|
||||
|
||||
log = logging.getLogger('subiquity.server.controllers.mirror')
|
||||
|
||||
|
||||
class CheckState(enum.IntEnum):
|
||||
NOT_STARTED = enum.auto()
|
||||
CHECKING = enum.auto()
|
||||
FAILED = enum.auto()
|
||||
DONE = enum.auto()
|
||||
|
||||
|
||||
class MirrorController(SubiquityController):
|
||||
|
||||
endpoint = API.mirror
|
||||
|
@ -55,11 +45,8 @@ class MirrorController(SubiquityController):
|
|||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.geoip_enabled = True
|
||||
self.check_state = CheckState.NOT_STARTED
|
||||
self.lookup_task = SingleInstanceTask(self.lookup)
|
||||
self.geoip = GeoIP()
|
||||
self.app.hub.subscribe('network-up', self.maybe_start_check)
|
||||
self.app.hub.subscribe('network-proxy-set', self.maybe_start_check)
|
||||
self.app.geoip.on_countrycode.subscribe(self.on_countrycode)
|
||||
self.cc_event = asyncio.Event()
|
||||
|
||||
def load_autoinstall_data(self, data):
|
||||
if data is None:
|
||||
|
@ -72,30 +59,16 @@ class MirrorController(SubiquityController):
|
|||
async def apply_autoinstall_config(self, context):
|
||||
if not self.geoip_enabled:
|
||||
return
|
||||
if self.lookup_task.task is None:
|
||||
return
|
||||
try:
|
||||
with context.child('waiting'):
|
||||
await asyncio.wait_for(self.lookup_task.wait(), 10)
|
||||
await asyncio.wait_for(self.cc_event.wait(), 10)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
def maybe_start_check(self):
|
||||
if not self.geoip_enabled:
|
||||
return
|
||||
if self.check_state != CheckState.DONE:
|
||||
self.check_state = CheckState.CHECKING
|
||||
self.lookup_task.start_sync()
|
||||
|
||||
@with_context()
|
||||
async def lookup(self, context):
|
||||
if await self.geoip.lookup():
|
||||
cc = self.geoip.countrycode
|
||||
if cc:
|
||||
self.model.set_country(cc)
|
||||
self.check_state = CheckState.DONE
|
||||
return
|
||||
self.check_state = CheckState.FAILED
|
||||
def on_countrycode(self, cc):
|
||||
if self.geoip_enabled:
|
||||
self.model.set_country(cc)
|
||||
self.cc_event.set()
|
||||
|
||||
def serialize(self):
|
||||
return self.model.get_mirror()
|
||||
|
|
|
@ -14,19 +14,53 @@
|
|||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import logging
|
||||
import enum
|
||||
import requests
|
||||
from xml.etree import ElementTree
|
||||
|
||||
from subiquitycore.async_helpers import run_in_thread
|
||||
from subiquitycore.async_helpers import (
|
||||
run_in_thread,
|
||||
SingleInstanceTask,
|
||||
)
|
||||
from subiquitycore.pubsub import EventCallback
|
||||
|
||||
log = logging.getLogger('subiquity.common.geoip')
|
||||
|
||||
|
||||
class CheckState(enum.IntEnum):
|
||||
NOT_STARTED = enum.auto()
|
||||
CHECKING = enum.auto()
|
||||
FAILED = enum.auto()
|
||||
DONE = enum.auto()
|
||||
|
||||
|
||||
class GeoIP:
|
||||
def __init__(self):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.element = None
|
||||
self.cc = None
|
||||
self.tz = None
|
||||
self.check_state = CheckState.NOT_STARTED
|
||||
self.on_countrycode = EventCallback()
|
||||
self.on_timezone = EventCallback()
|
||||
self.lookup_task = SingleInstanceTask(self.lookup)
|
||||
self.app.hub.subscribe('network-up', self.maybe_start_check)
|
||||
self.app.hub.subscribe('network-proxy-set', self.maybe_start_check)
|
||||
|
||||
def maybe_start_check(self):
|
||||
if self.check_state != CheckState.DONE:
|
||||
self.check_state = CheckState.CHECKING
|
||||
self.lookup_task.start_sync()
|
||||
|
||||
async def lookup(self):
|
||||
rv = await self._lookup()
|
||||
if rv:
|
||||
self.check_state = CheckState.DONE
|
||||
else:
|
||||
self.check_state = CheckState.FAILED
|
||||
return rv
|
||||
|
||||
async def _lookup(self):
|
||||
try:
|
||||
response = await run_in_thread(
|
||||
requests.get, "https://geoip.ubuntu.com/lookup")
|
||||
|
@ -40,28 +74,33 @@ class GeoIP:
|
|||
except ElementTree.ParseError:
|
||||
log.exception("parsing %r failed", self.response_text)
|
||||
return False
|
||||
|
||||
cc = self.element.find("CountryCode")
|
||||
if cc is None or cc.text is None:
|
||||
log.debug("no CountryCode found in %r", self.response_text)
|
||||
return False
|
||||
cc = cc.text.lower()
|
||||
if len(cc) != 2:
|
||||
log.debug("bogus CountryCode found in %r", self.response_text)
|
||||
return False
|
||||
if cc != self.cc:
|
||||
self.on_countrycode.broadcast(cc)
|
||||
self.cc = cc
|
||||
|
||||
tz = self.element.find("TimeZone")
|
||||
if tz is None or not tz.text:
|
||||
log.debug("no TimeZone found in %r", self.response_text)
|
||||
return False
|
||||
if tz != self.tz:
|
||||
self.on_timezone.broadcast(tz)
|
||||
self.tz = tz.text
|
||||
|
||||
return True
|
||||
|
||||
@property
|
||||
def countrycode(self):
|
||||
if not self.element:
|
||||
return None
|
||||
cc = self.element.find("CountryCode")
|
||||
if cc is None or cc.text is None:
|
||||
log.debug("no CountryCode found in %r", self.response_text)
|
||||
return None
|
||||
cc = cc.text.lower()
|
||||
if len(cc) != 2:
|
||||
log.debug("bogus CountryCode found in %r", self.response_text)
|
||||
return None
|
||||
return cc
|
||||
return self.cc
|
||||
|
||||
@property
|
||||
def timezone(self):
|
||||
if not self.element:
|
||||
return None
|
||||
tz = self.element.find("TimeZone")
|
||||
if tz is None or not tz.text:
|
||||
log.debug("no TimeZone found in %r", self.response_text)
|
||||
return None
|
||||
return tz.text
|
||||
return self.tz
|
|
@ -61,8 +61,9 @@ from subiquity.common.types import (
|
|||
LiveSessionSSHInfo,
|
||||
PasswordKind,
|
||||
)
|
||||
from subiquity.server.controller import SubiquityController
|
||||
from subiquity.models.subiquity import SubiquityModel
|
||||
from subiquity.server.controller import SubiquityController
|
||||
from subiquity.server.geoip import GeoIP
|
||||
from subiquity.server.errors import ErrorController
|
||||
from subiquitycore.snapd import (
|
||||
AsyncSnapd,
|
||||
|
@ -246,6 +247,7 @@ class SubiquityServer(Application):
|
|||
self.autoinstall_config = None
|
||||
self.hub.subscribe('network-up', self._network_change)
|
||||
self.hub.subscribe('network-proxy-set', self._proxy_set)
|
||||
self.geoip = GeoIP(self)
|
||||
|
||||
def load_serialized_state(self):
|
||||
for controller in self.controllers.instances:
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
import mock
|
||||
|
||||
from subiquitycore.tests import SubiTestCase
|
||||
from subiquity.common.geoip import GeoIP
|
||||
from subiquity.tests.util import run_coro
|
||||
|
||||
from subiquitycore.tests.mocks import make_app
|
||||
from subiquitycore.tests.util import run_coro
|
||||
from subiquity.server.geoip import GeoIP
|
||||
|
||||
xml = '''
|
||||
<Response>
|
||||
|
@ -62,7 +62,7 @@ def requests_get_factory(text):
|
|||
class TestGeoIP(SubiTestCase):
|
||||
@mock.patch('requests.get', new=requests_get_factory(xml))
|
||||
def setUp(self):
|
||||
self.geoip = GeoIP()
|
||||
self.geoip = GeoIP(make_app())
|
||||
|
||||
async def fn():
|
||||
self.assertTrue(await self.geoip.lookup())
|
||||
|
@ -77,7 +77,7 @@ class TestGeoIP(SubiTestCase):
|
|||
|
||||
class TestGeoIPBadData(SubiTestCase):
|
||||
def setUp(self):
|
||||
self.geoip = GeoIP()
|
||||
self.geoip = GeoIP(make_app())
|
||||
|
||||
@mock.patch('requests.get', new=requests_get_factory(partial))
|
||||
def test_partial_reponse(self):
|
||||
|
@ -88,7 +88,7 @@ class TestGeoIPBadData(SubiTestCase):
|
|||
@mock.patch('requests.get', new=requests_get_factory(incomplete))
|
||||
def test_incomplete(self):
|
||||
async def fn():
|
||||
self.assertTrue(await self.geoip.lookup())
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
run_coro(fn())
|
||||
self.assertIsNone(self.geoip.countrycode)
|
||||
self.assertIsNone(self.geoip.timezone)
|
||||
|
@ -96,20 +96,20 @@ class TestGeoIPBadData(SubiTestCase):
|
|||
@mock.patch('requests.get', new=requests_get_factory(long_cc))
|
||||
def test_long_cc(self):
|
||||
async def fn():
|
||||
self.assertTrue(await self.geoip.lookup())
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
run_coro(fn())
|
||||
self.assertIsNone(self.geoip.countrycode)
|
||||
|
||||
@mock.patch('requests.get', new=requests_get_factory(empty_cc))
|
||||
def test_empty_cc(self):
|
||||
async def fn():
|
||||
self.assertTrue(await self.geoip.lookup())
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
run_coro(fn())
|
||||
self.assertIsNone(self.geoip.countrycode)
|
||||
|
||||
@mock.patch('requests.get', new=requests_get_factory(empty_tz))
|
||||
def test_empty_tz(self):
|
||||
async def fn():
|
||||
self.assertTrue(await self.geoip.lookup())
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
run_coro(fn())
|
||||
self.assertIsNone(self.geoip.timezone)
|
|
@ -33,3 +33,21 @@ class MessageHub:
|
|||
|
||||
def broadcast(self, channel):
|
||||
return asyncio.get_event_loop().create_task(self.abroadcast(channel))
|
||||
|
||||
|
||||
class EventCallback:
|
||||
|
||||
def __init__(self):
|
||||
self.subscriptions = []
|
||||
|
||||
def subscribe(self, method, *args):
|
||||
self.subscriptions.append((method, args))
|
||||
|
||||
async def abroadcast(self, cbdata):
|
||||
for m, args in self.subscriptions:
|
||||
v = m(cbdata, *args)
|
||||
if inspect.iscoroutine(v):
|
||||
await v
|
||||
|
||||
def broadcast(self, cbdata):
|
||||
return asyncio.get_event_loop().create_task(self.abroadcast(cbdata))
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020-2021 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
from unittest import mock
|
||||
|
||||
from subiquitycore.context import Context
|
||||
from subiquitycore.pubsub import MessageHub
|
||||
|
||||
|
||||
class MockedApplication:
|
||||
signal = loop = None
|
||||
project = "mini"
|
||||
autoinstall_config = {}
|
||||
answers = {}
|
||||
opts = None
|
||||
|
||||
|
||||
def make_app(model=None):
|
||||
app = MockedApplication()
|
||||
app.ui = mock.Mock()
|
||||
if model is not None:
|
||||
app.base_model = model
|
||||
else:
|
||||
app.base_model = mock.Mock()
|
||||
app.context = Context.new(app)
|
||||
app.exit = mock.Mock()
|
||||
app.respond = mock.Mock()
|
||||
app.next_screen = mock.Mock()
|
||||
app.prev_screen = mock.Mock()
|
||||
app.hub = MessageHub()
|
||||
return app
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2021 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
|
||||
from subiquitycore.tests import SubiTestCase
|
||||
from subiquitycore.pubsub import EventCallback
|
||||
from subiquitycore.tests.util import run_coro
|
||||
|
||||
|
||||
async def wait_other_tasks():
|
||||
if hasattr(asyncio, 'all_tasks'):
|
||||
tasks = asyncio.all_tasks() # py 3.7+
|
||||
tasks.remove(asyncio.current_task())
|
||||
else:
|
||||
tasks = asyncio.Task.all_tasks() # py 3.6
|
||||
tasks.remove(asyncio.Task.current_task())
|
||||
await asyncio.wait(tasks)
|
||||
|
||||
|
||||
class TestEventCallback(SubiTestCase):
|
||||
def test_basic(self):
|
||||
def job():
|
||||
self.thething.broadcast(42)
|
||||
|
||||
def cb(val, mydata):
|
||||
self.assertEqual(42, val)
|
||||
self.assertEqual('bacon', mydata)
|
||||
self.called += 1
|
||||
|
||||
async def fn():
|
||||
self.called = 0
|
||||
self.thething = EventCallback()
|
||||
calls_expected = 3
|
||||
for _ in range(calls_expected):
|
||||
self.thething.subscribe(cb, 'bacon')
|
||||
job()
|
||||
await wait_other_tasks()
|
||||
self.assertEqual(calls_expected, self.called)
|
||||
|
||||
run_coro(fn())
|
Loading…
Reference in New Issue