tests: s/run_coro/IsolatedAsyncioTestCase/

Tests that use run_coro are at risk of being broken by the introduction
of another test using IsolatedAsyncioTestCase.  Switch over to only use
IsolatedAsyncioTestCase.
This commit is contained in:
Dan Bungert 2022-06-23 17:00:45 -06:00
parent 924e527b7b
commit dffb155203
13 changed files with 181 additions and 269 deletions

View File

@ -21,8 +21,6 @@ import unittest
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
from subiquitycore.tests.util import run_coro
from subiquity.common.api.client import make_client from subiquity.common.api.client import make_client
from subiquity.common.api.defs import api, Payload from subiquity.common.api.defs import api, Payload
@ -46,9 +44,9 @@ async def makeE2EClient(api, impl,
yield make_client(api, mr) yield make_client(api, mr)
class TestEndToEnd(unittest.TestCase): class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
def test_simple(self): async def test_simple(self):
@api @api
class API: class API:
def GET() -> str: ... def GET() -> str: ...
@ -57,13 +55,10 @@ class TestEndToEnd(unittest.TestCase):
async def GET(self) -> str: async def GET(self) -> str:
return 'value' return 'value'
async def run(): async with makeE2EClient(API, Impl()) as client:
async with makeE2EClient(API, Impl()) as client: self.assertEqual(await client.GET(), 'value')
self.assertEqual(await client.GET(), 'value')
run_coro(run()) async def test_nested(self):
def test_nested(self):
@api @api
class API: class API:
class endpoint: class endpoint:
@ -74,13 +69,10 @@ class TestEndToEnd(unittest.TestCase):
async def endpoint_nested_GET(self) -> str: async def endpoint_nested_GET(self) -> str:
return 'value' return 'value'
async def run(): async with makeE2EClient(API, Impl()) as client:
async with makeE2EClient(API, Impl()) as client: self.assertEqual(await client.endpoint.nested.GET(), 'value')
self.assertEqual(await client.endpoint.nested.GET(), 'value')
run_coro(run()) async def test_args(self):
def test_args(self):
@api @api
class API: class API:
def GET(arg1: str, arg2: str) -> str: ... def GET(arg1: str, arg2: str) -> str: ...
@ -89,14 +81,11 @@ class TestEndToEnd(unittest.TestCase):
async def GET(self, arg1: str, arg2: str) -> str: async def GET(self, arg1: str, arg2: str) -> str:
return '{}+{}'.format(arg1, arg2) return '{}+{}'.format(arg1, arg2)
async def run(): async with makeE2EClient(API, Impl()) as client:
async with makeE2EClient(API, Impl()) as client: self.assertEqual(
self.assertEqual( await client.GET(arg1="A", arg2="B"), 'A+B')
await client.GET(arg1="A", arg2="B"), 'A+B')
run_coro(run()) async def test_defaults(self):
def test_defaults(self):
@api @api
class API: class API:
def GET(arg1: str, arg2: str = "arg2") -> str: ... def GET(arg1: str, arg2: str = "arg2") -> str: ...
@ -105,16 +94,13 @@ class TestEndToEnd(unittest.TestCase):
async def GET(self, arg1: str, arg2: str = "arg2") -> str: async def GET(self, arg1: str, arg2: str = "arg2") -> str:
return '{}+{}'.format(arg1, arg2) return '{}+{}'.format(arg1, arg2)
async def run(): async with makeE2EClient(API, Impl()) as client:
async with makeE2EClient(API, Impl()) as client: self.assertEqual(
self.assertEqual( await client.GET(arg1="A", arg2="B"), 'A+B')
await client.GET(arg1="A", arg2="B"), 'A+B') self.assertEqual(
self.assertEqual( await client.GET(arg1="A"), 'A+arg2')
await client.GET(arg1="A"), 'A+arg2')
run_coro(run()) async def test_post(self):
def test_post(self):
@api @api
class API: class API:
def POST(data: Payload[dict]) -> str: ... def POST(data: Payload[dict]) -> str: ...
@ -123,14 +109,11 @@ class TestEndToEnd(unittest.TestCase):
async def POST(self, data: dict) -> str: async def POST(self, data: dict) -> str:
return data['key'] return data['key']
async def run(): async with makeE2EClient(API, Impl()) as client:
async with makeE2EClient(API, Impl()) as client: self.assertEqual(
self.assertEqual( await client.POST({'key': 'value'}), 'value')
await client.POST({'key': 'value'}), 'value')
run_coro(run()) async def test_typed(self):
def test_typed(self):
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class In: class In:
@ -149,14 +132,11 @@ class TestEndToEnd(unittest.TestCase):
async def doubler_POST(self, data: In) -> Out: async def doubler_POST(self, data: In) -> Out:
return Out(doubled=data.val*2) return Out(doubled=data.val*2)
async def run(): async with makeE2EClient(API, Impl()) as client:
async with makeE2EClient(API, Impl()) as client: out = await client.doubler.POST(In(3))
out = await client.doubler.POST(In(3)) self.assertEqual(out.doubled, 6)
self.assertEqual(out.doubled, 6)
run_coro(run()) async def test_middleware(self):
def test_middleware(self):
@api @api
class API: class API:
def GET() -> int: ... def GET() -> int: ...
@ -182,17 +162,14 @@ class TestEndToEnd(unittest.TestCase):
raise Skip raise Skip
yield resp yield resp
async def run(): async with makeE2EClient(
async with makeE2EClient( API, Impl(),
API, Impl(), middlewares=[middleware],
middlewares=[middleware], make_request=custom_make_request) as client:
make_request=custom_make_request) as client: with self.assertRaises(Skip):
with self.assertRaises(Skip): await client.GET()
await client.GET()
run_coro(run()) async def test_error(self):
def test_error(self):
@api @api
class API: class API:
class good: class good:
@ -208,16 +185,13 @@ class TestEndToEnd(unittest.TestCase):
async def bad_GET(self, x: int) -> int: async def bad_GET(self, x: int) -> int:
raise Exception("baz") raise Exception("baz")
async def run(): async with makeE2EClient(API, Impl()) as client:
async with makeE2EClient(API, Impl()) as client: r = await client.good.GET(2)
r = await client.good.GET(2) self.assertEqual(r, 3)
self.assertEqual(r, 3) with self.assertRaises(aiohttp.ClientResponseError):
with self.assertRaises(aiohttp.ClientResponseError): await client.bad.GET(2)
await client.bad.GET(2)
run_coro(run()) async def test_error_middleware(self):
def test_error_middleware(self):
@api @api
class API: class API:
class good: class good:
@ -251,14 +225,11 @@ class TestEndToEnd(unittest.TestCase):
raise Abort raise Abort
yield resp yield resp
async def run(): async with makeE2EClient(
async with makeE2EClient( API, Impl(),
API, Impl(), middlewares=[middleware],
middlewares=[middleware], make_request=custom_make_request) as client:
make_request=custom_make_request) as client: r = await client.good.GET(2)
r = await client.good.GET(2) self.assertEqual(r, 3)
self.assertEqual(r, 3) with self.assertRaises(Abort):
with self.assertRaises(Abort): await client.bad.GET(2)
await client.bad.GET(2)
run_coro(run())

View File

@ -20,7 +20,6 @@ from aiohttp.test_utils import TestClient, TestServer
from aiohttp import web from aiohttp import web
from subiquitycore.context import Context from subiquitycore.context import Context
from subiquitycore.tests.util import run_coro
from subiquity.common.api.defs import api, Payload from subiquity.common.api.defs import api, Payload
from subiquity.common.api.server import ( from subiquity.common.api.server import (
@ -56,7 +55,7 @@ async def makeTestClient(api, impl, middlewares=()):
yield client yield client
class TestBind(unittest.TestCase): class TestBind(unittest.IsolatedAsyncioTestCase):
async def assertResponse(self, coro, value): async def assertResponse(self, coro, value):
resp = await coro resp = await coro
@ -67,7 +66,7 @@ class TestBind(unittest.TestCase):
self.assertEqual(resp.status, 200) self.assertEqual(resp.status, 200)
self.assertEqual(await resp.json(), value) self.assertEqual(await resp.json(), value)
def test_simple(self): async def test_simple(self):
@api @api
class API: class API:
def GET() -> str: ... def GET() -> str: ...
@ -76,14 +75,11 @@ class TestBind(unittest.TestCase):
async def GET(self) -> str: async def GET(self) -> str:
return 'value' return 'value'
async def run(): async with makeTestClient(API, Impl()) as client:
async with makeTestClient(API, Impl()) as client: await self.assertResponse(
await self.assertResponse( client.get("/"), 'value')
client.get("/"), 'value')
run_coro(run()) async def test_nested(self):
def test_nested(self):
@api @api
class API: class API:
class endpoint: class endpoint:
@ -94,14 +90,11 @@ class TestBind(unittest.TestCase):
async def nested_get(self, request, context): async def nested_get(self, request, context):
return 'nested' return 'nested'
async def run(): async with makeTestClient(API.endpoint, Impl()) as client:
async with makeTestClient(API.endpoint, Impl()) as client: await self.assertResponse(
await self.assertResponse( client.get("/endpoint/nested"), 'nested')
client.get("/endpoint/nested"), 'nested')
run_coro(run()) async def test_args(self):
def test_args(self):
@api @api
class API: class API:
def GET(arg: str): ... def GET(arg: str): ...
@ -110,14 +103,11 @@ class TestBind(unittest.TestCase):
async def GET(self, arg: str): async def GET(self, arg: str):
return arg return arg
async def run(): async with makeTestClient(API, Impl()) as client:
async with makeTestClient(API, Impl()) as client: await self.assertResponse(
await self.assertResponse( client.get('/?arg="whut"'), 'whut')
client.get('/?arg="whut"'), 'whut')
run_coro(run()) async def test_missing_argument(self):
def test_missing_argument(self):
@api @api
class API: class API:
def GET(arg: str): ... def GET(arg: str): ...
@ -126,19 +116,16 @@ class TestBind(unittest.TestCase):
async def GET(self, arg: str): async def GET(self, arg: str):
return arg return arg
async def run(): async with makeTestClient(API, Impl()) as client:
async with makeTestClient(API, Impl()) as client: resp = await client.get('/')
resp = await client.get('/') self.assertEqual(resp.status, 500)
self.assertEqual(resp.status, 500) self.assertEqual(resp.headers['x-status'], 'error')
self.assertEqual(resp.headers['x-status'], 'error') self.assertEqual(resp.headers['x-error-type'], 'TypeError')
self.assertEqual(resp.headers['x-error-type'], 'TypeError') self.assertEqual(
self.assertEqual( resp.headers['x-error-msg'],
resp.headers['x-error-msg'], 'missing required argument "arg"')
'missing required argument "arg"')
run_coro(run()) async def test_error(self):
def test_error(self):
@api @api
class API: class API:
def GET(): ... def GET(): ...
@ -147,17 +134,14 @@ class TestBind(unittest.TestCase):
async def GET(self): async def GET(self):
return 1/0 return 1/0
async def run(): async with makeTestClient(API, Impl()) as client:
async with makeTestClient(API, Impl()) as client: resp = await client.get('/')
resp = await client.get('/') self.assertEqual(resp.status, 500)
self.assertEqual(resp.status, 500) self.assertEqual(resp.headers['x-status'], 'error')
self.assertEqual(resp.headers['x-status'], 'error') self.assertEqual(
self.assertEqual( resp.headers['x-error-type'], 'ZeroDivisionError')
resp.headers['x-error-type'], 'ZeroDivisionError')
run_coro(run()) async def test_post(self):
def test_post(self):
@api @api
class API: class API:
def POST(data: Payload[str]) -> str: ... def POST(data: Payload[str]) -> str: ...
@ -166,13 +150,10 @@ class TestBind(unittest.TestCase):
async def POST(self, data: str) -> str: async def POST(self, data: str) -> str:
return data return data
async def run(): async with makeTestClient(API, Impl()) as client:
async with makeTestClient(API, Impl()) as client: await self.assertResponse(
await self.assertResponse( client.post("/", json='value'),
client.post("/", json='value'), 'value')
'value')
run_coro(run())
def test_missing_method(self): def test_missing_method(self):
@api @api
@ -201,7 +182,7 @@ class TestBind(unittest.TestCase):
bind(app.router, API, Impl()) bind(app.router, API, Impl())
self.assertEqual(cm.exception.methname, "API.GET") self.assertEqual(cm.exception.methname, "API.GET")
def test_middleware(self): async def test_middleware(self):
@web.middleware @web.middleware
async def middleware(request, handler): async def middleware(request, handler):
@ -219,16 +200,13 @@ class TestBind(unittest.TestCase):
async def GET(self) -> str: async def GET(self) -> str:
return 1/0 return 1/0
async def run(): async with makeTestClient(
async with makeTestClient( API, Impl(), middlewares=[middleware]) as client:
API, Impl(), middlewares=[middleware]) as client: resp = await client.get("/")
resp = await client.get("/") self.assertEqual(
self.assertEqual( resp.headers['x-error-type'], 'ZeroDivisionError')
resp.headers['x-error-type'], 'ZeroDivisionError')
run_coro(run()) async def test_controller_for_request(self):
def test_controller_for_request(self):
seen_controller = None seen_controller = None
@ -249,12 +227,9 @@ class TestBind(unittest.TestCase):
impl = Impl() impl = Impl()
async def run(): async with makeTestClient(
async with makeTestClient( API.meth, impl, middlewares=[middleware]) as client:
API.meth, impl, middlewares=[middleware]) as client: resp = await client.get("/meth")
resp = await client.get("/meth") self.assertEqual(await resp.json(), '')
self.assertEqual(await resp.json(), '')
run_coro(run())
self.assertIs(impl, seen_controller) self.assertIs(impl, seen_controller)

View File

@ -19,7 +19,6 @@ from unittest import mock
import yaml import yaml
from subiquitycore.pubsub import MessageHub from subiquitycore.pubsub import MessageHub
from subiquitycore.tests.util import run_coro
from subiquity.common.types import IdentityData from subiquity.common.types import IdentityData
from subiquity.models.subiquity import ModelNames, SubiquityModel from subiquity.models.subiquity import ModelNames, SubiquityModel
@ -63,7 +62,7 @@ class TestModelNames(unittest.TestCase):
self.assertEqual(model_names.all(), {'a', 'b', 'c'}) self.assertEqual(model_names.all(), {'a', 'b', 'c'})
class TestSubiquityModel(unittest.TestCase): class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
def writtenFiles(self, config): def writtenFiles(self, config):
for k, v in config.get('write_files', {}).items(): for k, v in config.get('write_files', {}).items():
@ -114,7 +113,7 @@ class TestSubiquityModel(unittest.TestCase):
cur = cur[component] cur = cur[component]
self.fail("config has value {} for {}".format(cur, path)) self.fail("config has value {} for {}".format(cur, path))
async def _test_configure(self): async def test_configure(self):
hub = MessageHub() hub = MessageHub()
model = SubiquityModel( model = SubiquityModel(
'test', hub, ModelNames({'a', 'b'}), ModelNames(set())) 'test', hub, ModelNames({'a', 'b'}), ModelNames(set()))
@ -124,9 +123,6 @@ class TestSubiquityModel(unittest.TestCase):
await hub.abroadcast((InstallerChannels.CONFIGURED, 'b')) await hub.abroadcast((InstallerChannels.CONFIGURED, 'b'))
self.assertTrue(model._install_event.is_set()) self.assertTrue(model._install_event.is_set())
def test_configure(self):
run_coro(self._test_configure())
def make_model(self): def make_model(self):
return SubiquityModel( return SubiquityModel(
'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES) 'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES)

View File

@ -13,13 +13,12 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from unittest import mock, TestCase from unittest import mock, TestCase, IsolatedAsyncioTestCase
from parameterized import parameterized from parameterized import parameterized
from subiquity.server.controllers.filesystem import FilesystemController from subiquity.server.controllers.filesystem import FilesystemController
from subiquitycore.tests.util import run_coro
from subiquitycore.tests.mocks import make_app from subiquitycore.tests.mocks import make_app
from subiquity.common.types import Bootloader from subiquity.common.types import Bootloader
from subiquity.models.tests.test_filesystem import ( from subiquity.models.tests.test_filesystem import (
@ -28,7 +27,7 @@ from subiquity.models.tests.test_filesystem import (
) )
class TestSubiquityControllerFilesystem(TestCase): class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase):
def setUp(self): def setUp(self):
self.app = make_app() self.app = make_app()
self.app.opts.bootloader = 'UEFI' self.app.opts.bootloader = 'UEFI'
@ -38,20 +37,20 @@ class TestSubiquityControllerFilesystem(TestCase):
self.fsc = FilesystemController(app=self.app) self.fsc = FilesystemController(app=self.app)
self.fsc._configured = True self.fsc._configured = True
def test_probe_restricted(self): async def test_probe_restricted(self):
run_coro(self.fsc._probe_once(context=None, restricted=True)) await self.fsc._probe_once(context=None, restricted=True)
self.app.prober.get_storage.assert_called_with({'blockdev'}) self.app.prober.get_storage.assert_called_with({'blockdev'})
def test_probe_os_prober_false(self): async def test_probe_os_prober_false(self):
self.app.opts.use_os_prober = False self.app.opts.use_os_prober = False
run_coro(self.fsc._probe_once(context=None, restricted=False)) await self.fsc._probe_once(context=None, restricted=False)
actual = self.app.prober.get_storage.call_args.args[0] actual = self.app.prober.get_storage.call_args.args[0]
self.assertTrue({'defaults'} <= actual) self.assertTrue({'defaults'} <= actual)
self.assertNotIn('os', actual) self.assertNotIn('os', actual)
def test_probe_os_prober_true(self): async def test_probe_os_prober_true(self):
self.app.opts.use_os_prober = True self.app.opts.use_os_prober = True
run_coro(self.fsc._probe_once(context=None, restricted=False)) await self.fsc._probe_once(context=None, restricted=False)
actual = self.app.prober.get_storage.call_args.args[0] actual = self.app.prober.get_storage.call_args.args[0]
self.assertTrue({'defaults', 'os'} <= actual) self.assertTrue({'defaults', 'os'} <= actual)

View File

@ -17,7 +17,6 @@ from unittest.mock import Mock, patch, AsyncMock
from subiquitycore.tests import SubiTestCase from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.mocks import make_app from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
from subiquity.server.apt import ( from subiquity.server.apt import (
AptConfigurer, AptConfigurer,
Mountpoint, Mountpoint,
@ -51,14 +50,14 @@ class TestAptConfigurer(SubiTestCase):
expected['apt']['https_proxy'] = proxy expected['apt']['https_proxy'] = proxy
self.assertEqual(expected, self.configurer.apt_config()) self.assertEqual(expected, self.configurer.apt_config())
def test_mount_unmount(self): async def test_mount_unmount(self):
# Make sure we can unmount something that we mounted before. # Make sure we can unmount something that we mounted before.
with patch.object(self.app, "command_runner", with patch.object(self.app, "command_runner",
create=True, new_callable=AsyncMock): create=True, new_callable=AsyncMock):
m = run_coro(self.configurer.mount("/dev/cdrom", "/target")) m = await self.configurer.mount("/dev/cdrom", "/target")
run_coro(self.configurer.unmount(m)) await self.configurer.unmount(m)
def test_overlay(self): async def test_overlay(self):
self.configurer.install_tree = OverlayMountpoint( self.configurer.install_tree = OverlayMountpoint(
upperdir="upperdir-install-tree", upperdir="upperdir-install-tree",
lowers=["lowers1-install-tree"], lowers=["lowers1-install-tree"],
@ -71,13 +70,10 @@ class TestAptConfigurer(SubiTestCase):
) )
self.source = "source" self.source = "source"
async def coro():
async with self.configurer.overlay():
pass
with patch.object(self.app, "command_runner", with patch.object(self.app, "command_runner",
create=True, new_callable=AsyncMock): create=True, new_callable=AsyncMock):
run_coro(coro()) async with self.configurer.overlay():
pass
class TestLowerDirFor(SubiTestCase): class TestLowerDirFor(SubiTestCase):

View File

@ -17,7 +17,6 @@ from aioresponses import aioresponses
from subiquitycore.tests import SubiTestCase from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.mocks import make_app from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
from subiquity.server.geoip import ( from subiquity.server.geoip import (
GeoIP, GeoIP,
HTTPGeoIPStrategy, HTTPGeoIPStrategy,
@ -48,16 +47,13 @@ empty_cc = '<Response><CountryCode></CountryCode></Response>'
class TestGeoIP(SubiTestCase): class TestGeoIP(SubiTestCase):
def setUp(self): async def asyncSetUp(self):
strategy = HTTPGeoIPStrategy() strategy = HTTPGeoIPStrategy()
self.geoip = GeoIP(make_app(), strategy) self.geoip = GeoIP(make_app(), strategy)
async def fn():
self.assertTrue(await self.geoip.lookup())
with aioresponses() as mocked: with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=xml) mocked.get("https://geoip.ubuntu.com/lookup", body=xml)
run_coro(fn()) self.assertTrue(await self.geoip.lookup())
def test_countrycode(self): def test_countrycode(self):
self.assertEqual("us", self.geoip.countrycode) self.assertEqual("us", self.geoip.countrycode)
@ -71,42 +67,32 @@ class TestGeoIPBadData(SubiTestCase):
strategy = HTTPGeoIPStrategy() strategy = HTTPGeoIPStrategy()
self.geoip = GeoIP(make_app(), strategy) self.geoip = GeoIP(make_app(), strategy)
def test_partial_reponse(self): async def test_partial_reponse(self):
async def fn():
self.assertFalse(await self.geoip.lookup())
with aioresponses() as mocked: with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=partial) mocked.get("https://geoip.ubuntu.com/lookup", body=partial)
run_coro(fn())
def test_incomplete(self):
async def fn():
self.assertFalse(await self.geoip.lookup()) self.assertFalse(await self.geoip.lookup())
async def test_incomplete(self):
with aioresponses() as mocked: with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=incomplete) mocked.get("https://geoip.ubuntu.com/lookup", body=incomplete)
run_coro(fn()) self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.countrycode) self.assertIsNone(self.geoip.countrycode)
self.assertIsNone(self.geoip.timezone) self.assertIsNone(self.geoip.timezone)
def test_long_cc(self): async def test_long_cc(self):
async def fn():
self.assertFalse(await self.geoip.lookup())
with aioresponses() as mocked: with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=long_cc) mocked.get("https://geoip.ubuntu.com/lookup", body=long_cc)
run_coro(fn()) self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.countrycode) self.assertIsNone(self.geoip.countrycode)
def test_empty_cc(self): async def test_empty_cc(self):
async def fn():
self.assertFalse(await self.geoip.lookup())
with aioresponses() as mocked: with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_cc) mocked.get("https://geoip.ubuntu.com/lookup", body=empty_cc)
run_coro(fn()) self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.countrycode) self.assertIsNone(self.geoip.countrycode)
def test_empty_tz(self): async def test_empty_tz(self):
async def fn():
self.assertFalse(await self.geoip.lookup())
with aioresponses() as mocked: with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_tz) mocked.get("https://geoip.ubuntu.com/lookup", body=empty_tz)
run_coro(fn()) self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.timezone) self.assertIsNone(self.geoip.timezone)

