From dffb15520392edfe428643686addccc0f2209527 Mon Sep 17 00:00:00 2001 From: Dan Bungert Date: Thu, 23 Jun 2022 17:00:45 -0600 Subject: [PATCH] tests: s/run_coro/IsolatedAsyncioTestCase/ Tests that use run_coro are at risk of being broken by the introduction of another test using IsolatedAsyncioTestCase. Switch over to only use IsolatedAsyncioTestCase. --- subiquity/common/api/tests/test_endtoend.py | 123 +++++++----------- subiquity/common/api/tests/test_server.py | 115 +++++++--------- subiquity/models/tests/test_subiquity.py | 8 +- .../controllers/tests/test_filesystem.py | 17 ++- subiquity/server/tests/test_apt.py | 16 +-- subiquity/server/tests/test_geoip.py | 38 ++---- .../server/tests/test_ubuntu_advantage.py | 53 ++++---- subiquity/server/tests/test_ubuntu_drivers.py | 29 ++--- subiquity/tests/api/test_api.py | 3 +- subiquity/tests/test_timezonecontroller.py | 9 +- subiquity/ui/views/tests/test_identity.py | 30 ++--- subiquitycore/tests/__init__.py | 4 +- subiquitycore/tests/util.py | 5 - 13 files changed, 181 insertions(+), 269 deletions(-) diff --git a/subiquity/common/api/tests/test_endtoend.py b/subiquity/common/api/tests/test_endtoend.py index 593a76e2..b611033b 100644 --- a/subiquity/common/api/tests/test_endtoend.py +++ b/subiquity/common/api/tests/test_endtoend.py @@ -21,8 +21,6 @@ import unittest import aiohttp from aiohttp import web -from subiquitycore.tests.util import run_coro - from subiquity.common.api.client import make_client from subiquity.common.api.defs import api, Payload @@ -46,9 +44,9 @@ async def makeE2EClient(api, impl, yield make_client(api, mr) -class TestEndToEnd(unittest.TestCase): +class TestEndToEnd(unittest.IsolatedAsyncioTestCase): - def test_simple(self): + async def test_simple(self): @api class API: def GET() -> str: ... @@ -57,13 +55,10 @@ class TestEndToEnd(unittest.TestCase): async def GET(self) -> str: return 'value' - async def run(): - async with makeE2EClient(API, Impl()) as client: - self.assertEqual(await client.GET(), 'value') + async with makeE2EClient(API, Impl()) as client: + self.assertEqual(await client.GET(), 'value') - run_coro(run()) - - def test_nested(self): + async def test_nested(self): @api class API: class endpoint: @@ -74,13 +69,10 @@ class TestEndToEnd(unittest.TestCase): async def endpoint_nested_GET(self) -> str: return 'value' - async def run(): - async with makeE2EClient(API, Impl()) as client: - self.assertEqual(await client.endpoint.nested.GET(), 'value') + async with makeE2EClient(API, Impl()) as client: + self.assertEqual(await client.endpoint.nested.GET(), 'value') - run_coro(run()) - - def test_args(self): + async def test_args(self): @api class API: def GET(arg1: str, arg2: str) -> str: ... @@ -89,14 +81,11 @@ class TestEndToEnd(unittest.TestCase): async def GET(self, arg1: str, arg2: str) -> str: return '{}+{}'.format(arg1, arg2) - async def run(): - async with makeE2EClient(API, Impl()) as client: - self.assertEqual( - await client.GET(arg1="A", arg2="B"), 'A+B') + async with makeE2EClient(API, Impl()) as client: + self.assertEqual( + await client.GET(arg1="A", arg2="B"), 'A+B') - run_coro(run()) - - def test_defaults(self): + async def test_defaults(self): @api class API: def GET(arg1: str, arg2: str = "arg2") -> str: ... @@ -105,16 +94,13 @@ class TestEndToEnd(unittest.TestCase): async def GET(self, arg1: str, arg2: str = "arg2") -> str: return '{}+{}'.format(arg1, arg2) - async def run(): - async with makeE2EClient(API, Impl()) as client: - self.assertEqual( - await client.GET(arg1="A", arg2="B"), 'A+B') - self.assertEqual( - await client.GET(arg1="A"), 'A+arg2') + async with makeE2EClient(API, Impl()) as client: + self.assertEqual( + await client.GET(arg1="A", arg2="B"), 'A+B') + self.assertEqual( + await client.GET(arg1="A"), 'A+arg2') - run_coro(run()) - - def test_post(self): + async def test_post(self): @api class API: def POST(data: Payload[dict]) -> str: ... @@ -123,14 +109,11 @@ class TestEndToEnd(unittest.TestCase): async def POST(self, data: dict) -> str: return data['key'] - async def run(): - async with makeE2EClient(API, Impl()) as client: - self.assertEqual( - await client.POST({'key': 'value'}), 'value') + async with makeE2EClient(API, Impl()) as client: + self.assertEqual( + await client.POST({'key': 'value'}), 'value') - run_coro(run()) - - def test_typed(self): + async def test_typed(self): @attr.s(auto_attribs=True) class In: @@ -149,14 +132,11 @@ class TestEndToEnd(unittest.TestCase): async def doubler_POST(self, data: In) -> Out: return Out(doubled=data.val*2) - async def run(): - async with makeE2EClient(API, Impl()) as client: - out = await client.doubler.POST(In(3)) - self.assertEqual(out.doubled, 6) + async with makeE2EClient(API, Impl()) as client: + out = await client.doubler.POST(In(3)) + self.assertEqual(out.doubled, 6) - run_coro(run()) - - def test_middleware(self): + async def test_middleware(self): @api class API: def GET() -> int: ... @@ -182,17 +162,14 @@ class TestEndToEnd(unittest.TestCase): raise Skip yield resp - async def run(): - async with makeE2EClient( - API, Impl(), - middlewares=[middleware], - make_request=custom_make_request) as client: - with self.assertRaises(Skip): - await client.GET() + async with makeE2EClient( + API, Impl(), + middlewares=[middleware], + make_request=custom_make_request) as client: + with self.assertRaises(Skip): + await client.GET() - run_coro(run()) - - def test_error(self): + async def test_error(self): @api class API: class good: @@ -208,16 +185,13 @@ class TestEndToEnd(unittest.TestCase): async def bad_GET(self, x: int) -> int: raise Exception("baz") - async def run(): - async with makeE2EClient(API, Impl()) as client: - r = await client.good.GET(2) - self.assertEqual(r, 3) - with self.assertRaises(aiohttp.ClientResponseError): - await client.bad.GET(2) + async with makeE2EClient(API, Impl()) as client: + r = await client.good.GET(2) + self.assertEqual(r, 3) + with self.assertRaises(aiohttp.ClientResponseError): + await client.bad.GET(2) - run_coro(run()) - - def test_error_middleware(self): + async def test_error_middleware(self): @api class API: class good: @@ -251,14 +225,11 @@ class TestEndToEnd(unittest.TestCase): raise Abort yield resp - async def run(): - async with makeE2EClient( - API, Impl(), - middlewares=[middleware], - make_request=custom_make_request) as client: - r = await client.good.GET(2) - self.assertEqual(r, 3) - with self.assertRaises(Abort): - await client.bad.GET(2) - - run_coro(run()) + async with makeE2EClient( + API, Impl(), + middlewares=[middleware], + make_request=custom_make_request) as client: + r = await client.good.GET(2) + self.assertEqual(r, 3) + with self.assertRaises(Abort): + await client.bad.GET(2) diff --git a/subiquity/common/api/tests/test_server.py b/subiquity/common/api/tests/test_server.py index 89d75413..d7cc3191 100644 --- a/subiquity/common/api/tests/test_server.py +++ b/subiquity/common/api/tests/test_server.py @@ -20,7 +20,6 @@ from aiohttp.test_utils import TestClient, TestServer from aiohttp import web from subiquitycore.context import Context -from subiquitycore.tests.util import run_coro from subiquity.common.api.defs import api, Payload from subiquity.common.api.server import ( @@ -56,7 +55,7 @@ async def makeTestClient(api, impl, middlewares=()): yield client -class TestBind(unittest.TestCase): +class TestBind(unittest.IsolatedAsyncioTestCase): async def assertResponse(self, coro, value): resp = await coro @@ -67,7 +66,7 @@ class TestBind(unittest.TestCase): self.assertEqual(resp.status, 200) self.assertEqual(await resp.json(), value) - def test_simple(self): + async def test_simple(self): @api class API: def GET() -> str: ... @@ -76,14 +75,11 @@ class TestBind(unittest.TestCase): async def GET(self) -> str: return 'value' - async def run(): - async with makeTestClient(API, Impl()) as client: - await self.assertResponse( - client.get("/"), 'value') + async with makeTestClient(API, Impl()) as client: + await self.assertResponse( + client.get("/"), 'value') - run_coro(run()) - - def test_nested(self): + async def test_nested(self): @api class API: class endpoint: @@ -94,14 +90,11 @@ class TestBind(unittest.TestCase): async def nested_get(self, request, context): return 'nested' - async def run(): - async with makeTestClient(API.endpoint, Impl()) as client: - await self.assertResponse( - client.get("/endpoint/nested"), 'nested') + async with makeTestClient(API.endpoint, Impl()) as client: + await self.assertResponse( + client.get("/endpoint/nested"), 'nested') - run_coro(run()) - - def test_args(self): + async def test_args(self): @api class API: def GET(arg: str): ... @@ -110,14 +103,11 @@ class TestBind(unittest.TestCase): async def GET(self, arg: str): return arg - async def run(): - async with makeTestClient(API, Impl()) as client: - await self.assertResponse( - client.get('/?arg="whut"'), 'whut') + async with makeTestClient(API, Impl()) as client: + await self.assertResponse( + client.get('/?arg="whut"'), 'whut') - run_coro(run()) - - def test_missing_argument(self): + async def test_missing_argument(self): @api class API: def GET(arg: str): ... @@ -126,19 +116,16 @@ class TestBind(unittest.TestCase): async def GET(self, arg: str): return arg - async def run(): - async with makeTestClient(API, Impl()) as client: - resp = await client.get('/') - self.assertEqual(resp.status, 500) - self.assertEqual(resp.headers['x-status'], 'error') - self.assertEqual(resp.headers['x-error-type'], 'TypeError') - self.assertEqual( - resp.headers['x-error-msg'], - 'missing required argument "arg"') + async with makeTestClient(API, Impl()) as client: + resp = await client.get('/') + self.assertEqual(resp.status, 500) + self.assertEqual(resp.headers['x-status'], 'error') + self.assertEqual(resp.headers['x-error-type'], 'TypeError') + self.assertEqual( + resp.headers['x-error-msg'], + 'missing required argument "arg"') - run_coro(run()) - - def test_error(self): + async def test_error(self): @api class API: def GET(): ... @@ -147,17 +134,14 @@ class TestBind(unittest.TestCase): async def GET(self): return 1/0 - async def run(): - async with makeTestClient(API, Impl()) as client: - resp = await client.get('/') - self.assertEqual(resp.status, 500) - self.assertEqual(resp.headers['x-status'], 'error') - self.assertEqual( - resp.headers['x-error-type'], 'ZeroDivisionError') + async with makeTestClient(API, Impl()) as client: + resp = await client.get('/') + self.assertEqual(resp.status, 500) + self.assertEqual(resp.headers['x-status'], 'error') + self.assertEqual( + resp.headers['x-error-type'], 'ZeroDivisionError') - run_coro(run()) - - def test_post(self): + async def test_post(self): @api class API: def POST(data: Payload[str]) -> str: ... @@ -166,13 +150,10 @@ class TestBind(unittest.TestCase): async def POST(self, data: str) -> str: return data - async def run(): - async with makeTestClient(API, Impl()) as client: - await self.assertResponse( - client.post("/", json='value'), - 'value') - - run_coro(run()) + async with makeTestClient(API, Impl()) as client: + await self.assertResponse( + client.post("/", json='value'), + 'value') def test_missing_method(self): @api @@ -201,7 +182,7 @@ class TestBind(unittest.TestCase): bind(app.router, API, Impl()) self.assertEqual(cm.exception.methname, "API.GET") - def test_middleware(self): + async def test_middleware(self): @web.middleware async def middleware(request, handler): @@ -219,16 +200,13 @@ class TestBind(unittest.TestCase): async def GET(self) -> str: return 1/0 - async def run(): - async with makeTestClient( - API, Impl(), middlewares=[middleware]) as client: - resp = await client.get("/") - self.assertEqual( - resp.headers['x-error-type'], 'ZeroDivisionError') + async with makeTestClient( + API, Impl(), middlewares=[middleware]) as client: + resp = await client.get("/") + self.assertEqual( + resp.headers['x-error-type'], 'ZeroDivisionError') - run_coro(run()) - - def test_controller_for_request(self): + async def test_controller_for_request(self): seen_controller = None @@ -249,12 +227,9 @@ class TestBind(unittest.TestCase): impl = Impl() - async def run(): - async with makeTestClient( - API.meth, impl, middlewares=[middleware]) as client: - resp = await client.get("/meth") - self.assertEqual(await resp.json(), '') - - run_coro(run()) + async with makeTestClient( + API.meth, impl, middlewares=[middleware]) as client: + resp = await client.get("/meth") + self.assertEqual(await resp.json(), '') self.assertIs(impl, seen_controller) diff --git a/subiquity/models/tests/test_subiquity.py b/subiquity/models/tests/test_subiquity.py index c92b761b..065e243e 100644 --- a/subiquity/models/tests/test_subiquity.py +++ b/subiquity/models/tests/test_subiquity.py @@ -19,7 +19,6 @@ from unittest import mock import yaml from subiquitycore.pubsub import MessageHub -from subiquitycore.tests.util import run_coro from subiquity.common.types import IdentityData from subiquity.models.subiquity import ModelNames, SubiquityModel @@ -63,7 +62,7 @@ class TestModelNames(unittest.TestCase): self.assertEqual(model_names.all(), {'a', 'b', 'c'}) -class TestSubiquityModel(unittest.TestCase): +class TestSubiquityModel(unittest.IsolatedAsyncioTestCase): def writtenFiles(self, config): for k, v in config.get('write_files', {}).items(): @@ -114,7 +113,7 @@ class TestSubiquityModel(unittest.TestCase): cur = cur[component] self.fail("config has value {} for {}".format(cur, path)) - async def _test_configure(self): + async def test_configure(self): hub = MessageHub() model = SubiquityModel( 'test', hub, ModelNames({'a', 'b'}), ModelNames(set())) @@ -124,9 +123,6 @@ class TestSubiquityModel(unittest.TestCase): await hub.abroadcast((InstallerChannels.CONFIGURED, 'b')) self.assertTrue(model._install_event.is_set()) - def test_configure(self): - run_coro(self._test_configure()) - def make_model(self): return SubiquityModel( 'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES) diff --git a/subiquity/server/controllers/tests/test_filesystem.py b/subiquity/server/controllers/tests/test_filesystem.py index d52af07c..a117975b 100644 --- a/subiquity/server/controllers/tests/test_filesystem.py +++ b/subiquity/server/controllers/tests/test_filesystem.py @@ -13,13 +13,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from unittest import mock, TestCase +from unittest import mock, TestCase, IsolatedAsyncioTestCase from parameterized import parameterized from subiquity.server.controllers.filesystem import FilesystemController -from subiquitycore.tests.util import run_coro from subiquitycore.tests.mocks import make_app from subiquity.common.types import Bootloader from subiquity.models.tests.test_filesystem import ( @@ -28,7 +27,7 @@ from subiquity.models.tests.test_filesystem import ( ) -class TestSubiquityControllerFilesystem(TestCase): +class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase): def setUp(self): self.app = make_app() self.app.opts.bootloader = 'UEFI' @@ -38,20 +37,20 @@ class TestSubiquityControllerFilesystem(TestCase): self.fsc = FilesystemController(app=self.app) self.fsc._configured = True - def test_probe_restricted(self): - run_coro(self.fsc._probe_once(context=None, restricted=True)) + async def test_probe_restricted(self): + await self.fsc._probe_once(context=None, restricted=True) self.app.prober.get_storage.assert_called_with({'blockdev'}) - def test_probe_os_prober_false(self): + async def test_probe_os_prober_false(self): self.app.opts.use_os_prober = False - run_coro(self.fsc._probe_once(context=None, restricted=False)) + await self.fsc._probe_once(context=None, restricted=False) actual = self.app.prober.get_storage.call_args.args[0] self.assertTrue({'defaults'} <= actual) self.assertNotIn('os', actual) - def test_probe_os_prober_true(self): + async def test_probe_os_prober_true(self): self.app.opts.use_os_prober = True - run_coro(self.fsc._probe_once(context=None, restricted=False)) + await self.fsc._probe_once(context=None, restricted=False) actual = self.app.prober.get_storage.call_args.args[0] self.assertTrue({'defaults', 'os'} <= actual) diff --git a/subiquity/server/tests/test_apt.py b/subiquity/server/tests/test_apt.py index 44a73d81..0a2e4960 100644 --- a/subiquity/server/tests/test_apt.py +++ b/subiquity/server/tests/test_apt.py @@ -17,7 +17,6 @@ from unittest.mock import Mock, patch, AsyncMock from subiquitycore.tests import SubiTestCase from subiquitycore.tests.mocks import make_app -from subiquitycore.tests.util import run_coro from subiquity.server.apt import ( AptConfigurer, Mountpoint, @@ -51,14 +50,14 @@ class TestAptConfigurer(SubiTestCase): expected['apt']['https_proxy'] = proxy self.assertEqual(expected, self.configurer.apt_config()) - def test_mount_unmount(self): + async def test_mount_unmount(self): # Make sure we can unmount something that we mounted before. with patch.object(self.app, "command_runner", create=True, new_callable=AsyncMock): - m = run_coro(self.configurer.mount("/dev/cdrom", "/target")) - run_coro(self.configurer.unmount(m)) + m = await self.configurer.mount("/dev/cdrom", "/target") + await self.configurer.unmount(m) - def test_overlay(self): + async def test_overlay(self): self.configurer.install_tree = OverlayMountpoint( upperdir="upperdir-install-tree", lowers=["lowers1-install-tree"], @@ -71,13 +70,10 @@ class TestAptConfigurer(SubiTestCase): ) self.source = "source" - async def coro(): - async with self.configurer.overlay(): - pass - with patch.object(self.app, "command_runner", create=True, new_callable=AsyncMock): - run_coro(coro()) + async with self.configurer.overlay(): + pass class TestLowerDirFor(SubiTestCase): diff --git a/subiquity/server/tests/test_geoip.py b/subiquity/server/tests/test_geoip.py index 835292e8..5807331b 100644 --- a/subiquity/server/tests/test_geoip.py +++ b/subiquity/server/tests/test_geoip.py @@ -17,7 +17,6 @@ from aioresponses import aioresponses from subiquitycore.tests import SubiTestCase from subiquitycore.tests.mocks import make_app -from subiquitycore.tests.util import run_coro from subiquity.server.geoip import ( GeoIP, HTTPGeoIPStrategy, @@ -48,16 +47,13 @@ empty_cc = '' class TestGeoIP(SubiTestCase): - def setUp(self): + async def asyncSetUp(self): strategy = HTTPGeoIPStrategy() self.geoip = GeoIP(make_app(), strategy) - async def fn(): - self.assertTrue(await self.geoip.lookup()) - with aioresponses() as mocked: mocked.get("https://geoip.ubuntu.com/lookup", body=xml) - run_coro(fn()) + self.assertTrue(await self.geoip.lookup()) def test_countrycode(self): self.assertEqual("us", self.geoip.countrycode) @@ -71,42 +67,32 @@ class TestGeoIPBadData(SubiTestCase): strategy = HTTPGeoIPStrategy() self.geoip = GeoIP(make_app(), strategy) - def test_partial_reponse(self): - async def fn(): - self.assertFalse(await self.geoip.lookup()) + async def test_partial_reponse(self): with aioresponses() as mocked: mocked.get("https://geoip.ubuntu.com/lookup", body=partial) - run_coro(fn()) - - def test_incomplete(self): - async def fn(): self.assertFalse(await self.geoip.lookup()) + + async def test_incomplete(self): with aioresponses() as mocked: mocked.get("https://geoip.ubuntu.com/lookup", body=incomplete) - run_coro(fn()) + self.assertFalse(await self.geoip.lookup()) self.assertIsNone(self.geoip.countrycode) self.assertIsNone(self.geoip.timezone) - def test_long_cc(self): - async def fn(): - self.assertFalse(await self.geoip.lookup()) + async def test_long_cc(self): with aioresponses() as mocked: mocked.get("https://geoip.ubuntu.com/lookup", body=long_cc) - run_coro(fn()) + self.assertFalse(await self.geoip.lookup()) self.assertIsNone(self.geoip.countrycode) - def test_empty_cc(self): - async def fn(): - self.assertFalse(await self.geoip.lookup()) + async def test_empty_cc(self): with aioresponses() as mocked: mocked.get("https://geoip.ubuntu.com/lookup", body=empty_cc) - run_coro(fn()) + self.assertFalse(await self.geoip.lookup()) self.assertIsNone(self.geoip.countrycode) - def test_empty_tz(self): - async def fn(): - self.assertFalse(await self.geoip.lookup()) + async def test_empty_tz(self): with aioresponses() as mocked: mocked.get("https://geoip.ubuntu.com/lookup", body=empty_tz) - run_coro(fn()) + self.assertFalse(await self.geoip.lookup()) self.assertIsNone(self.geoip.timezone) diff --git a/subiquity/server/tests/test_ubuntu_advantage.py b/subiquity/server/tests/test_ubuntu_advantage.py index e6d9a217..4df6a76c 100644 --- a/subiquity/server/tests/test_ubuntu_advantage.py +++ b/subiquity/server/tests/test_ubuntu_advantage.py @@ -26,37 +26,36 @@ from subiquity.server.ubuntu_advantage import ( MockedUAInterfaceStrategy, UAClientUAInterfaceStrategy, ) -from subiquitycore.tests.util import run_coro -class TestMockedUAInterfaceStrategy(unittest.TestCase): +class TestMockedUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): def setUp(self): self.strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000) - def test_query_info_invalid(self): + async def test_query_info_invalid(self): # Tokens starting with "i" in dry-run mode cause the token to be # reported as invalid. with self.assertRaises(InvalidTokenError): - run_coro(self.strategy.query_info(token="invalidToken")) + await self.strategy.query_info(token="invalidToken") - def test_query_info_failure(self): + async def test_query_info_failure(self): # Tokens starting with "f" in dry-run mode simulate an "internal" # error. with self.assertRaises(CheckSubscriptionError): - run_coro(self.strategy.query_info(token="failure")) + await self.strategy.query_info(token="failure") - def test_query_info_expired(self): + async def test_query_info_expired(self): # Tokens starting with "x" is dry-run mode simulate an expired token. - info = run_coro(self.strategy.query_info(token="xpiredToken")) + info = await self.strategy.query_info(token="xpiredToken") self.assertEqual(info["expires"], "2010-12-31T00:00:00+00:00") - def test_query_info_valid(self): + async def test_query_info_valid(self): # Other tokens are considered valid in dry-run mode. - info = run_coro(self.strategy.query_info(token="validToken")) + info = await self.strategy.query_info(token="validToken") self.assertEqual(info["expires"], "2035-12-31T00:00:00+00:00") -class TestUAClientUAInterfaceStrategy(unittest.TestCase): +class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): arun_command_sym = "subiquity.server.ubuntu_advantage.utils.arun_command" def test_init(self): @@ -75,7 +74,7 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase): self.assertEqual(strategy.executable, ["python3", "/usr/bin/ubuntu-advantage"]) - def test_query_info_succeeded(self): + async def test_query_info_succeeded(self): strategy = UAClientUAInterfaceStrategy() command = ( "ubuntu-advantage", @@ -87,10 +86,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase): with patch(self.arun_command_sym) as mock_arun: mock_arun.return_value = CompletedProcess([], 0) mock_arun.return_value.stdout = "{}" - run_coro(strategy.query_info(token="123456789")) + await strategy.query_info(token="123456789") mock_arun.assert_called_once_with(command, check=True) - def test_query_info_failed(self): + async def test_query_info_failed(self): strategy = UAClientUAInterfaceStrategy() command = ( "ubuntu-advantage", @@ -104,10 +103,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase): cmd=command) mock_arun.return_value.stdout = "{}" with self.assertRaises(CheckSubscriptionError): - run_coro(strategy.query_info(token="123456789")) + await strategy.query_info(token="123456789") mock_arun.assert_called_once_with(command, check=True) - def test_query_info_invalid_json(self): + async def test_query_info_invalid_json(self): strategy = UAClientUAInterfaceStrategy() command = ( "ubuntu-advantage", @@ -120,31 +119,31 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase): mock_arun.return_value = CompletedProcess([], 0) mock_arun.return_value.stdout = "invalid-json" with self.assertRaises(CheckSubscriptionError): - run_coro(strategy.query_info(token="123456789")) + await strategy.query_info(token="123456789") mock_arun.assert_called_once_with(command, check=True) -class TestUAInterface(unittest.TestCase): +class TestUAInterface(unittest.IsolatedAsyncioTestCase): - def test_mocked_get_activable_services(self): + async def test_mocked_get_activable_services(self): strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000) interface = UAInterface(strategy) with self.assertRaises(InvalidTokenError): - run_coro(interface.get_activable_services(token="invalidToken")) + await interface.get_activable_services(token="invalidToken") # Tokens starting with "f" in dry-run mode simulate an "internal" # error. with self.assertRaises(CheckSubscriptionError): - run_coro(interface.get_activable_services(token="failure")) + await interface.get_activable_services(token="failure") # Tokens starting with "x" is dry-run mode simulate an expired token. with self.assertRaises(ExpiredTokenError): - run_coro(interface.get_activable_services(token="xpiredToken")) + await interface.get_activable_services(token="xpiredToken") # Other tokens are considered valid in dry-run mode. - run_coro(interface.get_activable_services(token="validToken")) + await interface.get_activable_services(token="validToken") - def test_get_activable_services(self): + async def test_get_activable_services(self): # We use the standard strategy but don't actually run it strategy = UAClientUAInterfaceStrategy() interface = UAInterface(strategy) @@ -185,8 +184,7 @@ class TestUAInterface(unittest.TestCase): ] } interface.get_subscription = AsyncMock(return_value=subscription) - services = run_coro( - interface.get_activable_services(token="XXX")) + services = await interface.get_activable_services(token="XXX") self.assertIn(UbuntuProService( name="esm-infra", @@ -211,5 +209,4 @@ class TestUAInterface(unittest.TestCase): # Test with "Z" suffix for the expiration date. subscription["expires"] = "2035-12-31T00:00:00Z" - services = run_coro( - interface.get_activable_services(token="XXX")) + services = await interface.get_activable_services(token="XXX") diff --git a/subiquity/server/tests/test_ubuntu_drivers.py b/subiquity/server/tests/test_ubuntu_drivers.py index 4a658e6f..91c917bf 100644 --- a/subiquity/server/tests/test_ubuntu_drivers.py +++ b/subiquity/server/tests/test_ubuntu_drivers.py @@ -18,7 +18,6 @@ import unittest from unittest.mock import patch, AsyncMock, Mock from subiquitycore.tests.mocks import make_app -from subiquitycore.tests.util import run_coro from subiquity.server.ubuntu_drivers import ( UbuntuDriversInterface, @@ -28,7 +27,7 @@ from subiquity.server.ubuntu_drivers import ( ) -class TestUbuntuDriversInterface(unittest.TestCase): +class TestUbuntuDriversInterface(unittest.IsolatedAsyncioTestCase): def setUp(self): self.app = make_app() @@ -54,11 +53,11 @@ class TestUbuntuDriversInterface(unittest.TestCase): @patch.multiple(UbuntuDriversInterface, __abstractmethods__=set()) @patch("subiquity.server.ubuntu_drivers.run_curtin_command") - def test_install_drivers(self, mock_run_curtin_command): + async def test_install_drivers(self, mock_run_curtin_command): ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False) - run_coro(ubuntu_drivers.install_drivers( + await ubuntu_drivers.install_drivers( root_dir="/target", - context="installing third-party drivers")) + context="installing third-party drivers") mock_run_curtin_command.assert_called_once_with( self.app, "installing third-party drivers", "in-target", "-t", "/target", @@ -91,18 +90,18 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04 ["nvidia-driver-470", "nvidia-driver-510"]) -class TestUbuntuDriversClientInterface(unittest.TestCase): +class TestUbuntuDriversClientInterface(unittest.IsolatedAsyncioTestCase): def setUp(self): self.app = make_app() self.ubuntu_drivers = UbuntuDriversClientInterface( self.app, gpgpu=False) - def test_ensure_cmd_exists(self): + async def test_ensure_cmd_exists(self): with patch.object( self.app, "command_runner", create=True, new_callable=AsyncMock) as mock_runner: # On success - run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target")) + await self.ubuntu_drivers.ensure_cmd_exists("/target") mock_runner.run.assert_called_once_with( [ "chroot", "/target", @@ -115,17 +114,17 @@ class TestUbuntuDriversClientInterface(unittest.TestCase): cmd=["sh", "-c", "command -v ubuntu-drivers"]) with self.assertRaises(CommandNotFoundError): - run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target")) + await self.ubuntu_drivers.ensure_cmd_exists("/target") @patch("subiquity.server.ubuntu_drivers.run_curtin_command") - def test_list_drivers(self, mock_run_curtin_command): + async def test_list_drivers(self, mock_run_curtin_command): # Make sure this gets decoded as utf-8. mock_run_curtin_command.return_value = Mock(stdout=b"""\ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04 """) - drivers = run_coro(self.ubuntu_drivers.list_drivers( + drivers = await self.ubuntu_drivers.list_drivers( root_dir="/target", - context="listing third-party drivers")) + context="listing third-party drivers") mock_run_curtin_command.assert_called_once_with( self.app, "listing third-party drivers", @@ -137,15 +136,15 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04 self.assertEqual(drivers, ["nvidia-driver-510"]) -class TestUbuntuDriversRunDriversInterface(unittest.TestCase): +class TestUbuntuDriversRunDriversInterface(unittest.IsolatedAsyncioTestCase): def setUp(self): self.app = make_app() self.ubuntu_drivers = UbuntuDriversRunDriversInterface( self.app, gpgpu=False) @patch("subiquity.server.ubuntu_drivers.arun_command") - def test_ensure_cmd_exists(self, mock_arun_command): - run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target")) + async def test_ensure_cmd_exists(self, mock_arun_command): + await self.ubuntu_drivers.ensure_cmd_exists("/target") mock_arun_command.assert_called_once_with( ["sh", "-c", "command -v ubuntu-drivers"], check=True) diff --git a/subiquity/tests/api/test_api.py b/subiquity/tests/api/test_api.py index f89c29cd..b495aa3d 100755 --- a/subiquity/tests/api/test_api.py +++ b/subiquity/tests/api/test_api.py @@ -7,7 +7,6 @@ from functools import wraps import json import os import tempfile -import unittest from unittest.mock import patch from urllib.parse import unquote @@ -134,7 +133,7 @@ class Server(Client): pass -class TestAPI(unittest.IsolatedAsyncioTestCase, SubiTestCase): +class TestAPI(SubiTestCase): class _MachineConfig(os.PathLike): def __init__(self, outer, path): self.outer = outer diff --git a/subiquity/tests/test_timezonecontroller.py b/subiquity/tests/test_timezonecontroller.py index 050fcd2e..08c2dc4a 100644 --- a/subiquity/tests/test_timezonecontroller.py +++ b/subiquity/tests/test_timezonecontroller.py @@ -20,7 +20,6 @@ from subiquity.models.timezone import TimeZoneModel from subiquity.server.controllers.timezone import TimeZoneController from subiquitycore.tests import SubiTestCase from subiquitycore.tests.mocks import make_app -from subiquitycore.tests.util import run_coro class MockGeoIP: @@ -47,7 +46,7 @@ class TestTimeZoneController(SubiTestCase): @mock.patch('subiquity.server.controllers.timezone.timedatectl_settz') @mock.patch('subiquity.server.controllers.timezone.timedatectl_gettz') - def test_good_tzs(self, tdc_gettz, tdc_settz): + async def test_good_tzs(self, tdc_gettz, tdc_settz): tdc_gettz.return_value = tz_utc goods = [ # val - autoinstall value @@ -77,7 +76,7 @@ class TestTimeZoneController(SubiTestCase): self.tzc.model) self.assertEqual(geoip, self.tzc.model.detect_with_geoip, self.tzc.model) - self.assertEqual(tz, run_coro(self.tzc.GET()), self.tzc.model) + self.assertEqual(tz, await self.tzc.GET(), self.tzc.model) cloudconfig = {} if self.tzc.model.should_set_tz: cloudconfig = {'timezone': tz.timezone} @@ -104,7 +103,7 @@ class TestTimeZoneController(SubiTestCase): self.assertEqual('sleep', subprocess_run.call_args.args[0][0]) @mock.patch('subiquity.server.controllers.timezone.timedatectl_settz') - def test_get_tz_should_not_set(self, tdc_settz): - run_coro(self.tzc.GET()) + async def test_get_tz_should_not_set(self, tdc_settz): + await self.tzc.GET() self.assertFalse(self.tzc.model.should_set_tz) tdc_settz.assert_not_called() diff --git a/subiquity/ui/views/tests/test_identity.py b/subiquity/ui/views/tests/test_identity.py index ccfb8f18..62bad17f 100644 --- a/subiquity/ui/views/tests/test_identity.py +++ b/subiquity/ui/views/tests/test_identity.py @@ -1,4 +1,18 @@ -import asyncio +# Copyright 2017-2022 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + import unittest from unittest import mock @@ -42,20 +56,6 @@ system_reserved = { } -def tearDownModule() -> None: - # Set empty loop policy, so that subsequent get_event_loop() returns a new - # loop. If there is no running event loop set, that function will return - # the result of `get_event_loop_policy().get_event_loop()` call. If there - # is a policy there must be a running loop. IsolatedAsyncioTestCase - # closes the loop during tear down, though. It doesn't touch the policy. - # By having it as None, it autoinits and the next tests run smoothly. - # Another approach would be set a new event_loop for the current policy on - # test fixture tearDown, as pytest-asyncio does. - # Either way we would prevent failure on tests that depend on [run_coro] - # (subiquitycore/tests/util.py), for instance. - asyncio.set_event_loop_policy(None) - - class IdentityViewTests(unittest.IsolatedAsyncioTestCase): def make_view(self): diff --git a/subiquitycore/tests/__init__.py b/subiquitycore/tests/__init__.py index 0d117de5..4d6a5870 100644 --- a/subiquitycore/tests/__init__.py +++ b/subiquitycore/tests/__init__.py @@ -3,10 +3,10 @@ import os import shutil import tempfile -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase -class SubiTestCase(TestCase): +class SubiTestCase(IsolatedAsyncioTestCase): def tmp_dir(self, dir=None, cleanup=True): # return a full path to a temporary directory that will be cleaned up. if dir is None: diff --git a/subiquitycore/tests/util.py b/subiquitycore/tests/util.py index 3deab2d8..63cef422 100644 --- a/subiquitycore/tests/util.py +++ b/subiquitycore/tests/util.py @@ -13,14 +13,9 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import asyncio import random import string -def run_coro(coro): - return asyncio.get_event_loop().run_until_complete(coro) - - def random_string(): return ''.join(random.choice(string.ascii_letters) for _ in range(8))