Add EventCallback, and move mirror task things to GeoIP (#983)

Move mock_app to common location.
Move run_coro to subiquitycore so that subiquitycore doesn't have to
reference things in subiquity, even for test.
Move task tracking things from mirror to geoip.
Server app owns the geoip instance.
Create EventCallback as an alternative to MessageHub that should
hopefully express clearer intermodule dependencies.
This commit is contained in:
Dan Bungert 2021-06-14 16:05:27 -06:00 committed by GitHub
parent 92bc06b5c0
commit 6a189dd598
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 194 additions and 90 deletions

View File

@ -15,7 +15,7 @@
import unittest
from unittest import mock
from subiquitycore.context import Context
from subiquitycore.tests.mocks import make_app
from console_conf.controllers.chooser import (
RecoveryChooserController,
RecoveryChooserConfirmController,
@ -26,29 +26,6 @@ from console_conf.models.systems import (
)
class MockedApplication:
signal = loop = None
project = "mini"
autoinstall_config = {}
answers = {}
opts = None
def make_app(model=None):
app = MockedApplication()
app.ui = mock.Mock()
if model is not None:
app.base_model = model
else:
app.base_model = mock.Mock()
app.context = Context.new(app)
app.exit = mock.Mock()
app.respond = mock.Mock()
app.next_screen = mock.Mock()
app.prev_screen = mock.Mock()
return app
model1_non_current = {
"current": False,
"label": "1234",

View File

@ -21,10 +21,10 @@ import aiohttp
from aiohttp import web
from subiquitycore import contextlib38
from subiquitycore.tests.util import run_coro
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import api, Payload
from subiquity.tests.util import run_coro
from .test_server import (
makeTestClient,

View File

@ -19,6 +19,7 @@ from aiohttp.test_utils import TestClient, TestServer
from aiohttp import web
from subiquitycore.context import Context
from subiquitycore.tests.util import run_coro
from subiquitycore import contextlib38
from subiquity.common.api.defs import api, Payload
@ -28,7 +29,6 @@ from subiquity.common.api.server import (
MissingImplementationError,
SignatureMisatchError,
)
from subiquity.tests.util import run_coro
class TestApp:

View File

@ -14,28 +14,18 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import enum
import logging
from curtin.config import merge_config
from subiquitycore.async_helpers import SingleInstanceTask
from subiquitycore.context import with_context
from subiquity.common.apidef import API
from subiquity.common.geoip import GeoIP
from subiquity.server.controller import SubiquityController
log = logging.getLogger('subiquity.server.controllers.mirror')
class CheckState(enum.IntEnum):
NOT_STARTED = enum.auto()
CHECKING = enum.auto()
FAILED = enum.auto()
DONE = enum.auto()
class MirrorController(SubiquityController):
endpoint = API.mirror
@ -55,11 +45,8 @@ class MirrorController(SubiquityController):
def __init__(self, app):
super().__init__(app)
self.geoip_enabled = True
self.check_state = CheckState.NOT_STARTED
self.lookup_task = SingleInstanceTask(self.lookup)
self.geoip = GeoIP()
self.app.hub.subscribe('network-up', self.maybe_start_check)
self.app.hub.subscribe('network-proxy-set', self.maybe_start_check)
self.app.geoip.on_countrycode.subscribe(self.on_countrycode)
self.cc_event = asyncio.Event()
def load_autoinstall_data(self, data):
if data is None:
@ -72,30 +59,16 @@ class MirrorController(SubiquityController):
async def apply_autoinstall_config(self, context):
if not self.geoip_enabled:
return
if self.lookup_task.task is None:
return
try:
with context.child('waiting'):
await asyncio.wait_for(self.lookup_task.wait(), 10)
await asyncio.wait_for(self.cc_event.wait(), 10)
except asyncio.TimeoutError:
pass
def maybe_start_check(self):
if not self.geoip_enabled:
return
if self.check_state != CheckState.DONE:
self.check_state = CheckState.CHECKING
self.lookup_task.start_sync()
@with_context()
async def lookup(self, context):
if await self.geoip.lookup():
cc = self.geoip.countrycode
if cc:
self.model.set_country(cc)
self.check_state = CheckState.DONE
return
self.check_state = CheckState.FAILED
def on_countrycode(self, cc):
if self.geoip_enabled:
self.model.set_country(cc)
self.cc_event.set()
def serialize(self):
return self.model.get_mirror()

View File

@ -14,19 +14,53 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
import enum
import requests
from xml.etree import ElementTree
from subiquitycore.async_helpers import run_in_thread
from subiquitycore.async_helpers import (
run_in_thread,
SingleInstanceTask,
)
from subiquitycore.pubsub import EventCallback
log = logging.getLogger('subiquity.common.geoip')
class CheckState(enum.IntEnum):
NOT_STARTED = enum.auto()
CHECKING = enum.auto()
FAILED = enum.auto()
DONE = enum.auto()
class GeoIP:
def __init__(self):
def __init__(self, app):
self.app = app
self.element = None
self.cc = None
self.tz = None
self.check_state = CheckState.NOT_STARTED
self.on_countrycode = EventCallback()
self.on_timezone = EventCallback()
self.lookup_task = SingleInstanceTask(self.lookup)
self.app.hub.subscribe('network-up', self.maybe_start_check)
self.app.hub.subscribe('network-proxy-set', self.maybe_start_check)
def maybe_start_check(self):
if self.check_state != CheckState.DONE:
self.check_state = CheckState.CHECKING
self.lookup_task.start_sync()
async def lookup(self):
rv = await self._lookup()
if rv:
self.check_state = CheckState.DONE
else:
self.check_state = CheckState.FAILED
return rv
async def _lookup(self):
try:
response = await run_in_thread(
requests.get, "https://geoip.ubuntu.com/lookup")
@ -40,28 +74,33 @@ class GeoIP:
except ElementTree.ParseError:
log.exception("parsing %r failed", self.response_text)
return False
cc = self.element.find("CountryCode")
if cc is None or cc.text is None:
log.debug("no CountryCode found in %r", self.response_text)
return False
cc = cc.text.lower()
if len(cc) != 2:
log.debug("bogus CountryCode found in %r", self.response_text)
return False
if cc != self.cc:
self.on_countrycode.broadcast(cc)
self.cc = cc
tz = self.element.find("TimeZone")
if tz is None or not tz.text:
log.debug("no TimeZone found in %r", self.response_text)
return False
if tz != self.tz:
self.on_timezone.broadcast(tz)
self.tz = tz.text
return True
@property
def countrycode(self):
if not self.element:
return None
cc = self.element.find("CountryCode")
if cc is None or cc.text is None:
log.debug("no CountryCode found in %r", self.response_text)
return None
cc = cc.text.lower()
if len(cc) != 2:
log.debug("bogus CountryCode found in %r", self.response_text)
return None
return cc
return self.cc
@property
def timezone(self):
if not self.element:
return None
tz = self.element.find("TimeZone")
if tz is None or not tz.text:
log.debug("no TimeZone found in %r", self.response_text)
return None
return tz.text
return self.tz

View File

@ -61,8 +61,9 @@ from subiquity.common.types import (
LiveSessionSSHInfo,
PasswordKind,
)
from subiquity.server.controller import SubiquityController
from subiquity.models.subiquity import SubiquityModel
from subiquity.server.controller import SubiquityController
from subiquity.server.geoip import GeoIP
from subiquity.server.errors import ErrorController
from subiquitycore.snapd import (
AsyncSnapd,
@ -246,6 +247,7 @@ class SubiquityServer(Application):
self.autoinstall_config = None
self.hub.subscribe('network-up', self._network_change)
self.hub.subscribe('network-proxy-set', self._proxy_set)
self.geoip = GeoIP(self)
def load_serialized_state(self):
for controller in self.controllers.instances:

View File

View File

@ -16,9 +16,9 @@
import mock
from subiquitycore.tests import SubiTestCase
from subiquity.common.geoip import GeoIP
from subiquity.tests.util import run_coro
from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
from subiquity.server.geoip import GeoIP
xml = '''
<Response>
@ -62,7 +62,7 @@ def requests_get_factory(text):
class TestGeoIP(SubiTestCase):
@mock.patch('requests.get', new=requests_get_factory(xml))
def setUp(self):
self.geoip = GeoIP()
self.geoip = GeoIP(make_app())
async def fn():
self.assertTrue(await self.geoip.lookup())
@ -77,7 +77,7 @@ class TestGeoIP(SubiTestCase):
class TestGeoIPBadData(SubiTestCase):
def setUp(self):
self.geoip = GeoIP()
self.geoip = GeoIP(make_app())
@mock.patch('requests.get', new=requests_get_factory(partial))
def test_partial_reponse(self):
@ -88,7 +88,7 @@ class TestGeoIPBadData(SubiTestCase):
@mock.patch('requests.get', new=requests_get_factory(incomplete))
def test_incomplete(self):
async def fn():
self.assertTrue(await self.geoip.lookup())
self.assertFalse(await self.geoip.lookup())
run_coro(fn())
self.assertIsNone(self.geoip.countrycode)
self.assertIsNone(self.geoip.timezone)
@ -96,20 +96,20 @@ class TestGeoIPBadData(SubiTestCase):
@mock.patch('requests.get', new=requests_get_factory(long_cc))
def test_long_cc(self):
async def fn():
self.assertTrue(await self.geoip.lookup())
self.assertFalse(await self.geoip.lookup())
run_coro(fn())
self.assertIsNone(self.geoip.countrycode)
@mock.patch('requests.get', new=requests_get_factory(empty_cc))
def test_empty_cc(self):
async def fn():
self.assertTrue(await self.geoip.lookup())
self.assertFalse(await self.geoip.lookup())
run_coro(fn())
self.assertIsNone(self.geoip.countrycode)
@mock.patch('requests.get', new=requests_get_factory(empty_tz))
def test_empty_tz(self):
async def fn():
self.assertTrue(await self.geoip.lookup())
self.assertFalse(await self.geoip.lookup())
run_coro(fn())
self.assertIsNone(self.geoip.timezone)

View File

@ -33,3 +33,21 @@ class MessageHub:
def broadcast(self, channel):
return asyncio.get_event_loop().create_task(self.abroadcast(channel))
class EventCallback:
def __init__(self):
self.subscriptions = []
def subscribe(self, method, *args):
self.subscriptions.append((method, args))
async def abroadcast(self, cbdata):
for m, args in self.subscriptions:
v = m(cbdata, *args)
if inspect.iscoroutine(v):
await v
def broadcast(self, cbdata):
return asyncio.get_event_loop().create_task(self.abroadcast(cbdata))

View File

@ -0,0 +1,42 @@
# Copyright 2020-2021 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from unittest import mock
from subiquitycore.context import Context
from subiquitycore.pubsub import MessageHub
class MockedApplication:
signal = loop = None
project = "mini"
autoinstall_config = {}
answers = {}
opts = None
def make_app(model=None):
app = MockedApplication()
app.ui = mock.Mock()
if model is not None:
app.base_model = model
else:
app.base_model = mock.Mock()
app.context = Context.new(app)
app.exit = mock.Mock()
app.respond = mock.Mock()
app.next_screen = mock.Mock()
app.prev_screen = mock.Mock()
app.hub = MessageHub()
return app

View File

@ -0,0 +1,53 @@
# Copyright 2021 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
from subiquitycore.tests import SubiTestCase
from subiquitycore.pubsub import EventCallback
from subiquitycore.tests.util import run_coro
async def wait_other_tasks():
if hasattr(asyncio, 'all_tasks'):
tasks = asyncio.all_tasks() # py 3.7+
tasks.remove(asyncio.current_task())
else:
tasks = asyncio.Task.all_tasks() # py 3.6
tasks.remove(asyncio.Task.current_task())
await asyncio.wait(tasks)
class TestEventCallback(SubiTestCase):
def test_basic(self):
def job():
self.thething.broadcast(42)
def cb(val, mydata):
self.assertEqual(42, val)
self.assertEqual('bacon', mydata)
self.called += 1
async def fn():
self.called = 0
self.thething = EventCallback()
calls_expected = 3
for _ in range(calls_expected):
self.thething.subscribe(cb, 'bacon')
job()
await wait_other_tasks()
self.assertEqual(calls_expected, self.called)
run_coro(fn())