diff --git a/subiquity/common/api/tests/test_endtoend.py b/subiquity/common/api/tests/test_endtoend.py
index 593a76e2..b611033b 100644
--- a/subiquity/common/api/tests/test_endtoend.py
+++ b/subiquity/common/api/tests/test_endtoend.py
@@ -21,8 +21,6 @@ import unittest
import aiohttp
from aiohttp import web
-from subiquitycore.tests.util import run_coro
-
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import api, Payload
@@ -46,9 +44,9 @@ async def makeE2EClient(api, impl,
yield make_client(api, mr)
-class TestEndToEnd(unittest.TestCase):
+class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
- def test_simple(self):
+ async def test_simple(self):
@api
class API:
def GET() -> str: ...
@@ -57,13 +55,10 @@ class TestEndToEnd(unittest.TestCase):
async def GET(self) -> str:
return 'value'
- async def run():
- async with makeE2EClient(API, Impl()) as client:
- self.assertEqual(await client.GET(), 'value')
+ async with makeE2EClient(API, Impl()) as client:
+ self.assertEqual(await client.GET(), 'value')
- run_coro(run())
-
- def test_nested(self):
+ async def test_nested(self):
@api
class API:
class endpoint:
@@ -74,13 +69,10 @@ class TestEndToEnd(unittest.TestCase):
async def endpoint_nested_GET(self) -> str:
return 'value'
- async def run():
- async with makeE2EClient(API, Impl()) as client:
- self.assertEqual(await client.endpoint.nested.GET(), 'value')
+ async with makeE2EClient(API, Impl()) as client:
+ self.assertEqual(await client.endpoint.nested.GET(), 'value')
- run_coro(run())
-
- def test_args(self):
+ async def test_args(self):
@api
class API:
def GET(arg1: str, arg2: str) -> str: ...
@@ -89,14 +81,11 @@ class TestEndToEnd(unittest.TestCase):
async def GET(self, arg1: str, arg2: str) -> str:
return '{}+{}'.format(arg1, arg2)
- async def run():
- async with makeE2EClient(API, Impl()) as client:
- self.assertEqual(
- await client.GET(arg1="A", arg2="B"), 'A+B')
+ async with makeE2EClient(API, Impl()) as client:
+ self.assertEqual(
+ await client.GET(arg1="A", arg2="B"), 'A+B')
- run_coro(run())
-
- def test_defaults(self):
+ async def test_defaults(self):
@api
class API:
def GET(arg1: str, arg2: str = "arg2") -> str: ...
@@ -105,16 +94,13 @@ class TestEndToEnd(unittest.TestCase):
async def GET(self, arg1: str, arg2: str = "arg2") -> str:
return '{}+{}'.format(arg1, arg2)
- async def run():
- async with makeE2EClient(API, Impl()) as client:
- self.assertEqual(
- await client.GET(arg1="A", arg2="B"), 'A+B')
- self.assertEqual(
- await client.GET(arg1="A"), 'A+arg2')
+ async with makeE2EClient(API, Impl()) as client:
+ self.assertEqual(
+ await client.GET(arg1="A", arg2="B"), 'A+B')
+ self.assertEqual(
+ await client.GET(arg1="A"), 'A+arg2')
- run_coro(run())
-
- def test_post(self):
+ async def test_post(self):
@api
class API:
def POST(data: Payload[dict]) -> str: ...
@@ -123,14 +109,11 @@ class TestEndToEnd(unittest.TestCase):
async def POST(self, data: dict) -> str:
return data['key']
- async def run():
- async with makeE2EClient(API, Impl()) as client:
- self.assertEqual(
- await client.POST({'key': 'value'}), 'value')
+ async with makeE2EClient(API, Impl()) as client:
+ self.assertEqual(
+ await client.POST({'key': 'value'}), 'value')
- run_coro(run())
-
- def test_typed(self):
+ async def test_typed(self):
@attr.s(auto_attribs=True)
class In:
@@ -149,14 +132,11 @@ class TestEndToEnd(unittest.TestCase):
async def doubler_POST(self, data: In) -> Out:
return Out(doubled=data.val*2)
- async def run():
- async with makeE2EClient(API, Impl()) as client:
- out = await client.doubler.POST(In(3))
- self.assertEqual(out.doubled, 6)
+ async with makeE2EClient(API, Impl()) as client:
+ out = await client.doubler.POST(In(3))
+ self.assertEqual(out.doubled, 6)
- run_coro(run())
-
- def test_middleware(self):
+ async def test_middleware(self):
@api
class API:
def GET() -> int: ...
@@ -182,17 +162,14 @@ class TestEndToEnd(unittest.TestCase):
raise Skip
yield resp
- async def run():
- async with makeE2EClient(
- API, Impl(),
- middlewares=[middleware],
- make_request=custom_make_request) as client:
- with self.assertRaises(Skip):
- await client.GET()
+ async with makeE2EClient(
+ API, Impl(),
+ middlewares=[middleware],
+ make_request=custom_make_request) as client:
+ with self.assertRaises(Skip):
+ await client.GET()
- run_coro(run())
-
- def test_error(self):
+ async def test_error(self):
@api
class API:
class good:
@@ -208,16 +185,13 @@ class TestEndToEnd(unittest.TestCase):
async def bad_GET(self, x: int) -> int:
raise Exception("baz")
- async def run():
- async with makeE2EClient(API, Impl()) as client:
- r = await client.good.GET(2)
- self.assertEqual(r, 3)
- with self.assertRaises(aiohttp.ClientResponseError):
- await client.bad.GET(2)
+ async with makeE2EClient(API, Impl()) as client:
+ r = await client.good.GET(2)
+ self.assertEqual(r, 3)
+ with self.assertRaises(aiohttp.ClientResponseError):
+ await client.bad.GET(2)
- run_coro(run())
-
- def test_error_middleware(self):
+ async def test_error_middleware(self):
@api
class API:
class good:
@@ -251,14 +225,11 @@ class TestEndToEnd(unittest.TestCase):
raise Abort
yield resp
- async def run():
- async with makeE2EClient(
- API, Impl(),
- middlewares=[middleware],
- make_request=custom_make_request) as client:
- r = await client.good.GET(2)
- self.assertEqual(r, 3)
- with self.assertRaises(Abort):
- await client.bad.GET(2)
-
- run_coro(run())
+ async with makeE2EClient(
+ API, Impl(),
+ middlewares=[middleware],
+ make_request=custom_make_request) as client:
+ r = await client.good.GET(2)
+ self.assertEqual(r, 3)
+ with self.assertRaises(Abort):
+ await client.bad.GET(2)
diff --git a/subiquity/common/api/tests/test_server.py b/subiquity/common/api/tests/test_server.py
index 89d75413..d7cc3191 100644
--- a/subiquity/common/api/tests/test_server.py
+++ b/subiquity/common/api/tests/test_server.py
@@ -20,7 +20,6 @@ from aiohttp.test_utils import TestClient, TestServer
from aiohttp import web
from subiquitycore.context import Context
-from subiquitycore.tests.util import run_coro
from subiquity.common.api.defs import api, Payload
from subiquity.common.api.server import (
@@ -56,7 +55,7 @@ async def makeTestClient(api, impl, middlewares=()):
yield client
-class TestBind(unittest.TestCase):
+class TestBind(unittest.IsolatedAsyncioTestCase):
async def assertResponse(self, coro, value):
resp = await coro
@@ -67,7 +66,7 @@ class TestBind(unittest.TestCase):
self.assertEqual(resp.status, 200)
self.assertEqual(await resp.json(), value)
- def test_simple(self):
+ async def test_simple(self):
@api
class API:
def GET() -> str: ...
@@ -76,14 +75,11 @@ class TestBind(unittest.TestCase):
async def GET(self) -> str:
return 'value'
- async def run():
- async with makeTestClient(API, Impl()) as client:
- await self.assertResponse(
- client.get("/"), 'value')
+ async with makeTestClient(API, Impl()) as client:
+ await self.assertResponse(
+ client.get("/"), 'value')
- run_coro(run())
-
- def test_nested(self):
+ async def test_nested(self):
@api
class API:
class endpoint:
@@ -94,14 +90,11 @@ class TestBind(unittest.TestCase):
async def nested_get(self, request, context):
return 'nested'
- async def run():
- async with makeTestClient(API.endpoint, Impl()) as client:
- await self.assertResponse(
- client.get("/endpoint/nested"), 'nested')
+ async with makeTestClient(API.endpoint, Impl()) as client:
+ await self.assertResponse(
+ client.get("/endpoint/nested"), 'nested')
- run_coro(run())
-
- def test_args(self):
+ async def test_args(self):
@api
class API:
def GET(arg: str): ...
@@ -110,14 +103,11 @@ class TestBind(unittest.TestCase):
async def GET(self, arg: str):
return arg
- async def run():
- async with makeTestClient(API, Impl()) as client:
- await self.assertResponse(
- client.get('/?arg="whut"'), 'whut')
+ async with makeTestClient(API, Impl()) as client:
+ await self.assertResponse(
+ client.get('/?arg="whut"'), 'whut')
- run_coro(run())
-
- def test_missing_argument(self):
+ async def test_missing_argument(self):
@api
class API:
def GET(arg: str): ...
@@ -126,19 +116,16 @@ class TestBind(unittest.TestCase):
async def GET(self, arg: str):
return arg
- async def run():
- async with makeTestClient(API, Impl()) as client:
- resp = await client.get('/')
- self.assertEqual(resp.status, 500)
- self.assertEqual(resp.headers['x-status'], 'error')
- self.assertEqual(resp.headers['x-error-type'], 'TypeError')
- self.assertEqual(
- resp.headers['x-error-msg'],
- 'missing required argument "arg"')
+ async with makeTestClient(API, Impl()) as client:
+ resp = await client.get('/')
+ self.assertEqual(resp.status, 500)
+ self.assertEqual(resp.headers['x-status'], 'error')
+ self.assertEqual(resp.headers['x-error-type'], 'TypeError')
+ self.assertEqual(
+ resp.headers['x-error-msg'],
+ 'missing required argument "arg"')
- run_coro(run())
-
- def test_error(self):
+ async def test_error(self):
@api
class API:
def GET(): ...
@@ -147,17 +134,14 @@ class TestBind(unittest.TestCase):
async def GET(self):
return 1/0
- async def run():
- async with makeTestClient(API, Impl()) as client:
- resp = await client.get('/')
- self.assertEqual(resp.status, 500)
- self.assertEqual(resp.headers['x-status'], 'error')
- self.assertEqual(
- resp.headers['x-error-type'], 'ZeroDivisionError')
+ async with makeTestClient(API, Impl()) as client:
+ resp = await client.get('/')
+ self.assertEqual(resp.status, 500)
+ self.assertEqual(resp.headers['x-status'], 'error')
+ self.assertEqual(
+ resp.headers['x-error-type'], 'ZeroDivisionError')
- run_coro(run())
-
- def test_post(self):
+ async def test_post(self):
@api
class API:
def POST(data: Payload[str]) -> str: ...
@@ -166,13 +150,10 @@ class TestBind(unittest.TestCase):
async def POST(self, data: str) -> str:
return data
- async def run():
- async with makeTestClient(API, Impl()) as client:
- await self.assertResponse(
- client.post("/", json='value'),
- 'value')
-
- run_coro(run())
+ async with makeTestClient(API, Impl()) as client:
+ await self.assertResponse(
+ client.post("/", json='value'),
+ 'value')
def test_missing_method(self):
@api
@@ -201,7 +182,7 @@ class TestBind(unittest.TestCase):
bind(app.router, API, Impl())
self.assertEqual(cm.exception.methname, "API.GET")
- def test_middleware(self):
+ async def test_middleware(self):
@web.middleware
async def middleware(request, handler):
@@ -219,16 +200,13 @@ class TestBind(unittest.TestCase):
async def GET(self) -> str:
return 1/0
- async def run():
- async with makeTestClient(
- API, Impl(), middlewares=[middleware]) as client:
- resp = await client.get("/")
- self.assertEqual(
- resp.headers['x-error-type'], 'ZeroDivisionError')
+ async with makeTestClient(
+ API, Impl(), middlewares=[middleware]) as client:
+ resp = await client.get("/")
+ self.assertEqual(
+ resp.headers['x-error-type'], 'ZeroDivisionError')
- run_coro(run())
-
- def test_controller_for_request(self):
+ async def test_controller_for_request(self):
seen_controller = None
@@ -249,12 +227,9 @@ class TestBind(unittest.TestCase):
impl = Impl()
- async def run():
- async with makeTestClient(
- API.meth, impl, middlewares=[middleware]) as client:
- resp = await client.get("/meth")
- self.assertEqual(await resp.json(), '')
-
- run_coro(run())
+ async with makeTestClient(
+ API.meth, impl, middlewares=[middleware]) as client:
+ resp = await client.get("/meth")
+ self.assertEqual(await resp.json(), '')
self.assertIs(impl, seen_controller)
diff --git a/subiquity/models/tests/test_subiquity.py b/subiquity/models/tests/test_subiquity.py
index c92b761b..065e243e 100644
--- a/subiquity/models/tests/test_subiquity.py
+++ b/subiquity/models/tests/test_subiquity.py
@@ -19,7 +19,6 @@ from unittest import mock
import yaml
from subiquitycore.pubsub import MessageHub
-from subiquitycore.tests.util import run_coro
from subiquity.common.types import IdentityData
from subiquity.models.subiquity import ModelNames, SubiquityModel
@@ -63,7 +62,7 @@ class TestModelNames(unittest.TestCase):
self.assertEqual(model_names.all(), {'a', 'b', 'c'})
-class TestSubiquityModel(unittest.TestCase):
+class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
def writtenFiles(self, config):
for k, v in config.get('write_files', {}).items():
@@ -114,7 +113,7 @@ class TestSubiquityModel(unittest.TestCase):
cur = cur[component]
self.fail("config has value {} for {}".format(cur, path))
- async def _test_configure(self):
+ async def test_configure(self):
hub = MessageHub()
model = SubiquityModel(
'test', hub, ModelNames({'a', 'b'}), ModelNames(set()))
@@ -124,9 +123,6 @@ class TestSubiquityModel(unittest.TestCase):
await hub.abroadcast((InstallerChannels.CONFIGURED, 'b'))
self.assertTrue(model._install_event.is_set())
- def test_configure(self):
- run_coro(self._test_configure())
-
def make_model(self):
return SubiquityModel(
'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES)
diff --git a/subiquity/server/controllers/tests/test_filesystem.py b/subiquity/server/controllers/tests/test_filesystem.py
index d52af07c..a117975b 100644
--- a/subiquity/server/controllers/tests/test_filesystem.py
+++ b/subiquity/server/controllers/tests/test_filesystem.py
@@ -13,13 +13,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from unittest import mock, TestCase
+from unittest import mock, TestCase, IsolatedAsyncioTestCase
from parameterized import parameterized
from subiquity.server.controllers.filesystem import FilesystemController
-from subiquitycore.tests.util import run_coro
from subiquitycore.tests.mocks import make_app
from subiquity.common.types import Bootloader
from subiquity.models.tests.test_filesystem import (
@@ -28,7 +27,7 @@ from subiquity.models.tests.test_filesystem import (
)
-class TestSubiquityControllerFilesystem(TestCase):
+class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase):
def setUp(self):
self.app = make_app()
self.app.opts.bootloader = 'UEFI'
@@ -38,20 +37,20 @@ class TestSubiquityControllerFilesystem(TestCase):
self.fsc = FilesystemController(app=self.app)
self.fsc._configured = True
- def test_probe_restricted(self):
- run_coro(self.fsc._probe_once(context=None, restricted=True))
+ async def test_probe_restricted(self):
+ await self.fsc._probe_once(context=None, restricted=True)
self.app.prober.get_storage.assert_called_with({'blockdev'})
- def test_probe_os_prober_false(self):
+ async def test_probe_os_prober_false(self):
self.app.opts.use_os_prober = False
- run_coro(self.fsc._probe_once(context=None, restricted=False))
+ await self.fsc._probe_once(context=None, restricted=False)
actual = self.app.prober.get_storage.call_args.args[0]
self.assertTrue({'defaults'} <= actual)
self.assertNotIn('os', actual)
- def test_probe_os_prober_true(self):
+ async def test_probe_os_prober_true(self):
self.app.opts.use_os_prober = True
- run_coro(self.fsc._probe_once(context=None, restricted=False))
+ await self.fsc._probe_once(context=None, restricted=False)
actual = self.app.prober.get_storage.call_args.args[0]
self.assertTrue({'defaults', 'os'} <= actual)
diff --git a/subiquity/server/tests/test_apt.py b/subiquity/server/tests/test_apt.py
index 44a73d81..0a2e4960 100644
--- a/subiquity/server/tests/test_apt.py
+++ b/subiquity/server/tests/test_apt.py
@@ -17,7 +17,6 @@ from unittest.mock import Mock, patch, AsyncMock
from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.mocks import make_app
-from subiquitycore.tests.util import run_coro
from subiquity.server.apt import (
AptConfigurer,
Mountpoint,
@@ -51,14 +50,14 @@ class TestAptConfigurer(SubiTestCase):
expected['apt']['https_proxy'] = proxy
self.assertEqual(expected, self.configurer.apt_config())
- def test_mount_unmount(self):
+ async def test_mount_unmount(self):
# Make sure we can unmount something that we mounted before.
with patch.object(self.app, "command_runner",
create=True, new_callable=AsyncMock):
- m = run_coro(self.configurer.mount("/dev/cdrom", "/target"))
- run_coro(self.configurer.unmount(m))
+ m = await self.configurer.mount("/dev/cdrom", "/target")
+ await self.configurer.unmount(m)
- def test_overlay(self):
+ async def test_overlay(self):
self.configurer.install_tree = OverlayMountpoint(
upperdir="upperdir-install-tree",
lowers=["lowers1-install-tree"],
@@ -71,13 +70,10 @@ class TestAptConfigurer(SubiTestCase):
)
self.source = "source"
- async def coro():
- async with self.configurer.overlay():
- pass
-
with patch.object(self.app, "command_runner",
create=True, new_callable=AsyncMock):
- run_coro(coro())
+ async with self.configurer.overlay():
+ pass
class TestLowerDirFor(SubiTestCase):
diff --git a/subiquity/server/tests/test_geoip.py b/subiquity/server/tests/test_geoip.py
index 835292e8..5807331b 100644
--- a/subiquity/server/tests/test_geoip.py
+++ b/subiquity/server/tests/test_geoip.py
@@ -17,7 +17,6 @@ from aioresponses import aioresponses
from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.mocks import make_app
-from subiquitycore.tests.util import run_coro
from subiquity.server.geoip import (
GeoIP,
HTTPGeoIPStrategy,
@@ -48,16 +47,13 @@ empty_cc = ''
class TestGeoIP(SubiTestCase):
- def setUp(self):
+ async def asyncSetUp(self):
strategy = HTTPGeoIPStrategy()
self.geoip = GeoIP(make_app(), strategy)
- async def fn():
- self.assertTrue(await self.geoip.lookup())
-
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=xml)
- run_coro(fn())
+ self.assertTrue(await self.geoip.lookup())
def test_countrycode(self):
self.assertEqual("us", self.geoip.countrycode)
@@ -71,42 +67,32 @@ class TestGeoIPBadData(SubiTestCase):
strategy = HTTPGeoIPStrategy()
self.geoip = GeoIP(make_app(), strategy)
- def test_partial_reponse(self):
- async def fn():
- self.assertFalse(await self.geoip.lookup())
+ async def test_partial_reponse(self):
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=partial)
- run_coro(fn())
-
- def test_incomplete(self):
- async def fn():
self.assertFalse(await self.geoip.lookup())
+
+ async def test_incomplete(self):
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=incomplete)
- run_coro(fn())
+ self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.countrycode)
self.assertIsNone(self.geoip.timezone)
- def test_long_cc(self):
- async def fn():
- self.assertFalse(await self.geoip.lookup())
+ async def test_long_cc(self):
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=long_cc)
- run_coro(fn())
+ self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.countrycode)
- def test_empty_cc(self):
- async def fn():
- self.assertFalse(await self.geoip.lookup())
+ async def test_empty_cc(self):
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_cc)
- run_coro(fn())
+ self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.countrycode)
- def test_empty_tz(self):
- async def fn():
- self.assertFalse(await self.geoip.lookup())
+ async def test_empty_tz(self):
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_tz)
- run_coro(fn())
+ self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.timezone)
diff --git a/subiquity/server/tests/test_ubuntu_advantage.py b/subiquity/server/tests/test_ubuntu_advantage.py
index e6d9a217..4df6a76c 100644
--- a/subiquity/server/tests/test_ubuntu_advantage.py
+++ b/subiquity/server/tests/test_ubuntu_advantage.py
@@ -26,37 +26,36 @@ from subiquity.server.ubuntu_advantage import (
MockedUAInterfaceStrategy,
UAClientUAInterfaceStrategy,
)
-from subiquitycore.tests.util import run_coro
-class TestMockedUAInterfaceStrategy(unittest.TestCase):
+class TestMockedUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
- def test_query_info_invalid(self):
+ async def test_query_info_invalid(self):
# Tokens starting with "i" in dry-run mode cause the token to be
# reported as invalid.
with self.assertRaises(InvalidTokenError):
- run_coro(self.strategy.query_info(token="invalidToken"))
+ await self.strategy.query_info(token="invalidToken")
- def test_query_info_failure(self):
+ async def test_query_info_failure(self):
# Tokens starting with "f" in dry-run mode simulate an "internal"
# error.
with self.assertRaises(CheckSubscriptionError):
- run_coro(self.strategy.query_info(token="failure"))
+ await self.strategy.query_info(token="failure")
- def test_query_info_expired(self):
+ async def test_query_info_expired(self):
# Tokens starting with "x" is dry-run mode simulate an expired token.
- info = run_coro(self.strategy.query_info(token="xpiredToken"))
+ info = await self.strategy.query_info(token="xpiredToken")
self.assertEqual(info["expires"], "2010-12-31T00:00:00+00:00")
- def test_query_info_valid(self):
+ async def test_query_info_valid(self):
# Other tokens are considered valid in dry-run mode.
- info = run_coro(self.strategy.query_info(token="validToken"))
+ info = await self.strategy.query_info(token="validToken")
self.assertEqual(info["expires"], "2035-12-31T00:00:00+00:00")
-class TestUAClientUAInterfaceStrategy(unittest.TestCase):
+class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase):
arun_command_sym = "subiquity.server.ubuntu_advantage.utils.arun_command"
def test_init(self):
@@ -75,7 +74,7 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
self.assertEqual(strategy.executable,
["python3", "/usr/bin/ubuntu-advantage"])
- def test_query_info_succeeded(self):
+ async def test_query_info_succeeded(self):
strategy = UAClientUAInterfaceStrategy()
command = (
"ubuntu-advantage",
@@ -87,10 +86,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
with patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = "{}"
- run_coro(strategy.query_info(token="123456789"))
+ await strategy.query_info(token="123456789")
mock_arun.assert_called_once_with(command, check=True)
- def test_query_info_failed(self):
+ async def test_query_info_failed(self):
strategy = UAClientUAInterfaceStrategy()
command = (
"ubuntu-advantage",
@@ -104,10 +103,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
cmd=command)
mock_arun.return_value.stdout = "{}"
with self.assertRaises(CheckSubscriptionError):
- run_coro(strategy.query_info(token="123456789"))
+ await strategy.query_info(token="123456789")
mock_arun.assert_called_once_with(command, check=True)
- def test_query_info_invalid_json(self):
+ async def test_query_info_invalid_json(self):
strategy = UAClientUAInterfaceStrategy()
command = (
"ubuntu-advantage",
@@ -120,31 +119,31 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = "invalid-json"
with self.assertRaises(CheckSubscriptionError):
- run_coro(strategy.query_info(token="123456789"))
+ await strategy.query_info(token="123456789")
mock_arun.assert_called_once_with(command, check=True)
-class TestUAInterface(unittest.TestCase):
+class TestUAInterface(unittest.IsolatedAsyncioTestCase):
- def test_mocked_get_activable_services(self):
+ async def test_mocked_get_activable_services(self):
strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
interface = UAInterface(strategy)
with self.assertRaises(InvalidTokenError):
- run_coro(interface.get_activable_services(token="invalidToken"))
+ await interface.get_activable_services(token="invalidToken")
# Tokens starting with "f" in dry-run mode simulate an "internal"
# error.
with self.assertRaises(CheckSubscriptionError):
- run_coro(interface.get_activable_services(token="failure"))
+ await interface.get_activable_services(token="failure")
# Tokens starting with "x" is dry-run mode simulate an expired token.
with self.assertRaises(ExpiredTokenError):
- run_coro(interface.get_activable_services(token="xpiredToken"))
+ await interface.get_activable_services(token="xpiredToken")
# Other tokens are considered valid in dry-run mode.
- run_coro(interface.get_activable_services(token="validToken"))
+ await interface.get_activable_services(token="validToken")
- def test_get_activable_services(self):
+ async def test_get_activable_services(self):
# We use the standard strategy but don't actually run it
strategy = UAClientUAInterfaceStrategy()
interface = UAInterface(strategy)
@@ -185,8 +184,7 @@ class TestUAInterface(unittest.TestCase):
]
}
interface.get_subscription = AsyncMock(return_value=subscription)
- services = run_coro(
- interface.get_activable_services(token="XXX"))
+ services = await interface.get_activable_services(token="XXX")
self.assertIn(UbuntuProService(
name="esm-infra",
@@ -211,5 +209,4 @@ class TestUAInterface(unittest.TestCase):
# Test with "Z" suffix for the expiration date.
subscription["expires"] = "2035-12-31T00:00:00Z"
- services = run_coro(
- interface.get_activable_services(token="XXX"))
+ services = await interface.get_activable_services(token="XXX")
diff --git a/subiquity/server/tests/test_ubuntu_drivers.py b/subiquity/server/tests/test_ubuntu_drivers.py
index 4a658e6f..91c917bf 100644
--- a/subiquity/server/tests/test_ubuntu_drivers.py
+++ b/subiquity/server/tests/test_ubuntu_drivers.py
@@ -18,7 +18,6 @@ import unittest
from unittest.mock import patch, AsyncMock, Mock
from subiquitycore.tests.mocks import make_app
-from subiquitycore.tests.util import run_coro
from subiquity.server.ubuntu_drivers import (
UbuntuDriversInterface,
@@ -28,7 +27,7 @@ from subiquity.server.ubuntu_drivers import (
)
-class TestUbuntuDriversInterface(unittest.TestCase):
+class TestUbuntuDriversInterface(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.app = make_app()
@@ -54,11 +53,11 @@ class TestUbuntuDriversInterface(unittest.TestCase):
@patch.multiple(UbuntuDriversInterface, __abstractmethods__=set())
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
- def test_install_drivers(self, mock_run_curtin_command):
+ async def test_install_drivers(self, mock_run_curtin_command):
ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False)
- run_coro(ubuntu_drivers.install_drivers(
+ await ubuntu_drivers.install_drivers(
root_dir="/target",
- context="installing third-party drivers"))
+ context="installing third-party drivers")
mock_run_curtin_command.assert_called_once_with(
self.app, "installing third-party drivers",
"in-target", "-t", "/target",
@@ -91,18 +90,18 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
["nvidia-driver-470", "nvidia-driver-510"])
-class TestUbuntuDriversClientInterface(unittest.TestCase):
+class TestUbuntuDriversClientInterface(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.app = make_app()
self.ubuntu_drivers = UbuntuDriversClientInterface(
self.app, gpgpu=False)
- def test_ensure_cmd_exists(self):
+ async def test_ensure_cmd_exists(self):
with patch.object(
self.app, "command_runner",
create=True, new_callable=AsyncMock) as mock_runner:
# On success
- run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
+ await self.ubuntu_drivers.ensure_cmd_exists("/target")
mock_runner.run.assert_called_once_with(
[
"chroot", "/target",
@@ -115,17 +114,17 @@ class TestUbuntuDriversClientInterface(unittest.TestCase):
cmd=["sh", "-c", "command -v ubuntu-drivers"])
with self.assertRaises(CommandNotFoundError):
- run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
+ await self.ubuntu_drivers.ensure_cmd_exists("/target")
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
- def test_list_drivers(self, mock_run_curtin_command):
+ async def test_list_drivers(self, mock_run_curtin_command):
# Make sure this gets decoded as utf-8.
mock_run_curtin_command.return_value = Mock(stdout=b"""\
nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
""")
- drivers = run_coro(self.ubuntu_drivers.list_drivers(
+ drivers = await self.ubuntu_drivers.list_drivers(
root_dir="/target",
- context="listing third-party drivers"))
+ context="listing third-party drivers")
mock_run_curtin_command.assert_called_once_with(
self.app, "listing third-party drivers",
@@ -137,15 +136,15 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
self.assertEqual(drivers, ["nvidia-driver-510"])
-class TestUbuntuDriversRunDriversInterface(unittest.TestCase):
+class TestUbuntuDriversRunDriversInterface(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.app = make_app()
self.ubuntu_drivers = UbuntuDriversRunDriversInterface(
self.app, gpgpu=False)
@patch("subiquity.server.ubuntu_drivers.arun_command")
- def test_ensure_cmd_exists(self, mock_arun_command):
- run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
+ async def test_ensure_cmd_exists(self, mock_arun_command):
+ await self.ubuntu_drivers.ensure_cmd_exists("/target")
mock_arun_command.assert_called_once_with(
["sh", "-c", "command -v ubuntu-drivers"],
check=True)
diff --git a/subiquity/tests/api/test_api.py b/subiquity/tests/api/test_api.py
index f89c29cd..b495aa3d 100755
--- a/subiquity/tests/api/test_api.py
+++ b/subiquity/tests/api/test_api.py
@@ -7,7 +7,6 @@ from functools import wraps
import json
import os
import tempfile
-import unittest
from unittest.mock import patch
from urllib.parse import unquote
@@ -134,7 +133,7 @@ class Server(Client):
pass
-class TestAPI(unittest.IsolatedAsyncioTestCase, SubiTestCase):
+class TestAPI(SubiTestCase):
class _MachineConfig(os.PathLike):
def __init__(self, outer, path):
self.outer = outer
diff --git a/subiquity/tests/test_timezonecontroller.py b/subiquity/tests/test_timezonecontroller.py
index 050fcd2e..08c2dc4a 100644
--- a/subiquity/tests/test_timezonecontroller.py
+++ b/subiquity/tests/test_timezonecontroller.py
@@ -20,7 +20,6 @@ from subiquity.models.timezone import TimeZoneModel
from subiquity.server.controllers.timezone import TimeZoneController
from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.mocks import make_app
-from subiquitycore.tests.util import run_coro
class MockGeoIP:
@@ -47,7 +46,7 @@ class TestTimeZoneController(SubiTestCase):
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
@mock.patch('subiquity.server.controllers.timezone.timedatectl_gettz')
- def test_good_tzs(self, tdc_gettz, tdc_settz):
+ async def test_good_tzs(self, tdc_gettz, tdc_settz):
tdc_gettz.return_value = tz_utc
goods = [
# val - autoinstall value
@@ -77,7 +76,7 @@ class TestTimeZoneController(SubiTestCase):
self.tzc.model)
self.assertEqual(geoip, self.tzc.model.detect_with_geoip,
self.tzc.model)
- self.assertEqual(tz, run_coro(self.tzc.GET()), self.tzc.model)
+ self.assertEqual(tz, await self.tzc.GET(), self.tzc.model)
cloudconfig = {}
if self.tzc.model.should_set_tz:
cloudconfig = {'timezone': tz.timezone}
@@ -104,7 +103,7 @@ class TestTimeZoneController(SubiTestCase):
self.assertEqual('sleep', subprocess_run.call_args.args[0][0])
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
- def test_get_tz_should_not_set(self, tdc_settz):
- run_coro(self.tzc.GET())
+ async def test_get_tz_should_not_set(self, tdc_settz):
+ await self.tzc.GET()
self.assertFalse(self.tzc.model.should_set_tz)
tdc_settz.assert_not_called()
diff --git a/subiquity/ui/views/tests/test_identity.py b/subiquity/ui/views/tests/test_identity.py
index ccfb8f18..62bad17f 100644
--- a/subiquity/ui/views/tests/test_identity.py
+++ b/subiquity/ui/views/tests/test_identity.py
@@ -1,4 +1,18 @@
-import asyncio
+# Copyright 2017-2022 Canonical, Ltd.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
import unittest
from unittest import mock
@@ -42,20 +56,6 @@ system_reserved = {
}
-def tearDownModule() -> None:
- # Set empty loop policy, so that subsequent get_event_loop() returns a new
- # loop. If there is no running event loop set, that function will return
- # the result of `get_event_loop_policy().get_event_loop()` call. If there
- # is a policy there must be a running loop. IsolatedAsyncioTestCase
- # closes the loop during tear down, though. It doesn't touch the policy.
- # By having it as None, it autoinits and the next tests run smoothly.
- # Another approach would be set a new event_loop for the current policy on
- # test fixture tearDown, as pytest-asyncio does.
- # Either way we would prevent failure on tests that depend on [run_coro]
- # (subiquitycore/tests/util.py), for instance.
- asyncio.set_event_loop_policy(None)
-
-
class IdentityViewTests(unittest.IsolatedAsyncioTestCase):
def make_view(self):
diff --git a/subiquitycore/tests/__init__.py b/subiquitycore/tests/__init__.py
index 0d117de5..4d6a5870 100644
--- a/subiquitycore/tests/__init__.py
+++ b/subiquitycore/tests/__init__.py
@@ -3,10 +3,10 @@ import os
import shutil
import tempfile
-from unittest import TestCase
+from unittest import IsolatedAsyncioTestCase
-class SubiTestCase(TestCase):
+class SubiTestCase(IsolatedAsyncioTestCase):
def tmp_dir(self, dir=None, cleanup=True):
# return a full path to a temporary directory that will be cleaned up.
if dir is None:
diff --git a/subiquitycore/tests/util.py b/subiquitycore/tests/util.py
index 3deab2d8..63cef422 100644
--- a/subiquitycore/tests/util.py
+++ b/subiquitycore/tests/util.py
@@ -13,14 +13,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import asyncio
import random
import string
-def run_coro(coro):
- return asyncio.get_event_loop().run_until_complete(coro)
-
-
def random_string():
return ''.join(random.choice(string.ascii_letters) for _ in range(8))