tests: s/run_coro/IsolatedAsyncioTestCase/

Tests that use run_coro are at risk of being broken by the introduction
of another test using IsolatedAsyncioTestCase.  Switch over to only use
IsolatedAsyncioTestCase.
This commit is contained in:
Dan Bungert 2022-06-23 17:00:45 -06:00
parent 924e527b7b
commit dffb155203
13 changed files with 181 additions and 269 deletions

View File

@ -21,8 +21,6 @@ import unittest
import aiohttp
from aiohttp import web
from subiquitycore.tests.util import run_coro
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import api, Payload
@ -46,9 +44,9 @@ async def makeE2EClient(api, impl,
yield make_client(api, mr)
class TestEndToEnd(unittest.TestCase):
class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
def test_simple(self):
async def test_simple(self):
@api
class API:
def GET() -> str: ...
@ -57,13 +55,10 @@ class TestEndToEnd(unittest.TestCase):
async def GET(self) -> str:
return 'value'
async def run():
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(await client.GET(), 'value')
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(await client.GET(), 'value')
run_coro(run())
def test_nested(self):
async def test_nested(self):
@api
class API:
class endpoint:
@ -74,13 +69,10 @@ class TestEndToEnd(unittest.TestCase):
async def endpoint_nested_GET(self) -> str:
return 'value'
async def run():
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(await client.endpoint.nested.GET(), 'value')
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(await client.endpoint.nested.GET(), 'value')
run_coro(run())
def test_args(self):
async def test_args(self):
@api
class API:
def GET(arg1: str, arg2: str) -> str: ...
@ -89,14 +81,11 @@ class TestEndToEnd(unittest.TestCase):
async def GET(self, arg1: str, arg2: str) -> str:
return '{}+{}'.format(arg1, arg2)
async def run():
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.GET(arg1="A", arg2="B"), 'A+B')
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.GET(arg1="A", arg2="B"), 'A+B')
run_coro(run())
def test_defaults(self):
async def test_defaults(self):
@api
class API:
def GET(arg1: str, arg2: str = "arg2") -> str: ...
@ -105,16 +94,13 @@ class TestEndToEnd(unittest.TestCase):
async def GET(self, arg1: str, arg2: str = "arg2") -> str:
return '{}+{}'.format(arg1, arg2)
async def run():
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.GET(arg1="A", arg2="B"), 'A+B')
self.assertEqual(
await client.GET(arg1="A"), 'A+arg2')
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.GET(arg1="A", arg2="B"), 'A+B')
self.assertEqual(
await client.GET(arg1="A"), 'A+arg2')
run_coro(run())
def test_post(self):
async def test_post(self):
@api
class API:
def POST(data: Payload[dict]) -> str: ...
@ -123,14 +109,11 @@ class TestEndToEnd(unittest.TestCase):
async def POST(self, data: dict) -> str:
return data['key']
async def run():
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.POST({'key': 'value'}), 'value')
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.POST({'key': 'value'}), 'value')
run_coro(run())
def test_typed(self):
async def test_typed(self):
@attr.s(auto_attribs=True)
class In:
@ -149,14 +132,11 @@ class TestEndToEnd(unittest.TestCase):
async def doubler_POST(self, data: In) -> Out:
return Out(doubled=data.val*2)
async def run():
async with makeE2EClient(API, Impl()) as client:
out = await client.doubler.POST(In(3))
self.assertEqual(out.doubled, 6)
async with makeE2EClient(API, Impl()) as client:
out = await client.doubler.POST(In(3))
self.assertEqual(out.doubled, 6)
run_coro(run())
def test_middleware(self):
async def test_middleware(self):
@api
class API:
def GET() -> int: ...
@ -182,17 +162,14 @@ class TestEndToEnd(unittest.TestCase):
raise Skip
yield resp
async def run():
async with makeE2EClient(
API, Impl(),
middlewares=[middleware],
make_request=custom_make_request) as client:
with self.assertRaises(Skip):
await client.GET()
async with makeE2EClient(
API, Impl(),
middlewares=[middleware],
make_request=custom_make_request) as client:
with self.assertRaises(Skip):
await client.GET()
run_coro(run())
def test_error(self):
async def test_error(self):
@api
class API:
class good:
@ -208,16 +185,13 @@ class TestEndToEnd(unittest.TestCase):
async def bad_GET(self, x: int) -> int:
raise Exception("baz")
async def run():
async with makeE2EClient(API, Impl()) as client:
r = await client.good.GET(2)
self.assertEqual(r, 3)
with self.assertRaises(aiohttp.ClientResponseError):
await client.bad.GET(2)
async with makeE2EClient(API, Impl()) as client:
r = await client.good.GET(2)
self.assertEqual(r, 3)
with self.assertRaises(aiohttp.ClientResponseError):
await client.bad.GET(2)
run_coro(run())
def test_error_middleware(self):
async def test_error_middleware(self):
@api
class API:
class good:
@ -251,14 +225,11 @@ class TestEndToEnd(unittest.TestCase):
raise Abort
yield resp
async def run():
async with makeE2EClient(
API, Impl(),
middlewares=[middleware],
make_request=custom_make_request) as client:
r = await client.good.GET(2)
self.assertEqual(r, 3)
with self.assertRaises(Abort):
await client.bad.GET(2)
run_coro(run())
async with makeE2EClient(
API, Impl(),
middlewares=[middleware],
make_request=custom_make_request) as client:
r = await client.good.GET(2)
self.assertEqual(r, 3)
with self.assertRaises(Abort):
await client.bad.GET(2)

