Merge pull request #1335 from dbungert/use-async-test-case
tests: s/run_coro/IsolatedAsyncioTestCase/
This commit is contained in:
commit
b786b4d876
|
@ -21,8 +21,6 @@ import unittest
|
|||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
from subiquitycore.tests.util import run_coro
|
||||
|
||||
from subiquity.common.api.client import make_client
|
||||
from subiquity.common.api.defs import api, Payload
|
||||
|
||||
|
@ -46,9 +44,9 @@ async def makeE2EClient(api, impl,
|
|||
yield make_client(api, mr)
|
||||
|
||||
|
||||
class TestEndToEnd(unittest.TestCase):
|
||||
class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def test_simple(self):
|
||||
async def test_simple(self):
|
||||
@api
|
||||
class API:
|
||||
def GET() -> str: ...
|
||||
|
@ -57,13 +55,10 @@ class TestEndToEnd(unittest.TestCase):
|
|||
async def GET(self) -> str:
|
||||
return 'value'
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(await client.GET(), 'value')
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(await client.GET(), 'value')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_nested(self):
|
||||
async def test_nested(self):
|
||||
@api
|
||||
class API:
|
||||
class endpoint:
|
||||
|
@ -74,13 +69,10 @@ class TestEndToEnd(unittest.TestCase):
|
|||
async def endpoint_nested_GET(self) -> str:
|
||||
return 'value'
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(await client.endpoint.nested.GET(), 'value')
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(await client.endpoint.nested.GET(), 'value')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_args(self):
|
||||
async def test_args(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg1: str, arg2: str) -> str: ...
|
||||
|
@ -89,14 +81,11 @@ class TestEndToEnd(unittest.TestCase):
|
|||
async def GET(self, arg1: str, arg2: str) -> str:
|
||||
return '{}+{}'.format(arg1, arg2)
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(
|
||||
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(
|
||||
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_defaults(self):
|
||||
async def test_defaults(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg1: str, arg2: str = "arg2") -> str: ...
|
||||
|
@ -105,16 +94,13 @@ class TestEndToEnd(unittest.TestCase):
|
|||
async def GET(self, arg1: str, arg2: str = "arg2") -> str:
|
||||
return '{}+{}'.format(arg1, arg2)
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(
|
||||
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||
self.assertEqual(
|
||||
await client.GET(arg1="A"), 'A+arg2')
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(
|
||||
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||
self.assertEqual(
|
||||
await client.GET(arg1="A"), 'A+arg2')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_post(self):
|
||||
async def test_post(self):
|
||||
@api
|
||||
class API:
|
||||
def POST(data: Payload[dict]) -> str: ...
|
||||
|
@ -123,14 +109,11 @@ class TestEndToEnd(unittest.TestCase):
|
|||
async def POST(self, data: dict) -> str:
|
||||
return data['key']
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(
|
||||
await client.POST({'key': 'value'}), 'value')
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(
|
||||
await client.POST({'key': 'value'}), 'value')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_typed(self):
|
||||
async def test_typed(self):
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class In:
|
||||
|
@ -149,14 +132,11 @@ class TestEndToEnd(unittest.TestCase):
|
|||
async def doubler_POST(self, data: In) -> Out:
|
||||
return Out(doubled=data.val*2)
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
out = await client.doubler.POST(In(3))
|
||||
self.assertEqual(out.doubled, 6)
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
out = await client.doubler.POST(In(3))
|
||||
self.assertEqual(out.doubled, 6)
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_middleware(self):
|
||||
async def test_middleware(self):
|
||||
@api
|
||||
class API:
|
||||
def GET() -> int: ...
|
||||
|
@ -182,17 +162,14 @@ class TestEndToEnd(unittest.TestCase):
|
|||
raise Skip
|
||||
yield resp
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(
|
||||
API, Impl(),
|
||||
middlewares=[middleware],
|
||||
make_request=custom_make_request) as client:
|
||||
with self.assertRaises(Skip):
|
||||
await client.GET()
|
||||
async with makeE2EClient(
|
||||
API, Impl(),
|
||||
middlewares=[middleware],
|
||||
make_request=custom_make_request) as client:
|
||||
with self.assertRaises(Skip):
|
||||
await client.GET()
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_error(self):
|
||||
async def test_error(self):
|
||||
@api
|
||||
class API:
|
||||
class good:
|
||||
|
@ -208,16 +185,13 @@ class TestEndToEnd(unittest.TestCase):
|
|||
async def bad_GET(self, x: int) -> int:
|
||||
raise Exception("baz")
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
r = await client.good.GET(2)
|
||||
self.assertEqual(r, 3)
|
||||
with self.assertRaises(aiohttp.ClientResponseError):
|
||||
await client.bad.GET(2)
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
r = await client.good.GET(2)
|
||||
self.assertEqual(r, 3)
|
||||
with self.assertRaises(aiohttp.ClientResponseError):
|
||||
await client.bad.GET(2)
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_error_middleware(self):
|
||||
async def test_error_middleware(self):
|
||||
@api
|
||||
class API:
|
||||
class good:
|
||||
|
@ -251,14 +225,11 @@ class TestEndToEnd(unittest.TestCase):
|
|||
raise Abort
|
||||
yield resp
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(
|
||||
API, Impl(),
|
||||
middlewares=[middleware],
|
||||
make_request=custom_make_request) as client:
|
||||
r = await client.good.GET(2)
|
||||
self.assertEqual(r, 3)
|
||||
with self.assertRaises(Abort):
|
||||
await client.bad.GET(2)
|
||||
|
||||
run_coro(run())
|
||||
async with makeE2EClient(
|
||||
API, Impl(),
|
||||
middlewares=[middleware],
|
||||
make_request=custom_make_request) as client:
|
||||
r = await client.good.GET(2)
|
||||
self.assertEqual(r, 3)
|
||||
with self.assertRaises(Abort):
|
||||
await client.bad.GET(2)
|
||||
|
|
|
@ -20,7 +20,6 @@ from aiohttp.test_utils import TestClient, TestServer
|
|||
from aiohttp import web
|
||||
|
||||
from subiquitycore.context import Context
|
||||
from subiquitycore.tests.util import run_coro
|
||||
|
||||
from subiquity.common.api.defs import api, Payload
|
||||
from subiquity.common.api.server import (
|
||||
|
@ -56,7 +55,7 @@ async def makeTestClient(api, impl, middlewares=()):
|
|||
yield client
|
||||
|
||||
|
||||
class TestBind(unittest.TestCase):
|
||||
class TestBind(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def assertResponse(self, coro, value):
|
||||
resp = await coro
|
||||
|
@ -67,7 +66,7 @@ class TestBind(unittest.TestCase):
|
|||
self.assertEqual(resp.status, 200)
|
||||
self.assertEqual(await resp.json(), value)
|
||||
|
||||
def test_simple(self):
|
||||
async def test_simple(self):
|
||||
@api
|
||||
class API:
|
||||
def GET() -> str: ...
|
||||
|
@ -76,14 +75,11 @@ class TestBind(unittest.TestCase):
|
|||
async def GET(self) -> str:
|
||||
return 'value'
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get("/"), 'value')
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get("/"), 'value')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_nested(self):
|
||||
async def test_nested(self):
|
||||
@api
|
||||
class API:
|
||||
class endpoint:
|
||||
|
@ -94,14 +90,11 @@ class TestBind(unittest.TestCase):
|
|||
async def nested_get(self, request, context):
|
||||
return 'nested'
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API.endpoint, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get("/endpoint/nested"), 'nested')
|
||||
async with makeTestClient(API.endpoint, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get("/endpoint/nested"), 'nested')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_args(self):
|
||||
async def test_args(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg: str): ...
|
||||
|
@ -110,14 +103,11 @@ class TestBind(unittest.TestCase):
|
|||
async def GET(self, arg: str):
|
||||
return arg
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get('/?arg="whut"'), 'whut')
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get('/?arg="whut"'), 'whut')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_missing_argument(self):
|
||||
async def test_missing_argument(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg: str): ...
|
||||
|
@ -126,19 +116,16 @@ class TestBind(unittest.TestCase):
|
|||
async def GET(self, arg: str):
|
||||
return arg
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
resp = await client.get('/')
|
||||
self.assertEqual(resp.status, 500)
|
||||
self.assertEqual(resp.headers['x-status'], 'error')
|
||||
self.assertEqual(resp.headers['x-error-type'], 'TypeError')
|
||||
self.assertEqual(
|
||||
resp.headers['x-error-msg'],
|
||||
'missing required argument "arg"')
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
resp = await client.get('/')
|
||||
self.assertEqual(resp.status, 500)
|
||||
self.assertEqual(resp.headers['x-status'], 'error')
|
||||
self.assertEqual(resp.headers['x-error-type'], 'TypeError')
|
||||
self.assertEqual(
|
||||
resp.headers['x-error-msg'],
|
||||
'missing required argument "arg"')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_error(self):
|
||||
async def test_error(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(): ...
|
||||
|
@ -147,17 +134,14 @@ class TestBind(unittest.TestCase):
|
|||
async def GET(self):
|
||||
return 1/0
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
resp = await client.get('/')
|
||||
self.assertEqual(resp.status, 500)
|
||||
self.assertEqual(resp.headers['x-status'], 'error')
|
||||
self.assertEqual(
|
||||
resp.headers['x-error-type'], 'ZeroDivisionError')
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
resp = await client.get('/')
|
||||
self.assertEqual(resp.status, 500)
|
||||
self.assertEqual(resp.headers['x-status'], 'error')
|
||||
self.assertEqual(
|
||||
resp.headers['x-error-type'], 'ZeroDivisionError')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_post(self):
|
||||
async def test_post(self):
|
||||
@api
|
||||
class API:
|
||||
def POST(data: Payload[str]) -> str: ...
|
||||
|
@ -166,13 +150,10 @@ class TestBind(unittest.TestCase):
|
|||
async def POST(self, data: str) -> str:
|
||||
return data
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.post("/", json='value'),
|
||||
'value')
|
||||
|
||||
run_coro(run())
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.post("/", json='value'),
|
||||
'value')
|
||||
|
||||
def test_missing_method(self):
|
||||
@api
|
||||
|
@ -201,7 +182,7 @@ class TestBind(unittest.TestCase):
|
|||
bind(app.router, API, Impl())
|
||||
self.assertEqual(cm.exception.methname, "API.GET")
|
||||
|
||||
def test_middleware(self):
|
||||
async def test_middleware(self):
|
||||
|
||||
@web.middleware
|
||||
async def middleware(request, handler):
|
||||
|
@ -219,16 +200,13 @@ class TestBind(unittest.TestCase):
|
|||
async def GET(self) -> str:
|
||||
return 1/0
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(
|
||||
API, Impl(), middlewares=[middleware]) as client:
|
||||
resp = await client.get("/")
|
||||
self.assertEqual(
|
||||
resp.headers['x-error-type'], 'ZeroDivisionError')
|
||||
async with makeTestClient(
|
||||
API, Impl(), middlewares=[middleware]) as client:
|
||||
resp = await client.get("/")
|
||||
self.assertEqual(
|
||||
resp.headers['x-error-type'], 'ZeroDivisionError')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_controller_for_request(self):
|
||||
async def test_controller_for_request(self):
|
||||
|
||||
seen_controller = None
|
||||
|
||||
|
@ -249,12 +227,9 @@ class TestBind(unittest.TestCase):
|
|||
|
||||
impl = Impl()
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(
|
||||
API.meth, impl, middlewares=[middleware]) as client:
|
||||
resp = await client.get("/meth")
|
||||
self.assertEqual(await resp.json(), '')
|
||||
|
||||
run_coro(run())
|
||||
async with makeTestClient(
|
||||
API.meth, impl, middlewares=[middleware]) as client:
|
||||
resp = await client.get("/meth")
|
||||
self.assertEqual(await resp.json(), '')
|
||||
|
||||
self.assertIs(impl, seen_controller)
|
||||
|
|
|
@ -19,7 +19,6 @@ from unittest import mock
|
|||
import yaml
|
||||
|
||||
from subiquitycore.pubsub import MessageHub
|
||||
from subiquitycore.tests.util import run_coro
|
||||
|
||||
from subiquity.common.types import IdentityData
|
||||
from subiquity.models.subiquity import ModelNames, SubiquityModel
|
||||
|
@ -63,7 +62,7 @@ class TestModelNames(unittest.TestCase):
|
|||
self.assertEqual(model_names.all(), {'a', 'b', 'c'})
|
||||
|
||||
|
||||
class TestSubiquityModel(unittest.TestCase):
|
||||
class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def writtenFiles(self, config):
|
||||
for k, v in config.get('write_files', {}).items():
|
||||
|
@ -114,7 +113,7 @@ class TestSubiquityModel(unittest.TestCase):
|
|||
cur = cur[component]
|
||||
self.fail("config has value {} for {}".format(cur, path))
|
||||
|
||||
async def _test_configure(self):
|
||||
async def test_configure(self):
|
||||
hub = MessageHub()
|
||||
model = SubiquityModel(
|
||||
'test', hub, ModelNames({'a', 'b'}), ModelNames(set()))
|
||||
|
@ -124,9 +123,6 @@ class TestSubiquityModel(unittest.TestCase):
|
|||
await hub.abroadcast((InstallerChannels.CONFIGURED, 'b'))
|
||||
self.assertTrue(model._install_event.is_set())
|
||||
|
||||
def test_configure(self):
|
||||
run_coro(self._test_configure())
|
||||
|
||||
def make_model(self):
|
||||
return SubiquityModel(
|
||||
'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES)
|
||||
|
|
|
@ -13,13 +13,12 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from unittest import mock, TestCase
|
||||
from unittest import mock, TestCase, IsolatedAsyncioTestCase
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from subiquity.server.controllers.filesystem import FilesystemController
|
||||
|
||||
from subiquitycore.tests.util import run_coro
|
||||
from subiquitycore.tests.mocks import make_app
|
||||
from subiquity.common.types import Bootloader
|
||||
from subiquity.models.tests.test_filesystem import (
|
||||
|
@ -28,7 +27,7 @@ from subiquity.models.tests.test_filesystem import (
|
|||
)
|
||||
|
||||
|
||||
class TestSubiquityControllerFilesystem(TestCase):
|
||||
class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.app = make_app()
|
||||
self.app.opts.bootloader = 'UEFI'
|
||||
|
@ -38,20 +37,20 @@ class TestSubiquityControllerFilesystem(TestCase):
|
|||
self.fsc = FilesystemController(app=self.app)
|
||||
self.fsc._configured = True
|
||||
|
||||
def test_probe_restricted(self):
|
||||
run_coro(self.fsc._probe_once(context=None, restricted=True))
|
||||
async def test_probe_restricted(self):
|
||||
await self.fsc._probe_once(context=None, restricted=True)
|
||||
self.app.prober.get_storage.assert_called_with({'blockdev'})
|
||||
|
||||
def test_probe_os_prober_false(self):
|
||||
async def test_probe_os_prober_false(self):
|
||||
self.app.opts.use_os_prober = False
|
||||
run_coro(self.fsc._probe_once(context=None, restricted=False))
|
||||
await self.fsc._probe_once(context=None, restricted=False)
|
||||
actual = self.app.prober.get_storage.call_args.args[0]
|
||||
self.assertTrue({'defaults'} <= actual)
|
||||
self.assertNotIn('os', actual)
|
||||
|
||||
def test_probe_os_prober_true(self):
|
||||
async def test_probe_os_prober_true(self):
|
||||
self.app.opts.use_os_prober = True
|
||||
run_coro(self.fsc._probe_once(context=None, restricted=False))
|
||||
await self.fsc._probe_once(context=None, restricted=False)
|
||||
actual = self.app.prober.get_storage.call_args.args[0]
|
||||
self.assertTrue({'defaults', 'os'} <= actual)
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ from unittest.mock import Mock, patch, AsyncMock
|
|||
|
||||
from subiquitycore.tests import SubiTestCase
|
||||
from subiquitycore.tests.mocks import make_app
|
||||
from subiquitycore.tests.util import run_coro
|
||||
from subiquity.server.apt import (
|
||||
AptConfigurer,
|
||||
Mountpoint,
|
||||
|
@ -51,14 +50,14 @@ class TestAptConfigurer(SubiTestCase):
|
|||
expected['apt']['https_proxy'] = proxy
|
||||
self.assertEqual(expected, self.configurer.apt_config())
|
||||
|
||||
def test_mount_unmount(self):
|
||||
async def test_mount_unmount(self):
|
||||
# Make sure we can unmount something that we mounted before.
|
||||
with patch.object(self.app, "command_runner",
|
||||
create=True, new_callable=AsyncMock):
|
||||
m = run_coro(self.configurer.mount("/dev/cdrom", "/target"))
|
||||
run_coro(self.configurer.unmount(m))
|
||||
m = await self.configurer.mount("/dev/cdrom", "/target")
|
||||
await self.configurer.unmount(m)
|
||||
|
||||
def test_overlay(self):
|
||||
async def test_overlay(self):
|
||||
self.configurer.install_tree = OverlayMountpoint(
|
||||
upperdir="upperdir-install-tree",
|
||||
lowers=["lowers1-install-tree"],
|
||||
|
@ -71,13 +70,10 @@ class TestAptConfigurer(SubiTestCase):
|
|||
)
|
||||
self.source = "source"
|
||||
|
||||
async def coro():
|
||||
async with self.configurer.overlay():
|
||||
pass
|
||||
|
||||
with patch.object(self.app, "command_runner",
|
||||
create=True, new_callable=AsyncMock):
|
||||
run_coro(coro())
|
||||
async with self.configurer.overlay():
|
||||
pass
|
||||
|
||||
|
||||
class TestLowerDirFor(SubiTestCase):
|
||||
|
|
|
@ -17,7 +17,6 @@ from aioresponses import aioresponses
|
|||
|
||||
from subiquitycore.tests import SubiTestCase
|
||||
from subiquitycore.tests.mocks import make_app
|
||||
from subiquitycore.tests.util import run_coro
|
||||
from subiquity.server.geoip import (
|
||||
GeoIP,
|
||||
HTTPGeoIPStrategy,
|
||||
|
@ -48,16 +47,13 @@ empty_cc = '<Response><CountryCode></CountryCode></Response>'
|
|||
|
||||
|
||||
class TestGeoIP(SubiTestCase):
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
strategy = HTTPGeoIPStrategy()
|
||||
self.geoip = GeoIP(make_app(), strategy)
|
||||
|
||||
async def fn():
|
||||
self.assertTrue(await self.geoip.lookup())
|
||||
|
||||
with aioresponses() as mocked:
|
||||
mocked.get("https://geoip.ubuntu.com/lookup", body=xml)
|
||||
run_coro(fn())
|
||||
self.assertTrue(await self.geoip.lookup())
|
||||
|
||||
def test_countrycode(self):
|
||||
self.assertEqual("us", self.geoip.countrycode)
|
||||
|
@ -71,42 +67,32 @@ class TestGeoIPBadData(SubiTestCase):
|
|||
strategy = HTTPGeoIPStrategy()
|
||||
self.geoip = GeoIP(make_app(), strategy)
|
||||
|
||||
def test_partial_reponse(self):
|
||||
async def fn():
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
async def test_partial_reponse(self):
|
||||
with aioresponses() as mocked:
|
||||
mocked.get("https://geoip.ubuntu.com/lookup", body=partial)
|
||||
run_coro(fn())
|
||||
|
||||
def test_incomplete(self):
|
||||
async def fn():
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
|
||||
async def test_incomplete(self):
|
||||
with aioresponses() as mocked:
|
||||
mocked.get("https://geoip.ubuntu.com/lookup", body=incomplete)
|
||||
run_coro(fn())
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
self.assertIsNone(self.geoip.countrycode)
|
||||
self.assertIsNone(self.geoip.timezone)
|
||||
|
||||
def test_long_cc(self):
|
||||
async def fn():
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
async def test_long_cc(self):
|
||||
with aioresponses() as mocked:
|
||||
mocked.get("https://geoip.ubuntu.com/lookup", body=long_cc)
|
||||
run_coro(fn())
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
self.assertIsNone(self.geoip.countrycode)
|
||||
|
||||
def test_empty_cc(self):
|
||||
async def fn():
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
async def test_empty_cc(self):
|
||||
with aioresponses() as mocked:
|
||||
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_cc)
|
||||
run_coro(fn())
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
self.assertIsNone(self.geoip.countrycode)
|
||||
|
||||
def test_empty_tz(self):
|
||||
async def fn():
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
async def test_empty_tz(self):
|
||||
with aioresponses() as mocked:
|
||||
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_tz)
|
||||
run_coro(fn())
|
||||
self.assertFalse(await self.geoip.lookup())
|
||||
self.assertIsNone(self.geoip.timezone)
|
||||
|
|
|
@ -26,37 +26,36 @@ from subiquity.server.ubuntu_advantage import (
|
|||
MockedUAInterfaceStrategy,
|
||||
UAClientUAInterfaceStrategy,
|
||||
)
|
||||
from subiquitycore.tests.util import run_coro
|
||||
|
||||
|
||||
class TestMockedUAInterfaceStrategy(unittest.TestCase):
|
||||
class TestMockedUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
|
||||
|
||||
def test_query_info_invalid(self):
|
||||
async def test_query_info_invalid(self):
|
||||
# Tokens starting with "i" in dry-run mode cause the token to be
|
||||
# reported as invalid.
|
||||
with self.assertRaises(InvalidTokenError):
|
||||
run_coro(self.strategy.query_info(token="invalidToken"))
|
||||
await self.strategy.query_info(token="invalidToken")
|
||||
|
||||
def test_query_info_failure(self):
|
||||
async def test_query_info_failure(self):
|
||||
# Tokens starting with "f" in dry-run mode simulate an "internal"
|
||||
# error.
|
||||
with self.assertRaises(CheckSubscriptionError):
|
||||
run_coro(self.strategy.query_info(token="failure"))
|
||||
await self.strategy.query_info(token="failure")
|
||||
|
||||
def test_query_info_expired(self):
|
||||
async def test_query_info_expired(self):
|
||||
# Tokens starting with "x" is dry-run mode simulate an expired token.
|
||||
info = run_coro(self.strategy.query_info(token="xpiredToken"))
|
||||
info = await self.strategy.query_info(token="xpiredToken")
|
||||
self.assertEqual(info["expires"], "2010-12-31T00:00:00+00:00")
|
||||
|
||||
def test_query_info_valid(self):
|
||||
async def test_query_info_valid(self):
|
||||
# Other tokens are considered valid in dry-run mode.
|
||||
info = run_coro(self.strategy.query_info(token="validToken"))
|
||||
info = await self.strategy.query_info(token="validToken")
|
||||
self.assertEqual(info["expires"], "2035-12-31T00:00:00+00:00")
|
||||
|
||||
|
||||
class TestUAClientUAInterfaceStrategy(unittest.TestCase):
|
||||
class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase):
|
||||
arun_command_sym = "subiquity.server.ubuntu_advantage.utils.arun_command"
|
||||
|
||||
def test_init(self):
|
||||
|
@ -75,7 +74,7 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
|
|||
self.assertEqual(strategy.executable,
|
||||
["python3", "/usr/bin/ubuntu-advantage"])
|
||||
|
||||
def test_query_info_succeeded(self):
|
||||
async def test_query_info_succeeded(self):
|
||||
strategy = UAClientUAInterfaceStrategy()
|
||||
command = (
|
||||
"ubuntu-advantage",
|
||||
|
@ -87,10 +86,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
|
|||
with patch(self.arun_command_sym) as mock_arun:
|
||||
mock_arun.return_value = CompletedProcess([], 0)
|
||||
mock_arun.return_value.stdout = "{}"
|
||||
run_coro(strategy.query_info(token="123456789"))
|
||||
await strategy.query_info(token="123456789")
|
||||
mock_arun.assert_called_once_with(command, check=True)
|
||||
|
||||
def test_query_info_failed(self):
|
||||
async def test_query_info_failed(self):
|
||||
strategy = UAClientUAInterfaceStrategy()
|
||||
command = (
|
||||
"ubuntu-advantage",
|
||||
|
@ -104,10 +103,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
|
|||
cmd=command)
|
||||
mock_arun.return_value.stdout = "{}"
|
||||
with self.assertRaises(CheckSubscriptionError):
|
||||
run_coro(strategy.query_info(token="123456789"))
|
||||
await strategy.query_info(token="123456789")
|
||||
mock_arun.assert_called_once_with(command, check=True)
|
||||
|
||||
def test_query_info_invalid_json(self):
|
||||
async def test_query_info_invalid_json(self):
|
||||
strategy = UAClientUAInterfaceStrategy()
|
||||
command = (
|
||||
"ubuntu-advantage",
|
||||
|
@ -120,31 +119,31 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
|
|||
mock_arun.return_value = CompletedProcess([], 0)
|
||||
mock_arun.return_value.stdout = "invalid-json"
|
||||
with self.assertRaises(CheckSubscriptionError):
|
||||
run_coro(strategy.query_info(token="123456789"))
|
||||
await strategy.query_info(token="123456789")
|
||||
mock_arun.assert_called_once_with(command, check=True)
|
||||
|
||||
|
||||
class TestUAInterface(unittest.TestCase):
|
||||
class TestUAInterface(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def test_mocked_get_activable_services(self):
|
||||
async def test_mocked_get_activable_services(self):
|
||||
strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
|
||||
interface = UAInterface(strategy)
|
||||
|
||||
with self.assertRaises(InvalidTokenError):
|
||||
run_coro(interface.get_activable_services(token="invalidToken"))
|
||||
await interface.get_activable_services(token="invalidToken")
|
||||
# Tokens starting with "f" in dry-run mode simulate an "internal"
|
||||
# error.
|
||||
with self.assertRaises(CheckSubscriptionError):
|
||||
run_coro(interface.get_activable_services(token="failure"))
|
||||
await interface.get_activable_services(token="failure")
|
||||
|
||||
# Tokens starting with "x" is dry-run mode simulate an expired token.
|
||||
with self.assertRaises(ExpiredTokenError):
|
||||
run_coro(interface.get_activable_services(token="xpiredToken"))
|
||||
await interface.get_activable_services(token="xpiredToken")
|
||||
|
||||
# Other tokens are considered valid in dry-run mode.
|
||||
run_coro(interface.get_activable_services(token="validToken"))
|
||||
await interface.get_activable_services(token="validToken")
|
||||
|
||||
def test_get_activable_services(self):
|
||||
async def test_get_activable_services(self):
|
||||
# We use the standard strategy but don't actually run it
|
||||
strategy = UAClientUAInterfaceStrategy()
|
||||
interface = UAInterface(strategy)
|
||||
|
@ -185,8 +184,7 @@ class TestUAInterface(unittest.TestCase):
|
|||
]
|
||||
}
|
||||
interface.get_subscription = AsyncMock(return_value=subscription)
|
||||
services = run_coro(
|
||||
interface.get_activable_services(token="XXX"))
|
||||
services = await interface.get_activable_services(token="XXX")
|
||||
|
||||
self.assertIn(UbuntuProService(
|
||||
name="esm-infra",
|
||||
|
@ -211,5 +209,4 @@ class TestUAInterface(unittest.TestCase):
|
|||
|
||||
# Test with "Z" suffix for the expiration date.
|
||||
subscription["expires"] = "2035-12-31T00:00:00Z"
|
||||
services = run_coro(
|
||||
interface.get_activable_services(token="XXX"))
|
||||
services = await interface.get_activable_services(token="XXX")
|
||||
|
|
|
@ -18,7 +18,6 @@ import unittest
|
|||
from unittest.mock import patch, AsyncMock, Mock
|
||||
|
||||
from subiquitycore.tests.mocks import make_app
|
||||
from subiquitycore.tests.util import run_coro
|
||||
|
||||
from subiquity.server.ubuntu_drivers import (
|
||||
UbuntuDriversInterface,
|
||||
|
@ -28,7 +27,7 @@ from subiquity.server.ubuntu_drivers import (
|
|||
)
|
||||
|
||||
|
||||
class TestUbuntuDriversInterface(unittest.TestCase):
|
||||
class TestUbuntuDriversInterface(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.app = make_app()
|
||||
|
||||
|
@ -54,11 +53,11 @@ class TestUbuntuDriversInterface(unittest.TestCase):
|
|||
|
||||
@patch.multiple(UbuntuDriversInterface, __abstractmethods__=set())
|
||||
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
|
||||
def test_install_drivers(self, mock_run_curtin_command):
|
||||
async def test_install_drivers(self, mock_run_curtin_command):
|
||||
ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False)
|
||||
run_coro(ubuntu_drivers.install_drivers(
|
||||
await ubuntu_drivers.install_drivers(
|
||||
root_dir="/target",
|
||||
context="installing third-party drivers"))
|
||||
context="installing third-party drivers")
|
||||
mock_run_curtin_command.assert_called_once_with(
|
||||
self.app, "installing third-party drivers",
|
||||
"in-target", "-t", "/target",
|
||||
|
@ -91,18 +90,18 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
|
|||
["nvidia-driver-470", "nvidia-driver-510"])
|
||||
|
||||
|
||||
class TestUbuntuDriversClientInterface(unittest.TestCase):
|
||||
class TestUbuntuDriversClientInterface(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.app = make_app()
|
||||
self.ubuntu_drivers = UbuntuDriversClientInterface(
|
||||
self.app, gpgpu=False)
|
||||
|
||||
def test_ensure_cmd_exists(self):
|
||||
async def test_ensure_cmd_exists(self):
|
||||
with patch.object(
|
||||
self.app, "command_runner",
|
||||
create=True, new_callable=AsyncMock) as mock_runner:
|
||||
# On success
|
||||
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
|
||||
await self.ubuntu_drivers.ensure_cmd_exists("/target")
|
||||
mock_runner.run.assert_called_once_with(
|
||||
[
|
||||
"chroot", "/target",
|
||||
|
@ -115,17 +114,17 @@ class TestUbuntuDriversClientInterface(unittest.TestCase):
|
|||
cmd=["sh", "-c", "command -v ubuntu-drivers"])
|
||||
|
||||
with self.assertRaises(CommandNotFoundError):
|
||||
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
|
||||
await self.ubuntu_drivers.ensure_cmd_exists("/target")
|
||||
|
||||
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
|
||||
def test_list_drivers(self, mock_run_curtin_command):
|
||||
async def test_list_drivers(self, mock_run_curtin_command):
|
||||
# Make sure this gets decoded as utf-8.
|
||||
mock_run_curtin_command.return_value = Mock(stdout=b"""\
|
||||
nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
|
||||
""")
|
||||
drivers = run_coro(self.ubuntu_drivers.list_drivers(
|
||||
drivers = await self.ubuntu_drivers.list_drivers(
|
||||
root_dir="/target",
|
||||
context="listing third-party drivers"))
|
||||
context="listing third-party drivers")
|
||||
|
||||
mock_run_curtin_command.assert_called_once_with(
|
||||
self.app, "listing third-party drivers",
|
||||
|
@ -137,15 +136,15 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
|
|||
self.assertEqual(drivers, ["nvidia-driver-510"])
|
||||
|
||||
|
||||
class TestUbuntuDriversRunDriversInterface(unittest.TestCase):
|
||||
class TestUbuntuDriversRunDriversInterface(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.app = make_app()
|
||||
self.ubuntu_drivers = UbuntuDriversRunDriversInterface(
|
||||
self.app, gpgpu=False)
|
||||
|
||||
@patch("subiquity.server.ubuntu_drivers.arun_command")
|
||||
def test_ensure_cmd_exists(self, mock_arun_command):
|
||||
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
|
||||
async def test_ensure_cmd_exists(self, mock_arun_command):
|
||||
await self.ubuntu_drivers.ensure_cmd_exists("/target")
|
||||
mock_arun_command.assert_called_once_with(
|
||||
["sh", "-c", "command -v ubuntu-drivers"],
|
||||
check=True)
|
||||
|
|
|
@ -7,7 +7,6 @@ from functools import wraps
|
|||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import unquote
|
||||
|
||||
|
@ -134,7 +133,7 @@ class Server(Client):
|
|||
pass
|
||||
|
||||
|
||||
class TestAPI(unittest.IsolatedAsyncioTestCase, SubiTestCase):
|
||||
class TestAPI(SubiTestCase):
|
||||
class _MachineConfig(os.PathLike):
|
||||
def __init__(self, outer, path):
|
||||
self.outer = outer
|
||||
|
|
|
@ -20,7 +20,6 @@ from subiquity.models.timezone import TimeZoneModel
|
|||
from subiquity.server.controllers.timezone import TimeZoneController
|
||||
from subiquitycore.tests import SubiTestCase
|
||||
from subiquitycore.tests.mocks import make_app
|
||||
from subiquitycore.tests.util import run_coro
|
||||
|
||||
|
||||
class MockGeoIP:
|
||||
|
@ -47,7 +46,7 @@ class TestTimeZoneController(SubiTestCase):
|
|||
|
||||
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
|
||||
@mock.patch('subiquity.server.controllers.timezone.timedatectl_gettz')
|
||||
def test_good_tzs(self, tdc_gettz, tdc_settz):
|
||||
async def test_good_tzs(self, tdc_gettz, tdc_settz):
|
||||
tdc_gettz.return_value = tz_utc
|
||||
goods = [
|
||||
# val - autoinstall value
|
||||
|
@ -77,7 +76,7 @@ class TestTimeZoneController(SubiTestCase):
|
|||
self.tzc.model)
|
||||
self.assertEqual(geoip, self.tzc.model.detect_with_geoip,
|
||||
self.tzc.model)
|
||||
self.assertEqual(tz, run_coro(self.tzc.GET()), self.tzc.model)
|
||||
self.assertEqual(tz, await self.tzc.GET(), self.tzc.model)
|
||||
cloudconfig = {}
|
||||
if self.tzc.model.should_set_tz:
|
||||
cloudconfig = {'timezone': tz.timezone}
|
||||
|
@ -104,7 +103,7 @@ class TestTimeZoneController(SubiTestCase):
|
|||
self.assertEqual('sleep', subprocess_run.call_args.args[0][0])
|
||||
|
||||
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
|
||||
def test_get_tz_should_not_set(self, tdc_settz):
|
||||
run_coro(self.tzc.GET())
|
||||
async def test_get_tz_should_not_set(self, tdc_settz):
|
||||
await self.tzc.GET()
|
||||
self.assertFalse(self.tzc.model.should_set_tz)
|
||||
tdc_settz.assert_not_called()
|
||||
|
|
|
@ -1,4 +1,18 @@
|
|||
import asyncio
|
||||
# Copyright 2017-2022 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
|
@ -42,20 +56,6 @@ system_reserved = {
|
|||
}
|
||||
|
||||
|
||||
def tearDownModule() -> None:
|
||||
# Set empty loop policy, so that subsequent get_event_loop() returns a new
|
||||
# loop. If there is no running event loop set, that function will return
|
||||
# the result of `get_event_loop_policy().get_event_loop()` call. If there
|
||||
# is a policy there must be a running loop. IsolatedAsyncioTestCase
|
||||
# closes the loop during tear down, though. It doesn't touch the policy.
|
||||
# By having it as None, it autoinits and the next tests run smoothly.
|
||||
# Another approach would be set a new event_loop for the current policy on
|
||||
# test fixture tearDown, as pytest-asyncio does.
|
||||
# Either way we would prevent failure on tests that depend on [run_coro]
|
||||
# (subiquitycore/tests/util.py), for instance.
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class IdentityViewTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def make_view(self):
|
||||
|
|
|
@ -3,10 +3,10 @@ import os
|
|||
import shutil
|
||||
import tempfile
|
||||
|
||||
from unittest import TestCase
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
|
||||
class SubiTestCase(TestCase):
|
||||
class SubiTestCase(IsolatedAsyncioTestCase):
|
||||
def tmp_dir(self, dir=None, cleanup=True):
|
||||
# return a full path to a temporary directory that will be cleaned up.
|
||||
if dir is None:
|
||||
|
|
|
@ -13,14 +13,9 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import string
|
||||
|
||||
|
||||
def run_coro(coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
def random_string():
|
||||
return ''.join(random.choice(string.ascii_letters) for _ in range(8))
|
||||
|
|
Loading…
Reference in New Issue