View File

@ -26,37 +26,36 @@ from subiquity.server.ubuntu_advantage import (
MockedUAInterfaceStrategy, MockedUAInterfaceStrategy,
UAClientUAInterfaceStrategy, UAClientUAInterfaceStrategy,
) )
from subiquitycore.tests.util import run_coro
class TestMockedUAInterfaceStrategy(unittest.TestCase): class TestMockedUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase):
def setUp(self): def setUp(self):
self.strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000) self.strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
def test_query_info_invalid(self): async def test_query_info_invalid(self):
# Tokens starting with "i" in dry-run mode cause the token to be # Tokens starting with "i" in dry-run mode cause the token to be
# reported as invalid. # reported as invalid.
with self.assertRaises(InvalidTokenError): with self.assertRaises(InvalidTokenError):
run_coro(self.strategy.query_info(token="invalidToken")) await self.strategy.query_info(token="invalidToken")
def test_query_info_failure(self): async def test_query_info_failure(self):
# Tokens starting with "f" in dry-run mode simulate an "internal" # Tokens starting with "f" in dry-run mode simulate an "internal"
# error. # error.
with self.assertRaises(CheckSubscriptionError): with self.assertRaises(CheckSubscriptionError):
run_coro(self.strategy.query_info(token="failure")) await self.strategy.query_info(token="failure")
def test_query_info_expired(self): async def test_query_info_expired(self):
# Tokens starting with "x" is dry-run mode simulate an expired token. # Tokens starting with "x" is dry-run mode simulate an expired token.
info = run_coro(self.strategy.query_info(token="xpiredToken")) info = await self.strategy.query_info(token="xpiredToken")
self.assertEqual(info["expires"], "2010-12-31T00:00:00+00:00") self.assertEqual(info["expires"], "2010-12-31T00:00:00+00:00")
def test_query_info_valid(self): async def test_query_info_valid(self):
# Other tokens are considered valid in dry-run mode. # Other tokens are considered valid in dry-run mode.
info = run_coro(self.strategy.query_info(token="validToken")) info = await self.strategy.query_info(token="validToken")
self.assertEqual(info["expires"], "2035-12-31T00:00:00+00:00") self.assertEqual(info["expires"], "2035-12-31T00:00:00+00:00")
class TestUAClientUAInterfaceStrategy(unittest.TestCase): class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase):
arun_command_sym = "subiquity.server.ubuntu_advantage.utils.arun_command" arun_command_sym = "subiquity.server.ubuntu_advantage.utils.arun_command"
def test_init(self): def test_init(self):
@ -75,7 +74,7 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
self.assertEqual(strategy.executable, self.assertEqual(strategy.executable,
["python3", "/usr/bin/ubuntu-advantage"]) ["python3", "/usr/bin/ubuntu-advantage"])
def test_query_info_succeeded(self): async def test_query_info_succeeded(self):
strategy = UAClientUAInterfaceStrategy() strategy = UAClientUAInterfaceStrategy()
command = ( command = (
"ubuntu-advantage", "ubuntu-advantage",
@ -87,10 +86,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
with patch(self.arun_command_sym) as mock_arun: with patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0) mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = "{}" mock_arun.return_value.stdout = "{}"
run_coro(strategy.query_info(token="123456789")) await strategy.query_info(token="123456789")
mock_arun.assert_called_once_with(command, check=True) mock_arun.assert_called_once_with(command, check=True)
def test_query_info_failed(self): async def test_query_info_failed(self):
strategy = UAClientUAInterfaceStrategy() strategy = UAClientUAInterfaceStrategy()
command = ( command = (
"ubuntu-advantage", "ubuntu-advantage",
@ -104,10 +103,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
cmd=command) cmd=command)
mock_arun.return_value.stdout = "{}" mock_arun.return_value.stdout = "{}"
with self.assertRaises(CheckSubscriptionError): with self.assertRaises(CheckSubscriptionError):
run_coro(strategy.query_info(token="123456789")) await strategy.query_info(token="123456789")
mock_arun.assert_called_once_with(command, check=True) mock_arun.assert_called_once_with(command, check=True)
def test_query_info_invalid_json(self): async def test_query_info_invalid_json(self):
strategy = UAClientUAInterfaceStrategy() strategy = UAClientUAInterfaceStrategy()
command = ( command = (
"ubuntu-advantage", "ubuntu-advantage",
@ -120,31 +119,31 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
mock_arun.return_value = CompletedProcess([], 0) mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = "invalid-json" mock_arun.return_value.stdout = "invalid-json"
with self.assertRaises(CheckSubscriptionError): with self.assertRaises(CheckSubscriptionError):
run_coro(strategy.query_info(token="123456789")) await strategy.query_info(token="123456789")
mock_arun.assert_called_once_with(command, check=True) mock_arun.assert_called_once_with(command, check=True)
class TestUAInterface(unittest.TestCase): class TestUAInterface(unittest.IsolatedAsyncioTestCase):
def test_mocked_get_activable_services(self): async def test_mocked_get_activable_services(self):
strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000) strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
interface = UAInterface(strategy) interface = UAInterface(strategy)
with self.assertRaises(InvalidTokenError): with self.assertRaises(InvalidTokenError):
run_coro(interface.get_activable_services(token="invalidToken")) await interface.get_activable_services(token="invalidToken")
# Tokens starting with "f" in dry-run mode simulate an "internal" # Tokens starting with "f" in dry-run mode simulate an "internal"
# error. # error.
with self.assertRaises(CheckSubscriptionError): with self.assertRaises(CheckSubscriptionError):
run_coro(interface.get_activable_services(token="failure")) await interface.get_activable_services(token="failure")
# Tokens starting with "x" is dry-run mode simulate an expired token. # Tokens starting with "x" is dry-run mode simulate an expired token.
with self.assertRaises(ExpiredTokenError): with self.assertRaises(ExpiredTokenError):
run_coro(interface.get_activable_services(token="xpiredToken")) await interface.get_activable_services(token="xpiredToken")
# Other tokens are considered valid in dry-run mode. # Other tokens are considered valid in dry-run mode.
run_coro(interface.get_activable_services(token="validToken")) await interface.get_activable_services(token="validToken")
def test_get_activable_services(self): async def test_get_activable_services(self):
# We use the standard strategy but don't actually run it # We use the standard strategy but don't actually run it
strategy = UAClientUAInterfaceStrategy() strategy = UAClientUAInterfaceStrategy()
interface = UAInterface(strategy) interface = UAInterface(strategy)
@ -185,8 +184,7 @@ class TestUAInterface(unittest.TestCase):
] ]
} }
interface.get_subscription = AsyncMock(return_value=subscription) interface.get_subscription = AsyncMock(return_value=subscription)
services = run_coro( services = await interface.get_activable_services(token="XXX")
interface.get_activable_services(token="XXX"))
self.assertIn(UbuntuProService( self.assertIn(UbuntuProService(
name="esm-infra", name="esm-infra",
@ -211,5 +209,4 @@ class TestUAInterface(unittest.TestCase):
# Test with "Z" suffix for the expiration date. # Test with "Z" suffix for the expiration date.
subscription["expires"] = "2035-12-31T00:00:00Z" subscription["expires"] = "2035-12-31T00:00:00Z"
services = run_coro( services = await interface.get_activable_services(token="XXX")
interface.get_activable_services(token="XXX"))

View File

@ -18,7 +18,6 @@ import unittest
from unittest.mock import patch, AsyncMock, Mock from unittest.mock import patch, AsyncMock, Mock
from subiquitycore.tests.mocks import make_app from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
from subiquity.server.ubuntu_drivers import ( from subiquity.server.ubuntu_drivers import (
UbuntuDriversInterface, UbuntuDriversInterface,
@ -28,7 +27,7 @@ from subiquity.server.ubuntu_drivers import (
) )
class TestUbuntuDriversInterface(unittest.TestCase): class TestUbuntuDriversInterface(unittest.IsolatedAsyncioTestCase):
def setUp(self): def setUp(self):
self.app = make_app() self.app = make_app()
@ -54,11 +53,11 @@ class TestUbuntuDriversInterface(unittest.TestCase):
@patch.multiple(UbuntuDriversInterface, __abstractmethods__=set()) @patch.multiple(UbuntuDriversInterface, __abstractmethods__=set())
@patch("subiquity.server.ubuntu_drivers.run_curtin_command") @patch("subiquity.server.ubuntu_drivers.run_curtin_command")
def test_install_drivers(self, mock_run_curtin_command): async def test_install_drivers(self, mock_run_curtin_command):
ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False) ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False)
run_coro(ubuntu_drivers.install_drivers( await ubuntu_drivers.install_drivers(
root_dir="/target", root_dir="/target",
context="installing third-party drivers")) context="installing third-party drivers")
mock_run_curtin_command.assert_called_once_with( mock_run_curtin_command.assert_called_once_with(
self.app, "installing third-party drivers", self.app, "installing third-party drivers",
"in-target", "-t", "/target", "in-target", "-t", "/target",
@ -91,18 +90,18 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
["nvidia-driver-470", "nvidia-driver-510"]) ["nvidia-driver-470", "nvidia-driver-510"])
class TestUbuntuDriversClientInterface(unittest.TestCase): class TestUbuntuDriversClientInterface(unittest.IsolatedAsyncioTestCase):
def setUp(self): def setUp(self):
self.app = make_app() self.app = make_app()
self.ubuntu_drivers = UbuntuDriversClientInterface( self.ubuntu_drivers = UbuntuDriversClientInterface(
self.app, gpgpu=False) self.app, gpgpu=False)
def test_ensure_cmd_exists(self): async def test_ensure_cmd_exists(self):
with patch.object( with patch.object(
self.app, "command_runner", self.app, "command_runner",
create=True, new_callable=AsyncMock) as mock_runner: create=True, new_callable=AsyncMock) as mock_runner:
# On success # On success
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target")) await self.ubuntu_drivers.ensure_cmd_exists("/target")
mock_runner.run.assert_called_once_with( mock_runner.run.assert_called_once_with(
[ [
"chroot", "/target", "chroot", "/target",
@ -115,17 +114,17 @@ class TestUbuntuDriversClientInterface(unittest.TestCase):
cmd=["sh", "-c", "command -v ubuntu-drivers"]) cmd=["sh", "-c", "command -v ubuntu-drivers"])
with self.assertRaises(CommandNotFoundError): with self.assertRaises(CommandNotFoundError):
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target")) await self.ubuntu_drivers.ensure_cmd_exists("/target")
@patch("subiquity.server.ubuntu_drivers.run_curtin_command") @patch("subiquity.server.ubuntu_drivers.run_curtin_command")
def test_list_drivers(self, mock_run_curtin_command): async def test_list_drivers(self, mock_run_curtin_command):
# Make sure this gets decoded as utf-8. # Make sure this gets decoded as utf-8.
mock_run_curtin_command.return_value = Mock(stdout=b"""\ mock_run_curtin_command.return_value = Mock(stdout=b"""\
nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04 nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
""") """)
drivers = run_coro(self.ubuntu_drivers.list_drivers( drivers = await self.ubuntu_drivers.list_drivers(
root_dir="/target", root_dir="/target",
context="listing third-party drivers")) context="listing third-party drivers")
mock_run_curtin_command.assert_called_once_with( mock_run_curtin_command.assert_called_once_with(
self.app, "listing third-party drivers", self.app, "listing third-party drivers",
@ -137,15 +136,15 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
self.assertEqual(drivers, ["nvidia-driver-510"]) self.assertEqual(drivers, ["nvidia-driver-510"])
class TestUbuntuDriversRunDriversInterface(unittest.TestCase): class TestUbuntuDriversRunDriversInterface(unittest.IsolatedAsyncioTestCase):
def setUp(self): def setUp(self):
self.app = make_app() self.app = make_app()
self.ubuntu_drivers = UbuntuDriversRunDriversInterface( self.ubuntu_drivers = UbuntuDriversRunDriversInterface(
self.app, gpgpu=False) self.app, gpgpu=False)
@patch("subiquity.server.ubuntu_drivers.arun_command") @patch("subiquity.server.ubuntu_drivers.arun_command")
def test_ensure_cmd_exists(self, mock_arun_command): async def test_ensure_cmd_exists(self, mock_arun_command):
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target")) await self.ubuntu_drivers.ensure_cmd_exists("/target")
mock_arun_command.assert_called_once_with( mock_arun_command.assert_called_once_with(
["sh", "-c", "command -v ubuntu-drivers"], ["sh", "-c", "command -v ubuntu-drivers"],
check=True) check=True)

View File

@ -7,7 +7,6 @@ from functools import wraps
import json import json
import os import os
import tempfile import tempfile
import unittest
from unittest.mock import patch from unittest.mock import patch
from urllib.parse import unquote from urllib.parse import unquote
@ -134,7 +133,7 @@ class Server(Client):
pass pass
class TestAPI(unittest.IsolatedAsyncioTestCase, SubiTestCase): class TestAPI(SubiTestCase):
class _MachineConfig(os.PathLike): class _MachineConfig(os.PathLike):
def __init__(self, outer, path): def __init__(self, outer, path):
self.outer = outer self.outer = outer

View File

@ -20,7 +20,6 @@ from subiquity.models.timezone import TimeZoneModel
from subiquity.server.controllers.timezone import TimeZoneController from subiquity.server.controllers.timezone import TimeZoneController
from subiquitycore.tests import SubiTestCase from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.mocks import make_app from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
class MockGeoIP: class MockGeoIP:
@ -47,7 +46,7 @@ class TestTimeZoneController(SubiTestCase):
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz') @mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
@mock.patch('subiquity.server.controllers.timezone.timedatectl_gettz') @mock.patch('subiquity.server.controllers.timezone.timedatectl_gettz')
def test_good_tzs(self, tdc_gettz, tdc_settz): async def test_good_tzs(self, tdc_gettz, tdc_settz):
tdc_gettz.return_value = tz_utc tdc_gettz.return_value = tz_utc
goods = [ goods = [
# val - autoinstall value # val - autoinstall value
@ -77,7 +76,7 @@ class TestTimeZoneController(SubiTestCase):
self.tzc.model) self.tzc.model)
self.assertEqual(geoip, self.tzc.model.detect_with_geoip, self.assertEqual(geoip, self.tzc.model.detect_with_geoip,
self.tzc.model) self.tzc.model)
self.assertEqual(tz, run_coro(self.tzc.GET()), self.tzc.model) self.assertEqual(tz, await self.tzc.GET(), self.tzc.model)
cloudconfig = {} cloudconfig = {}
if self.tzc.model.should_set_tz: if self.tzc.model.should_set_tz:
cloudconfig = {'timezone': tz.timezone} cloudconfig = {'timezone': tz.timezone}
@ -104,7 +103,7 @@ class TestTimeZoneController(SubiTestCase):
self.assertEqual('sleep', subprocess_run.call_args.args[0][0]) self.assertEqual('sleep', subprocess_run.call_args.args[0][0])
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz') @mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
def test_get_tz_should_not_set(self, tdc_settz): async def test_get_tz_should_not_set(self, tdc_settz):
run_coro(self.tzc.GET()) await self.tzc.GET()
self.assertFalse(self.tzc.model.should_set_tz) self.assertFalse(self.tzc.model.should_set_tz)
tdc_settz.assert_not_called() tdc_settz.assert_not_called()

View File

@ -1,4 +1,18 @@
import asyncio # Copyright 2017-2022 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from unittest import mock from unittest import mock
@ -42,20 +56,6 @@ system_reserved = {
} }
def tearDownModule() -> None:
# Set empty loop policy, so that subsequent get_event_loop() returns a new
# loop. If there is no running event loop set, that function will return
# the result of `get_event_loop_policy().get_event_loop()` call. If there
# is a policy there must be a running loop. IsolatedAsyncioTestCase
# closes the loop during tear down, though. It doesn't touch the policy.
# By having it as None, it autoinits and the next tests run smoothly.
# Another approach would be set a new event_loop for the current policy on
# test fixture tearDown, as pytest-asyncio does.
# Either way we would prevent failure on tests that depend on [run_coro]
# (subiquitycore/tests/util.py), for instance.
asyncio.set_event_loop_policy(None)
class IdentityViewTests(unittest.IsolatedAsyncioTestCase): class IdentityViewTests(unittest.IsolatedAsyncioTestCase):
def make_view(self): def make_view(self):

View File

@ -3,10 +3,10 @@ import os
import shutil import shutil
import tempfile import tempfile
from unittest import TestCase from unittest import IsolatedAsyncioTestCase
class SubiTestCase(TestCase): class SubiTestCase(IsolatedAsyncioTestCase):
def tmp_dir(self, dir=None, cleanup=True): def tmp_dir(self, dir=None, cleanup=True):
# return a full path to a temporary directory that will be cleaned up. # return a full path to a temporary directory that will be cleaned up.
if dir is None: if dir is None:

View File

@ -13,14 +13,9 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import random import random
import string import string
def run_coro(coro):
return asyncio.get_event_loop().run_until_complete(coro)
def random_string(): def random_string():
return ''.join(random.choice(string.ascii_letters) for _ in range(8)) return ''.join(random.choice(string.ascii_letters) for _ in range(8))