View File

@ -20,7 +20,6 @@ from aiohttp.test_utils import TestClient, TestServer
from aiohttp import web
from subiquitycore.context import Context
from subiquitycore.tests.util import run_coro
from subiquity.common.api.defs import api, Payload
from subiquity.common.api.server import (
@ -56,7 +55,7 @@ async def makeTestClient(api, impl, middlewares=()):
yield client
class TestBind(unittest.TestCase):
class TestBind(unittest.IsolatedAsyncioTestCase):
async def assertResponse(self, coro, value):
resp = await coro
@ -67,7 +66,7 @@ class TestBind(unittest.TestCase):
self.assertEqual(resp.status, 200)
self.assertEqual(await resp.json(), value)
def test_simple(self):
async def test_simple(self):
@api
class API:
def GET() -> str: ...
@ -76,14 +75,11 @@ class TestBind(unittest.TestCase):
async def GET(self) -> str:
return 'value'
async def run():
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get("/"), 'value')
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get("/"), 'value')
run_coro(run())
def test_nested(self):
async def test_nested(self):
@api
class API:
class endpoint:
@ -94,14 +90,11 @@ class TestBind(unittest.TestCase):
async def nested_get(self, request, context):
return 'nested'
async def run():
async with makeTestClient(API.endpoint, Impl()) as client:
await self.assertResponse(
client.get("/endpoint/nested"), 'nested')
async with makeTestClient(API.endpoint, Impl()) as client:
await self.assertResponse(
client.get("/endpoint/nested"), 'nested')
run_coro(run())
def test_args(self):
async def test_args(self):
@api
class API:
def GET(arg: str): ...
@ -110,14 +103,11 @@ class TestBind(unittest.TestCase):
async def GET(self, arg: str):
return arg
async def run():
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get('/?arg="whut"'), 'whut')
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get('/?arg="whut"'), 'whut')
run_coro(run())
def test_missing_argument(self):
async def test_missing_argument(self):
@api
class API:
def GET(arg: str): ...
@ -126,19 +116,16 @@ class TestBind(unittest.TestCase):
async def GET(self, arg: str):
return arg
async def run():
async with makeTestClient(API, Impl()) as client:
resp = await client.get('/')
self.assertEqual(resp.status, 500)
self.assertEqual(resp.headers['x-status'], 'error')
self.assertEqual(resp.headers['x-error-type'], 'TypeError')
self.assertEqual(
resp.headers['x-error-msg'],
'missing required argument "arg"')
async with makeTestClient(API, Impl()) as client:
resp = await client.get('/')
self.assertEqual(resp.status, 500)
self.assertEqual(resp.headers['x-status'], 'error')
self.assertEqual(resp.headers['x-error-type'], 'TypeError')
self.assertEqual(
resp.headers['x-error-msg'],
'missing required argument "arg"')
run_coro(run())
def test_error(self):
async def test_error(self):
@api
class API:
def GET(): ...
@ -147,17 +134,14 @@ class TestBind(unittest.TestCase):
async def GET(self):
return 1/0
async def run():
async with makeTestClient(API, Impl()) as client:
resp = await client.get('/')
self.assertEqual(resp.status, 500)
self.assertEqual(resp.headers['x-status'], 'error')
self.assertEqual(
resp.headers['x-error-type'], 'ZeroDivisionError')
async with makeTestClient(API, Impl()) as client:
resp = await client.get('/')
self.assertEqual(resp.status, 500)
self.assertEqual(resp.headers['x-status'], 'error')
self.assertEqual(
resp.headers['x-error-type'], 'ZeroDivisionError')
run_coro(run())
def test_post(self):
async def test_post(self):
@api
class API:
def POST(data: Payload[str]) -> str: ...
@ -166,13 +150,10 @@ class TestBind(unittest.TestCase):
async def POST(self, data: str) -> str:
return data
async def run():
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.post("/", json='value'),
'value')
run_coro(run())
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.post("/", json='value'),
'value')
def test_missing_method(self):
@api
@ -201,7 +182,7 @@ class TestBind(unittest.TestCase):
bind(app.router, API, Impl())
self.assertEqual(cm.exception.methname, "API.GET")
def test_middleware(self):
async def test_middleware(self):
@web.middleware
async def middleware(request, handler):
@ -219,16 +200,13 @@ class TestBind(unittest.TestCase):
async def GET(self) -> str:
return 1/0
async def run():
async with makeTestClient(
API, Impl(), middlewares=[middleware]) as client:
resp = await client.get("/")
self.assertEqual(
resp.headers['x-error-type'], 'ZeroDivisionError')
async with makeTestClient(
API, Impl(), middlewares=[middleware]) as client:
resp = await client.get("/")
self.assertEqual(
resp.headers['x-error-type'], 'ZeroDivisionError')
run_coro(run())
def test_controller_for_request(self):
async def test_controller_for_request(self):
seen_controller = None
@ -249,12 +227,9 @@ class TestBind(unittest.TestCase):
impl = Impl()
async def run():
async with makeTestClient(
API.meth, impl, middlewares=[middleware]) as client:
resp = await client.get("/meth")
self.assertEqual(await resp.json(), '')
run_coro(run())
async with makeTestClient(
API.meth, impl, middlewares=[middleware]) as client:
resp = await client.get("/meth")
self.assertEqual(await resp.json(), '')
self.assertIs(impl, seen_controller)

View File

@ -19,7 +19,6 @@ from unittest import mock
import yaml
from subiquitycore.pubsub import MessageHub
from subiquitycore.tests.util import run_coro
from subiquity.common.types import IdentityData
from subiquity.models.subiquity import ModelNames, SubiquityModel
@ -63,7 +62,7 @@ class TestModelNames(unittest.TestCase):
self.assertEqual(model_names.all(), {'a', 'b', 'c'})
class TestSubiquityModel(unittest.TestCase):
class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
def writtenFiles(self, config):
for k, v in config.get('write_files', {}).items():
@ -114,7 +113,7 @@ class TestSubiquityModel(unittest.TestCase):
cur = cur[component]
self.fail("config has value {} for {}".format(cur, path))
async def _test_configure(self):
async def test_configure(self):
hub = MessageHub()
model = SubiquityModel(
'test', hub, ModelNames({'a', 'b'}), ModelNames(set()))
@ -124,9 +123,6 @@ class TestSubiquityModel(unittest.TestCase):
await hub.abroadcast((InstallerChannels.CONFIGURED, 'b'))
self.assertTrue(model._install_event.is_set())
def test_configure(self):
run_coro(self._test_configure())
def make_model(self):
return SubiquityModel(
'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES)

View File

@ -13,13 +13,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from unittest import mock, TestCase
from unittest import mock, TestCase, IsolatedAsyncioTestCase
from parameterized import parameterized
from subiquity.server.controllers.filesystem import FilesystemController
from subiquitycore.tests.util import run_coro
from subiquitycore.tests.mocks import make_app
from subiquity.common.types import Bootloader
from subiquity.models.tests.test_filesystem import (
@ -28,7 +27,7 @@ from subiquity.models.tests.test_filesystem import (
)
class TestSubiquityControllerFilesystem(TestCase):
class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase):
def setUp(self):
self.app = make_app()
self.app.opts.bootloader = 'UEFI'
@ -38,20 +37,20 @@ class TestSubiquityControllerFilesystem(TestCase):
self.fsc = FilesystemController(app=self.app)
self.fsc._configured = True
def test_probe_restricted(self):
run_coro(self.fsc._probe_once(context=None, restricted=True))
async def test_probe_restricted(self):
await self.fsc._probe_once(context=None, restricted=True)
self.app.prober.get_storage.assert_called_with({'blockdev'})
def test_probe_os_prober_false(self):
async def test_probe_os_prober_false(self):
self.app.opts.use_os_prober = False
run_coro(self.fsc._probe_once(context=None, restricted=False))
await self.fsc._probe_once(context=None, restricted=False)
actual = self.app.prober.get_storage.call_args.args[0]
self.assertTrue({'defaults'} <= actual)
self.assertNotIn('os', actual)
def test_probe_os_prober_true(self):
async def test_probe_os_prober_true(self):
self.app.opts.use_os_prober = True
run_coro(self.fsc._probe_once(context=None, restricted=False))
await self.fsc._probe_once(context=None, restricted=False)
actual = self.app.prober.get_storage.call_args.args[0]
self.assertTrue({'defaults', 'os'} <= actual)

View File

@ -17,7 +17,6 @@ from unittest.mock import Mock, patch, AsyncMock
from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
from subiquity.server.apt import (
AptConfigurer,
Mountpoint,
@ -51,14 +50,14 @@ class TestAptConfigurer(SubiTestCase):
expected['apt']['https_proxy'] = proxy
self.assertEqual(expected, self.configurer.apt_config())
def test_mount_unmount(self):
async def test_mount_unmount(self):
# Make sure we can unmount something that we mounted before.
with patch.object(self.app, "command_runner",
create=True, new_callable=AsyncMock):
m = run_coro(self.configurer.mount("/dev/cdrom", "/target"))
run_coro(self.configurer.unmount(m))
m = await self.configurer.mount("/dev/cdrom", "/target")
await self.configurer.unmount(m)
def test_overlay(self):
async def test_overlay(self):
self.configurer.install_tree = OverlayMountpoint(
upperdir="upperdir-install-tree",
lowers=["lowers1-install-tree"],
@ -71,13 +70,10 @@ class TestAptConfigurer(SubiTestCase):
)
self.source = "source"
async def coro():
async with self.configurer.overlay():
pass
with patch.object(self.app, "command_runner",
create=True, new_callable=AsyncMock):
run_coro(coro())
async with self.configurer.overlay():
pass
class TestLowerDirFor(SubiTestCase):

View File

@ -17,7 +17,6 @@ from aioresponses import aioresponses
from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
from subiquity.server.geoip import (
GeoIP,
HTTPGeoIPStrategy,
@ -48,16 +47,13 @@ empty_cc = '<Response><CountryCode></CountryCode></Response>'
class TestGeoIP(SubiTestCase):
def setUp(self):
async def asyncSetUp(self):
strategy = HTTPGeoIPStrategy()
self.geoip = GeoIP(make_app(), strategy)
async def fn():
self.assertTrue(await self.geoip.lookup())
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=xml)
run_coro(fn())
self.assertTrue(await self.geoip.lookup())
def test_countrycode(self):
self.assertEqual("us", self.geoip.countrycode)
@ -71,42 +67,32 @@ class TestGeoIPBadData(SubiTestCase):
strategy = HTTPGeoIPStrategy()
self.geoip = GeoIP(make_app(), strategy)
def test_partial_reponse(self):
async def fn():
self.assertFalse(await self.geoip.lookup())
async def test_partial_reponse(self):
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=partial)
run_coro(fn())
def test_incomplete(self):
async def fn():
self.assertFalse(await self.geoip.lookup())
async def test_incomplete(self):
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=incomplete)
run_coro(fn())
self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.countrycode)
self.assertIsNone(self.geoip.timezone)
def test_long_cc(self):
async def fn():
self.assertFalse(await self.geoip.lookup())
async def test_long_cc(self):
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=long_cc)
run_coro(fn())
self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.countrycode)
def test_empty_cc(self):
async def fn():
self.assertFalse(await self.geoip.lookup())
async def test_empty_cc(self):
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_cc)
run_coro(fn())
self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.countrycode)
def test_empty_tz(self):
async def fn():
self.assertFalse(await self.geoip.lookup())
async def test_empty_tz(self):
with aioresponses() as mocked:
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_tz)
run_coro(fn())
self.assertFalse(await self.geoip.lookup())
self.assertIsNone(self.geoip.timezone)

View File

@ -26,37 +26,36 @@ from subiquity.server.ubuntu_advantage import (
MockedUAInterfaceStrategy,
UAClientUAInterfaceStrategy,
)
from subiquitycore.tests.util import run_coro
class TestMockedUAInterfaceStrategy(unittest.TestCase):
class TestMockedUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
def test_query_info_invalid(self):
async def test_query_info_invalid(self):
# Tokens starting with "i" in dry-run mode cause the token to be
# reported as invalid.
with self.assertRaises(InvalidTokenError):
run_coro(self.strategy.query_info(token="invalidToken"))
await self.strategy.query_info(token="invalidToken")
def test_query_info_failure(self):
async def test_query_info_failure(self):
# Tokens starting with "f" in dry-run mode simulate an "internal"
# error.
with self.assertRaises(CheckSubscriptionError):
run_coro(self.strategy.query_info(token="failure"))
await self.strategy.query_info(token="failure")
def test_query_info_expired(self):
async def test_query_info_expired(self):
# Tokens starting with "x" is dry-run mode simulate an expired token.
info = run_coro(self.strategy.query_info(token="xpiredToken"))
info = await self.strategy.query_info(token="xpiredToken")
self.assertEqual(info["expires"], "2010-12-31T00:00:00+00:00")
def test_query_info_valid(self):
async def test_query_info_valid(self):
# Other tokens are considered valid in dry-run mode.
info = run_coro(self.strategy.query_info(token="validToken"))
info = await self.strategy.query_info(token="validToken")
self.assertEqual(info["expires"], "2035-12-31T00:00:00+00:00")
class TestUAClientUAInterfaceStrategy(unittest.TestCase):
class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase):
arun_command_sym = "subiquity.server.ubuntu_advantage.utils.arun_command"
def test_init(self):
@ -75,7 +74,7 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
self.assertEqual(strategy.executable,
["python3", "/usr/bin/ubuntu-advantage"])
def test_query_info_succeeded(self):
async def test_query_info_succeeded(self):
strategy = UAClientUAInterfaceStrategy()
command = (
"ubuntu-advantage",
@ -87,10 +86,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
with patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = "{}"
run_coro(strategy.query_info(token="123456789"))
await strategy.query_info(token="123456789")
mock_arun.assert_called_once_with(command, check=True)
def test_query_info_failed(self):
async def test_query_info_failed(self):
strategy = UAClientUAInterfaceStrategy()
command = (
"ubuntu-advantage",
@ -104,10 +103,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
cmd=command)
mock_arun.return_value.stdout = "{}"
with self.assertRaises(CheckSubscriptionError):
run_coro(strategy.query_info(token="123456789"))
await strategy.query_info(token="123456789")
mock_arun.assert_called_once_with(command, check=True)
def test_query_info_invalid_json(self):
async def test_query_info_invalid_json(self):
strategy = UAClientUAInterfaceStrategy()
command = (
"ubuntu-advantage",
@ -120,31 +119,31 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = "invalid-json"
with self.assertRaises(CheckSubscriptionError):
run_coro(strategy.query_info(token="123456789"))
await strategy.query_info(token="123456789")
mock_arun.assert_called_once_with(command, check=True)
class TestUAInterface(unittest.TestCase):
class TestUAInterface(unittest.IsolatedAsyncioTestCase):
def test_mocked_get_activable_services(self):
async def test_mocked_get_activable_services(self):
strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
interface = UAInterface(strategy)
with self.assertRaises(InvalidTokenError):
run_coro(interface.get_activable_services(token="invalidToken"))
await interface.get_activable_services(token="invalidToken")
# Tokens starting with "f" in dry-run mode simulate an "internal"
# error.
with self.assertRaises(CheckSubscriptionError):
run_coro(interface.get_activable_services(token="failure"))
await interface.get_activable_services(token="failure")
# Tokens starting with "x" is dry-run mode simulate an expired token.
with self.assertRaises(ExpiredTokenError):
run_coro(interface.get_activable_services(token="xpiredToken"))
await interface.get_activable_services(token="xpiredToken")
# Other tokens are considered valid in dry-run mode.
run_coro(interface.get_activable_services(token="validToken"))
await interface.get_activable_services(token="validToken")
def test_get_activable_services(self):
async def test_get_activable_services(self):
# We use the standard strategy but don't actually run it
strategy = UAClientUAInterfaceStrategy()
interface = UAInterface(strategy)
@ -185,8 +184,7 @@ class TestUAInterface(unittest.TestCase):
]
}
interface.get_subscription = AsyncMock(return_value=subscription)
services = run_coro(
interface.get_activable_services(token="XXX"))
services = await interface.get_activable_services(token="XXX")
self.assertIn(UbuntuProService(
name="esm-infra",
@ -211,5 +209,4 @@ class TestUAInterface(unittest.TestCase):
# Test with "Z" suffix for the expiration date.
subscription["expires"] = "2035-12-31T00:00:00Z"
services = run_coro(
interface.get_activable_services(token="XXX"))
services = await interface.get_activable_services(token="XXX")

View File

@ -18,7 +18,6 @@ import unittest
from unittest.mock import patch, AsyncMock, Mock
from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
from subiquity.server.ubuntu_drivers import (
UbuntuDriversInterface,
@ -28,7 +27,7 @@ from subiquity.server.ubuntu_drivers import (
)
class TestUbuntuDriversInterface(unittest.TestCase):
class TestUbuntuDriversInterface(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.app = make_app()
@ -54,11 +53,11 @@ class TestUbuntuDriversInterface(unittest.TestCase):
@patch.multiple(UbuntuDriversInterface, __abstractmethods__=set())
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
def test_install_drivers(self, mock_run_curtin_command):
async def test_install_drivers(self, mock_run_curtin_command):
ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False)
run_coro(ubuntu_drivers.install_drivers(
await ubuntu_drivers.install_drivers(
root_dir="/target",
context="installing third-party drivers"))
context="installing third-party drivers")
mock_run_curtin_command.assert_called_once_with(
self.app, "installing third-party drivers",
"in-target", "-t", "/target",
@ -91,18 +90,18 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
["nvidia-driver-470", "nvidia-driver-510"])
class TestUbuntuDriversClientInterface(unittest.TestCase):
class TestUbuntuDriversClientInterface(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.app = make_app()
self.ubuntu_drivers = UbuntuDriversClientInterface(
self.app, gpgpu=False)
def test_ensure_cmd_exists(self):
async def test_ensure_cmd_exists(self):
with patch.object(
self.app, "command_runner",
create=True, new_callable=AsyncMock) as mock_runner:
# On success
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
await self.ubuntu_drivers.ensure_cmd_exists("/target")
mock_runner.run.assert_called_once_with(
[
"chroot", "/target",
@ -115,17 +114,17 @@ class TestUbuntuDriversClientInterface(unittest.TestCase):
cmd=["sh", "-c", "command -v ubuntu-drivers"])
with self.assertRaises(CommandNotFoundError):
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
await self.ubuntu_drivers.ensure_cmd_exists("/target")
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
def test_list_drivers(self, mock_run_curtin_command):
async def test_list_drivers(self, mock_run_curtin_command):
# Make sure this gets decoded as utf-8.
mock_run_curtin_command.return_value = Mock(stdout=b"""\
nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
""")
drivers = run_coro(self.ubuntu_drivers.list_drivers(
drivers = await self.ubuntu_drivers.list_drivers(
root_dir="/target",
context="listing third-party drivers"))
context="listing third-party drivers")
mock_run_curtin_command.assert_called_once_with(
self.app, "listing third-party drivers",
@ -137,15 +136,15 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
self.assertEqual(drivers, ["nvidia-driver-510"])
class TestUbuntuDriversRunDriversInterface(unittest.TestCase):
class TestUbuntuDriversRunDriversInterface(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.app = make_app()
self.ubuntu_drivers = UbuntuDriversRunDriversInterface(
self.app, gpgpu=False)
@patch("subiquity.server.ubuntu_drivers.arun_command")
def test_ensure_cmd_exists(self, mock_arun_command):
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
async def test_ensure_cmd_exists(self, mock_arun_command):
await self.ubuntu_drivers.ensure_cmd_exists("/target")
mock_arun_command.assert_called_once_with(
["sh", "-c", "command -v ubuntu-drivers"],
check=True)

View File

@ -7,7 +7,6 @@ from functools import wraps
import json
import os
import tempfile
import unittest
from unittest.mock import patch
from urllib.parse import unquote
@ -134,7 +133,7 @@ class Server(Client):
pass
class TestAPI(unittest.IsolatedAsyncioTestCase, SubiTestCase):
class TestAPI(SubiTestCase):
class _MachineConfig(os.PathLike):
def __init__(self, outer, path):
self.outer = outer

View File

@ -20,7 +20,6 @@ from subiquity.models.timezone import TimeZoneModel
from subiquity.server.controllers.timezone import TimeZoneController
from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
class MockGeoIP:
@ -47,7 +46,7 @@ class TestTimeZoneController(SubiTestCase):
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
@mock.patch('subiquity.server.controllers.timezone.timedatectl_gettz')
def test_good_tzs(self, tdc_gettz, tdc_settz):
async def test_good_tzs(self, tdc_gettz, tdc_settz):
tdc_gettz.return_value = tz_utc
goods = [
# val - autoinstall value
@ -77,7 +76,7 @@ class TestTimeZoneController(SubiTestCase):
self.tzc.model)
self.assertEqual(geoip, self.tzc.model.detect_with_geoip,
self.tzc.model)
self.assertEqual(tz, run_coro(self.tzc.GET()), self.tzc.model)
self.assertEqual(tz, await self.tzc.GET(), self.tzc.model)
cloudconfig = {}
if self.tzc.model.should_set_tz:
cloudconfig = {'timezone': tz.timezone}
@ -104,7 +103,7 @@ class TestTimeZoneController(SubiTestCase):
self.assertEqual('sleep', subprocess_run.call_args.args[0][0])
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
def test_get_tz_should_not_set(self, tdc_settz):
run_coro(self.tzc.GET())
async def test_get_tz_should_not_set(self, tdc_settz):
await self.tzc.GET()
self.assertFalse(self.tzc.model.should_set_tz)
tdc_settz.assert_not_called()

View File

@ -1,4 +1,18 @@
import asyncio
# Copyright 2017-2022 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from unittest import mock
@ -42,20 +56,6 @@ system_reserved = {
}
def tearDownModule() -> None:
# Set empty loop policy, so that subsequent get_event_loop() returns a new
# loop. If there is no running event loop set, that function will return
# the result of `get_event_loop_policy().get_event_loop()` call. If there
# is a policy there must be a running loop. IsolatedAsyncioTestCase
# closes the loop during tear down, though. It doesn't touch the policy.
# By having it as None, it autoinits and the next tests run smoothly.
# Another approach would be set a new event_loop for the current policy on
# test fixture tearDown, as pytest-asyncio does.
# Either way we would prevent failure on tests that depend on [run_coro]
# (subiquitycore/tests/util.py), for instance.
asyncio.set_event_loop_policy(None)
class IdentityViewTests(unittest.IsolatedAsyncioTestCase):
def make_view(self):

View File

@ -3,10 +3,10 @@ import os
import shutil
import tempfile
from unittest import TestCase
from unittest import IsolatedAsyncioTestCase
class SubiTestCase(TestCase):
class SubiTestCase(IsolatedAsyncioTestCase):
def tmp_dir(self, dir=None, cleanup=True):
# return a full path to a temporary directory that will be cleaned up.
if dir is None:

View File

@ -13,14 +13,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import random
import string
def run_coro(coro):
return asyncio.get_event_loop().run_until_complete(coro)
def random_string():
return ''.join(random.choice(string.ascii_letters) for _ in range(8))