Merge pull request #1335 from dbungert/use-async-test-case
tests: s/run_coro/IsolatedAsyncioTestCase/
This commit is contained in:
commit
b786b4d876
|
@ -21,8 +21,6 @@ import unittest
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from subiquitycore.tests.util import run_coro
|
|
||||||
|
|
||||||
from subiquity.common.api.client import make_client
|
from subiquity.common.api.client import make_client
|
||||||
from subiquity.common.api.defs import api, Payload
|
from subiquity.common.api.defs import api, Payload
|
||||||
|
|
||||||
|
@ -46,9 +44,9 @@ async def makeE2EClient(api, impl,
|
||||||
yield make_client(api, mr)
|
yield make_client(api, mr)
|
||||||
|
|
||||||
|
|
||||||
class TestEndToEnd(unittest.TestCase):
|
class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
def test_simple(self):
|
async def test_simple(self):
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
def GET() -> str: ...
|
def GET() -> str: ...
|
||||||
|
@ -57,13 +55,10 @@ class TestEndToEnd(unittest.TestCase):
|
||||||
async def GET(self) -> str:
|
async def GET(self) -> str:
|
||||||
return 'value'
|
return 'value'
|
||||||
|
|
||||||
async def run():
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
async with makeE2EClient(API, Impl()) as client:
|
self.assertEqual(await client.GET(), 'value')
|
||||||
self.assertEqual(await client.GET(), 'value')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_nested(self):
|
||||||
|
|
||||||
def test_nested(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
class endpoint:
|
class endpoint:
|
||||||
|
@ -74,13 +69,10 @@ class TestEndToEnd(unittest.TestCase):
|
||||||
async def endpoint_nested_GET(self) -> str:
|
async def endpoint_nested_GET(self) -> str:
|
||||||
return 'value'
|
return 'value'
|
||||||
|
|
||||||
async def run():
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
async with makeE2EClient(API, Impl()) as client:
|
self.assertEqual(await client.endpoint.nested.GET(), 'value')
|
||||||
self.assertEqual(await client.endpoint.nested.GET(), 'value')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_args(self):
|
||||||
|
|
||||||
def test_args(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
def GET(arg1: str, arg2: str) -> str: ...
|
def GET(arg1: str, arg2: str) -> str: ...
|
||||||
|
@ -89,14 +81,11 @@ class TestEndToEnd(unittest.TestCase):
|
||||||
async def GET(self, arg1: str, arg2: str) -> str:
|
async def GET(self, arg1: str, arg2: str) -> str:
|
||||||
return '{}+{}'.format(arg1, arg2)
|
return '{}+{}'.format(arg1, arg2)
|
||||||
|
|
||||||
async def run():
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
async with makeE2EClient(API, Impl()) as client:
|
self.assertEqual(
|
||||||
self.assertEqual(
|
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||||
await client.GET(arg1="A", arg2="B"), 'A+B')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_defaults(self):
|
||||||
|
|
||||||
def test_defaults(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
def GET(arg1: str, arg2: str = "arg2") -> str: ...
|
def GET(arg1: str, arg2: str = "arg2") -> str: ...
|
||||||
|
@ -105,16 +94,13 @@ class TestEndToEnd(unittest.TestCase):
|
||||||
async def GET(self, arg1: str, arg2: str = "arg2") -> str:
|
async def GET(self, arg1: str, arg2: str = "arg2") -> str:
|
||||||
return '{}+{}'.format(arg1, arg2)
|
return '{}+{}'.format(arg1, arg2)
|
||||||
|
|
||||||
async def run():
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
async with makeE2EClient(API, Impl()) as client:
|
self.assertEqual(
|
||||||
self.assertEqual(
|
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||||
await client.GET(arg1="A", arg2="B"), 'A+B')
|
self.assertEqual(
|
||||||
self.assertEqual(
|
await client.GET(arg1="A"), 'A+arg2')
|
||||||
await client.GET(arg1="A"), 'A+arg2')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_post(self):
|
||||||
|
|
||||||
def test_post(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
def POST(data: Payload[dict]) -> str: ...
|
def POST(data: Payload[dict]) -> str: ...
|
||||||
|
@ -123,14 +109,11 @@ class TestEndToEnd(unittest.TestCase):
|
||||||
async def POST(self, data: dict) -> str:
|
async def POST(self, data: dict) -> str:
|
||||||
return data['key']
|
return data['key']
|
||||||
|
|
||||||
async def run():
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
async with makeE2EClient(API, Impl()) as client:
|
self.assertEqual(
|
||||||
self.assertEqual(
|
await client.POST({'key': 'value'}), 'value')
|
||||||
await client.POST({'key': 'value'}), 'value')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_typed(self):
|
||||||
|
|
||||||
def test_typed(self):
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True)
|
@attr.s(auto_attribs=True)
|
||||||
class In:
|
class In:
|
||||||
|
@ -149,14 +132,11 @@ class TestEndToEnd(unittest.TestCase):
|
||||||
async def doubler_POST(self, data: In) -> Out:
|
async def doubler_POST(self, data: In) -> Out:
|
||||||
return Out(doubled=data.val*2)
|
return Out(doubled=data.val*2)
|
||||||
|
|
||||||
async def run():
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
async with makeE2EClient(API, Impl()) as client:
|
out = await client.doubler.POST(In(3))
|
||||||
out = await client.doubler.POST(In(3))
|
self.assertEqual(out.doubled, 6)
|
||||||
self.assertEqual(out.doubled, 6)
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_middleware(self):
|
||||||
|
|
||||||
def test_middleware(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
def GET() -> int: ...
|
def GET() -> int: ...
|
||||||
|
@ -182,17 +162,14 @@ class TestEndToEnd(unittest.TestCase):
|
||||||
raise Skip
|
raise Skip
|
||||||
yield resp
|
yield resp
|
||||||
|
|
||||||
async def run():
|
async with makeE2EClient(
|
||||||
async with makeE2EClient(
|
API, Impl(),
|
||||||
API, Impl(),
|
middlewares=[middleware],
|
||||||
middlewares=[middleware],
|
make_request=custom_make_request) as client:
|
||||||
make_request=custom_make_request) as client:
|
with self.assertRaises(Skip):
|
||||||
with self.assertRaises(Skip):
|
await client.GET()
|
||||||
await client.GET()
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_error(self):
|
||||||
|
|
||||||
def test_error(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
class good:
|
class good:
|
||||||
|
@ -208,16 +185,13 @@ class TestEndToEnd(unittest.TestCase):
|
||||||
async def bad_GET(self, x: int) -> int:
|
async def bad_GET(self, x: int) -> int:
|
||||||
raise Exception("baz")
|
raise Exception("baz")
|
||||||
|
|
||||||
async def run():
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
async with makeE2EClient(API, Impl()) as client:
|
r = await client.good.GET(2)
|
||||||
r = await client.good.GET(2)
|
self.assertEqual(r, 3)
|
||||||
self.assertEqual(r, 3)
|
with self.assertRaises(aiohttp.ClientResponseError):
|
||||||
with self.assertRaises(aiohttp.ClientResponseError):
|
await client.bad.GET(2)
|
||||||
await client.bad.GET(2)
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_error_middleware(self):
|
||||||
|
|
||||||
def test_error_middleware(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
class good:
|
class good:
|
||||||
|
@ -251,14 +225,11 @@ class TestEndToEnd(unittest.TestCase):
|
||||||
raise Abort
|
raise Abort
|
||||||
yield resp
|
yield resp
|
||||||
|
|
||||||
async def run():
|
async with makeE2EClient(
|
||||||
async with makeE2EClient(
|
API, Impl(),
|
||||||
API, Impl(),
|
middlewares=[middleware],
|
||||||
middlewares=[middleware],
|
make_request=custom_make_request) as client:
|
||||||
make_request=custom_make_request) as client:
|
r = await client.good.GET(2)
|
||||||
r = await client.good.GET(2)
|
self.assertEqual(r, 3)
|
||||||
self.assertEqual(r, 3)
|
with self.assertRaises(Abort):
|
||||||
with self.assertRaises(Abort):
|
await client.bad.GET(2)
|
||||||
await client.bad.GET(2)
|
|
||||||
|
|
||||||
run_coro(run())
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ from aiohttp.test_utils import TestClient, TestServer
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from subiquitycore.context import Context
|
from subiquitycore.context import Context
|
||||||
from subiquitycore.tests.util import run_coro
|
|
||||||
|
|
||||||
from subiquity.common.api.defs import api, Payload
|
from subiquity.common.api.defs import api, Payload
|
||||||
from subiquity.common.api.server import (
|
from subiquity.common.api.server import (
|
||||||
|
@ -56,7 +55,7 @@ async def makeTestClient(api, impl, middlewares=()):
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
class TestBind(unittest.TestCase):
|
class TestBind(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def assertResponse(self, coro, value):
|
async def assertResponse(self, coro, value):
|
||||||
resp = await coro
|
resp = await coro
|
||||||
|
@ -67,7 +66,7 @@ class TestBind(unittest.TestCase):
|
||||||
self.assertEqual(resp.status, 200)
|
self.assertEqual(resp.status, 200)
|
||||||
self.assertEqual(await resp.json(), value)
|
self.assertEqual(await resp.json(), value)
|
||||||
|
|
||||||
def test_simple(self):
|
async def test_simple(self):
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
def GET() -> str: ...
|
def GET() -> str: ...
|
||||||
|
@ -76,14 +75,11 @@ class TestBind(unittest.TestCase):
|
||||||
async def GET(self) -> str:
|
async def GET(self) -> str:
|
||||||
return 'value'
|
return 'value'
|
||||||
|
|
||||||
async def run():
|
async with makeTestClient(API, Impl()) as client:
|
||||||
async with makeTestClient(API, Impl()) as client:
|
await self.assertResponse(
|
||||||
await self.assertResponse(
|
client.get("/"), 'value')
|
||||||
client.get("/"), 'value')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_nested(self):
|
||||||
|
|
||||||
def test_nested(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
class endpoint:
|
class endpoint:
|
||||||
|
@ -94,14 +90,11 @@ class TestBind(unittest.TestCase):
|
||||||
async def nested_get(self, request, context):
|
async def nested_get(self, request, context):
|
||||||
return 'nested'
|
return 'nested'
|
||||||
|
|
||||||
async def run():
|
async with makeTestClient(API.endpoint, Impl()) as client:
|
||||||
async with makeTestClient(API.endpoint, Impl()) as client:
|
await self.assertResponse(
|
||||||
await self.assertResponse(
|
client.get("/endpoint/nested"), 'nested')
|
||||||
client.get("/endpoint/nested"), 'nested')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_args(self):
|
||||||
|
|
||||||
def test_args(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
def GET(arg: str): ...
|
def GET(arg: str): ...
|
||||||
|
@ -110,14 +103,11 @@ class TestBind(unittest.TestCase):
|
||||||
async def GET(self, arg: str):
|
async def GET(self, arg: str):
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
async def run():
|
async with makeTestClient(API, Impl()) as client:
|
||||||
async with makeTestClient(API, Impl()) as client:
|
await self.assertResponse(
|
||||||
await self.assertResponse(
|
client.get('/?arg="whut"'), 'whut')
|
||||||
client.get('/?arg="whut"'), 'whut')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_missing_argument(self):
|
||||||
|
|
||||||
def test_missing_argument(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
def GET(arg: str): ...
|
def GET(arg: str): ...
|
||||||
|
@ -126,19 +116,16 @@ class TestBind(unittest.TestCase):
|
||||||
async def GET(self, arg: str):
|
async def GET(self, arg: str):
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
async def run():
|
async with makeTestClient(API, Impl()) as client:
|
||||||
async with makeTestClient(API, Impl()) as client:
|
resp = await client.get('/')
|
||||||
resp = await client.get('/')
|
self.assertEqual(resp.status, 500)
|
||||||
self.assertEqual(resp.status, 500)
|
self.assertEqual(resp.headers['x-status'], 'error')
|
||||||
self.assertEqual(resp.headers['x-status'], 'error')
|
self.assertEqual(resp.headers['x-error-type'], 'TypeError')
|
||||||
self.assertEqual(resp.headers['x-error-type'], 'TypeError')
|
self.assertEqual(
|
||||||
self.assertEqual(
|
resp.headers['x-error-msg'],
|
||||||
resp.headers['x-error-msg'],
|
'missing required argument "arg"')
|
||||||
'missing required argument "arg"')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_error(self):
|
||||||
|
|
||||||
def test_error(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
def GET(): ...
|
def GET(): ...
|
||||||
|
@ -147,17 +134,14 @@ class TestBind(unittest.TestCase):
|
||||||
async def GET(self):
|
async def GET(self):
|
||||||
return 1/0
|
return 1/0
|
||||||
|
|
||||||
async def run():
|
async with makeTestClient(API, Impl()) as client:
|
||||||
async with makeTestClient(API, Impl()) as client:
|
resp = await client.get('/')
|
||||||
resp = await client.get('/')
|
self.assertEqual(resp.status, 500)
|
||||||
self.assertEqual(resp.status, 500)
|
self.assertEqual(resp.headers['x-status'], 'error')
|
||||||
self.assertEqual(resp.headers['x-status'], 'error')
|
self.assertEqual(
|
||||||
self.assertEqual(
|
resp.headers['x-error-type'], 'ZeroDivisionError')
|
||||||
resp.headers['x-error-type'], 'ZeroDivisionError')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_post(self):
|
||||||
|
|
||||||
def test_post(self):
|
|
||||||
@api
|
@api
|
||||||
class API:
|
class API:
|
||||||
def POST(data: Payload[str]) -> str: ...
|
def POST(data: Payload[str]) -> str: ...
|
||||||
|
@ -166,13 +150,10 @@ class TestBind(unittest.TestCase):
|
||||||
async def POST(self, data: str) -> str:
|
async def POST(self, data: str) -> str:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def run():
|
async with makeTestClient(API, Impl()) as client:
|
||||||
async with makeTestClient(API, Impl()) as client:
|
await self.assertResponse(
|
||||||
await self.assertResponse(
|
client.post("/", json='value'),
|
||||||
client.post("/", json='value'),
|
'value')
|
||||||
'value')
|
|
||||||
|
|
||||||
run_coro(run())
|
|
||||||
|
|
||||||
def test_missing_method(self):
|
def test_missing_method(self):
|
||||||
@api
|
@api
|
||||||
|
@ -201,7 +182,7 @@ class TestBind(unittest.TestCase):
|
||||||
bind(app.router, API, Impl())
|
bind(app.router, API, Impl())
|
||||||
self.assertEqual(cm.exception.methname, "API.GET")
|
self.assertEqual(cm.exception.methname, "API.GET")
|
||||||
|
|
||||||
def test_middleware(self):
|
async def test_middleware(self):
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def middleware(request, handler):
|
async def middleware(request, handler):
|
||||||
|
@ -219,16 +200,13 @@ class TestBind(unittest.TestCase):
|
||||||
async def GET(self) -> str:
|
async def GET(self) -> str:
|
||||||
return 1/0
|
return 1/0
|
||||||
|
|
||||||
async def run():
|
async with makeTestClient(
|
||||||
async with makeTestClient(
|
API, Impl(), middlewares=[middleware]) as client:
|
||||||
API, Impl(), middlewares=[middleware]) as client:
|
resp = await client.get("/")
|
||||||
resp = await client.get("/")
|
self.assertEqual(
|
||||||
self.assertEqual(
|
resp.headers['x-error-type'], 'ZeroDivisionError')
|
||||||
resp.headers['x-error-type'], 'ZeroDivisionError')
|
|
||||||
|
|
||||||
run_coro(run())
|
async def test_controller_for_request(self):
|
||||||
|
|
||||||
def test_controller_for_request(self):
|
|
||||||
|
|
||||||
seen_controller = None
|
seen_controller = None
|
||||||
|
|
||||||
|
@ -249,12 +227,9 @@ class TestBind(unittest.TestCase):
|
||||||
|
|
||||||
impl = Impl()
|
impl = Impl()
|
||||||
|
|
||||||
async def run():
|
async with makeTestClient(
|
||||||
async with makeTestClient(
|
API.meth, impl, middlewares=[middleware]) as client:
|
||||||
API.meth, impl, middlewares=[middleware]) as client:
|
resp = await client.get("/meth")
|
||||||
resp = await client.get("/meth")
|
self.assertEqual(await resp.json(), '')
|
||||||
self.assertEqual(await resp.json(), '')
|
|
||||||
|
|
||||||
run_coro(run())
|
|
||||||
|
|
||||||
self.assertIs(impl, seen_controller)
|
self.assertIs(impl, seen_controller)
|
||||||
|
|
|
@ -19,7 +19,6 @@ from unittest import mock
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from subiquitycore.pubsub import MessageHub
|
from subiquitycore.pubsub import MessageHub
|
||||||
from subiquitycore.tests.util import run_coro
|
|
||||||
|
|
||||||
from subiquity.common.types import IdentityData
|
from subiquity.common.types import IdentityData
|
||||||
from subiquity.models.subiquity import ModelNames, SubiquityModel
|
from subiquity.models.subiquity import ModelNames, SubiquityModel
|
||||||
|
@ -63,7 +62,7 @@ class TestModelNames(unittest.TestCase):
|
||||||
self.assertEqual(model_names.all(), {'a', 'b', 'c'})
|
self.assertEqual(model_names.all(), {'a', 'b', 'c'})
|
||||||
|
|
||||||
|
|
||||||
class TestSubiquityModel(unittest.TestCase):
|
class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
def writtenFiles(self, config):
|
def writtenFiles(self, config):
|
||||||
for k, v in config.get('write_files', {}).items():
|
for k, v in config.get('write_files', {}).items():
|
||||||
|
@ -114,7 +113,7 @@ class TestSubiquityModel(unittest.TestCase):
|
||||||
cur = cur[component]
|
cur = cur[component]
|
||||||
self.fail("config has value {} for {}".format(cur, path))
|
self.fail("config has value {} for {}".format(cur, path))
|
||||||
|
|
||||||
async def _test_configure(self):
|
async def test_configure(self):
|
||||||
hub = MessageHub()
|
hub = MessageHub()
|
||||||
model = SubiquityModel(
|
model = SubiquityModel(
|
||||||
'test', hub, ModelNames({'a', 'b'}), ModelNames(set()))
|
'test', hub, ModelNames({'a', 'b'}), ModelNames(set()))
|
||||||
|
@ -124,9 +123,6 @@ class TestSubiquityModel(unittest.TestCase):
|
||||||
await hub.abroadcast((InstallerChannels.CONFIGURED, 'b'))
|
await hub.abroadcast((InstallerChannels.CONFIGURED, 'b'))
|
||||||
self.assertTrue(model._install_event.is_set())
|
self.assertTrue(model._install_event.is_set())
|
||||||
|
|
||||||
def test_configure(self):
|
|
||||||
run_coro(self._test_configure())
|
|
||||||
|
|
||||||
def make_model(self):
|
def make_model(self):
|
||||||
return SubiquityModel(
|
return SubiquityModel(
|
||||||
'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES)
|
'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES)
|
||||||
|
|
|
@ -13,13 +13,12 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from unittest import mock, TestCase
|
from unittest import mock, TestCase, IsolatedAsyncioTestCase
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from subiquity.server.controllers.filesystem import FilesystemController
|
from subiquity.server.controllers.filesystem import FilesystemController
|
||||||
|
|
||||||
from subiquitycore.tests.util import run_coro
|
|
||||||
from subiquitycore.tests.mocks import make_app
|
from subiquitycore.tests.mocks import make_app
|
||||||
from subiquity.common.types import Bootloader
|
from subiquity.common.types import Bootloader
|
||||||
from subiquity.models.tests.test_filesystem import (
|
from subiquity.models.tests.test_filesystem import (
|
||||||
|
@ -28,7 +27,7 @@ from subiquity.models.tests.test_filesystem import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSubiquityControllerFilesystem(TestCase):
|
class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.app = make_app()
|
self.app = make_app()
|
||||||
self.app.opts.bootloader = 'UEFI'
|
self.app.opts.bootloader = 'UEFI'
|
||||||
|
@ -38,20 +37,20 @@ class TestSubiquityControllerFilesystem(TestCase):
|
||||||
self.fsc = FilesystemController(app=self.app)
|
self.fsc = FilesystemController(app=self.app)
|
||||||
self.fsc._configured = True
|
self.fsc._configured = True
|
||||||
|
|
||||||
def test_probe_restricted(self):
|
async def test_probe_restricted(self):
|
||||||
run_coro(self.fsc._probe_once(context=None, restricted=True))
|
await self.fsc._probe_once(context=None, restricted=True)
|
||||||
self.app.prober.get_storage.assert_called_with({'blockdev'})
|
self.app.prober.get_storage.assert_called_with({'blockdev'})
|
||||||
|
|
||||||
def test_probe_os_prober_false(self):
|
async def test_probe_os_prober_false(self):
|
||||||
self.app.opts.use_os_prober = False
|
self.app.opts.use_os_prober = False
|
||||||
run_coro(self.fsc._probe_once(context=None, restricted=False))
|
await self.fsc._probe_once(context=None, restricted=False)
|
||||||
actual = self.app.prober.get_storage.call_args.args[0]
|
actual = self.app.prober.get_storage.call_args.args[0]
|
||||||
self.assertTrue({'defaults'} <= actual)
|
self.assertTrue({'defaults'} <= actual)
|
||||||
self.assertNotIn('os', actual)
|
self.assertNotIn('os', actual)
|
||||||
|
|
||||||
def test_probe_os_prober_true(self):
|
async def test_probe_os_prober_true(self):
|
||||||
self.app.opts.use_os_prober = True
|
self.app.opts.use_os_prober = True
|
||||||
run_coro(self.fsc._probe_once(context=None, restricted=False))
|
await self.fsc._probe_once(context=None, restricted=False)
|
||||||
actual = self.app.prober.get_storage.call_args.args[0]
|
actual = self.app.prober.get_storage.call_args.args[0]
|
||||||
self.assertTrue({'defaults', 'os'} <= actual)
|
self.assertTrue({'defaults', 'os'} <= actual)
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ from unittest.mock import Mock, patch, AsyncMock
|
||||||
|
|
||||||
from subiquitycore.tests import SubiTestCase
|
from subiquitycore.tests import SubiTestCase
|
||||||
from subiquitycore.tests.mocks import make_app
|
from subiquitycore.tests.mocks import make_app
|
||||||
from subiquitycore.tests.util import run_coro
|
|
||||||
from subiquity.server.apt import (
|
from subiquity.server.apt import (
|
||||||
AptConfigurer,
|
AptConfigurer,
|
||||||
Mountpoint,
|
Mountpoint,
|
||||||
|
@ -51,14 +50,14 @@ class TestAptConfigurer(SubiTestCase):
|
||||||
expected['apt']['https_proxy'] = proxy
|
expected['apt']['https_proxy'] = proxy
|
||||||
self.assertEqual(expected, self.configurer.apt_config())
|
self.assertEqual(expected, self.configurer.apt_config())
|
||||||
|
|
||||||
def test_mount_unmount(self):
|
async def test_mount_unmount(self):
|
||||||
# Make sure we can unmount something that we mounted before.
|
# Make sure we can unmount something that we mounted before.
|
||||||
with patch.object(self.app, "command_runner",
|
with patch.object(self.app, "command_runner",
|
||||||
create=True, new_callable=AsyncMock):
|
create=True, new_callable=AsyncMock):
|
||||||
m = run_coro(self.configurer.mount("/dev/cdrom", "/target"))
|
m = await self.configurer.mount("/dev/cdrom", "/target")
|
||||||
run_coro(self.configurer.unmount(m))
|
await self.configurer.unmount(m)
|
||||||
|
|
||||||
def test_overlay(self):
|
async def test_overlay(self):
|
||||||
self.configurer.install_tree = OverlayMountpoint(
|
self.configurer.install_tree = OverlayMountpoint(
|
||||||
upperdir="upperdir-install-tree",
|
upperdir="upperdir-install-tree",
|
||||||
lowers=["lowers1-install-tree"],
|
lowers=["lowers1-install-tree"],
|
||||||
|
@ -71,13 +70,10 @@ class TestAptConfigurer(SubiTestCase):
|
||||||
)
|
)
|
||||||
self.source = "source"
|
self.source = "source"
|
||||||
|
|
||||||
async def coro():
|
|
||||||
async with self.configurer.overlay():
|
|
||||||
pass
|
|
||||||
|
|
||||||
with patch.object(self.app, "command_runner",
|
with patch.object(self.app, "command_runner",
|
||||||
create=True, new_callable=AsyncMock):
|
create=True, new_callable=AsyncMock):
|
||||||
run_coro(coro())
|
async with self.configurer.overlay():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestLowerDirFor(SubiTestCase):
|
class TestLowerDirFor(SubiTestCase):
|
||||||
|
|
|
@ -17,7 +17,6 @@ from aioresponses import aioresponses
|
||||||
|
|
||||||
from subiquitycore.tests import SubiTestCase
|
from subiquitycore.tests import SubiTestCase
|
||||||
from subiquitycore.tests.mocks import make_app
|
from subiquitycore.tests.mocks import make_app
|
||||||
from subiquitycore.tests.util import run_coro
|
|
||||||
from subiquity.server.geoip import (
|
from subiquity.server.geoip import (
|
||||||
GeoIP,
|
GeoIP,
|
||||||
HTTPGeoIPStrategy,
|
HTTPGeoIPStrategy,
|
||||||
|
@ -48,16 +47,13 @@ empty_cc = '<Response><CountryCode></CountryCode></Response>'
|
||||||
|
|
||||||
|
|
||||||
class TestGeoIP(SubiTestCase):
|
class TestGeoIP(SubiTestCase):
|
||||||
def setUp(self):
|
async def asyncSetUp(self):
|
||||||
strategy = HTTPGeoIPStrategy()
|
strategy = HTTPGeoIPStrategy()
|
||||||
self.geoip = GeoIP(make_app(), strategy)
|
self.geoip = GeoIP(make_app(), strategy)
|
||||||
|
|
||||||
async def fn():
|
|
||||||
self.assertTrue(await self.geoip.lookup())
|
|
||||||
|
|
||||||
with aioresponses() as mocked:
|
with aioresponses() as mocked:
|
||||||
mocked.get("https://geoip.ubuntu.com/lookup", body=xml)
|
mocked.get("https://geoip.ubuntu.com/lookup", body=xml)
|
||||||
run_coro(fn())
|
self.assertTrue(await self.geoip.lookup())
|
||||||
|
|
||||||
def test_countrycode(self):
|
def test_countrycode(self):
|
||||||
self.assertEqual("us", self.geoip.countrycode)
|
self.assertEqual("us", self.geoip.countrycode)
|
||||||
|
@ -71,42 +67,32 @@ class TestGeoIPBadData(SubiTestCase):
|
||||||
strategy = HTTPGeoIPStrategy()
|
strategy = HTTPGeoIPStrategy()
|
||||||
self.geoip = GeoIP(make_app(), strategy)
|
self.geoip = GeoIP(make_app(), strategy)
|
||||||
|
|
||||||
def test_partial_reponse(self):
|
async def test_partial_reponse(self):
|
||||||
async def fn():
|
|
||||||
self.assertFalse(await self.geoip.lookup())
|
|
||||||
with aioresponses() as mocked:
|
with aioresponses() as mocked:
|
||||||
mocked.get("https://geoip.ubuntu.com/lookup", body=partial)
|
mocked.get("https://geoip.ubuntu.com/lookup", body=partial)
|
||||||
run_coro(fn())
|
|
||||||
|
|
||||||
def test_incomplete(self):
|
|
||||||
async def fn():
|
|
||||||
self.assertFalse(await self.geoip.lookup())
|
self.assertFalse(await self.geoip.lookup())
|
||||||
|
|
||||||
|
async def test_incomplete(self):
|
||||||
with aioresponses() as mocked:
|
with aioresponses() as mocked:
|
||||||
mocked.get("https://geoip.ubuntu.com/lookup", body=incomplete)
|
mocked.get("https://geoip.ubuntu.com/lookup", body=incomplete)
|
||||||
run_coro(fn())
|
self.assertFalse(await self.geoip.lookup())
|
||||||
self.assertIsNone(self.geoip.countrycode)
|
self.assertIsNone(self.geoip.countrycode)
|
||||||
self.assertIsNone(self.geoip.timezone)
|
self.assertIsNone(self.geoip.timezone)
|
||||||
|
|
||||||
def test_long_cc(self):
|
async def test_long_cc(self):
|
||||||
async def fn():
|
|
||||||
self.assertFalse(await self.geoip.lookup())
|
|
||||||
with aioresponses() as mocked:
|
with aioresponses() as mocked:
|
||||||
mocked.get("https://geoip.ubuntu.com/lookup", body=long_cc)
|
mocked.get("https://geoip.ubuntu.com/lookup", body=long_cc)
|
||||||
run_coro(fn())
|
self.assertFalse(await self.geoip.lookup())
|
||||||
self.assertIsNone(self.geoip.countrycode)
|
self.assertIsNone(self.geoip.countrycode)
|
||||||
|
|
||||||
def test_empty_cc(self):
|
async def test_empty_cc(self):
|
||||||
async def fn():
|
|
||||||
self.assertFalse(await self.geoip.lookup())
|
|
||||||
with aioresponses() as mocked:
|
with aioresponses() as mocked:
|
||||||
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_cc)
|
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_cc)
|
||||||
run_coro(fn())
|
self.assertFalse(await self.geoip.lookup())
|
||||||
self.assertIsNone(self.geoip.countrycode)
|
self.assertIsNone(self.geoip.countrycode)
|
||||||
|
|
||||||
def test_empty_tz(self):
|
async def test_empty_tz(self):
|
||||||
async def fn():
|
|
||||||
self.assertFalse(await self.geoip.lookup())
|
|
||||||
with aioresponses() as mocked:
|
with aioresponses() as mocked:
|
||||||
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_tz)
|
mocked.get("https://geoip.ubuntu.com/lookup", body=empty_tz)
|
||||||
run_coro(fn())
|
self.assertFalse(await self.geoip.lookup())
|
||||||
self.assertIsNone(self.geoip.timezone)
|
self.assertIsNone(self.geoip.timezone)
|
||||||
|
|
|
@ -26,37 +26,36 @@ from subiquity.server.ubuntu_advantage import (
|
||||||
MockedUAInterfaceStrategy,
|
MockedUAInterfaceStrategy,
|
||||||
UAClientUAInterfaceStrategy,
|
UAClientUAInterfaceStrategy,
|
||||||
)
|
)
|
||||||
from subiquitycore.tests.util import run_coro
|
|
||||||
|
|
||||||
|
|
||||||
class TestMockedUAInterfaceStrategy(unittest.TestCase):
|
class TestMockedUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
|
self.strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
|
||||||
|
|
||||||
def test_query_info_invalid(self):
|
async def test_query_info_invalid(self):
|
||||||
# Tokens starting with "i" in dry-run mode cause the token to be
|
# Tokens starting with "i" in dry-run mode cause the token to be
|
||||||
# reported as invalid.
|
# reported as invalid.
|
||||||
with self.assertRaises(InvalidTokenError):
|
with self.assertRaises(InvalidTokenError):
|
||||||
run_coro(self.strategy.query_info(token="invalidToken"))
|
await self.strategy.query_info(token="invalidToken")
|
||||||
|
|
||||||
def test_query_info_failure(self):
|
async def test_query_info_failure(self):
|
||||||
# Tokens starting with "f" in dry-run mode simulate an "internal"
|
# Tokens starting with "f" in dry-run mode simulate an "internal"
|
||||||
# error.
|
# error.
|
||||||
with self.assertRaises(CheckSubscriptionError):
|
with self.assertRaises(CheckSubscriptionError):
|
||||||
run_coro(self.strategy.query_info(token="failure"))
|
await self.strategy.query_info(token="failure")
|
||||||
|
|
||||||
def test_query_info_expired(self):
|
async def test_query_info_expired(self):
|
||||||
# Tokens starting with "x" is dry-run mode simulate an expired token.
|
# Tokens starting with "x" is dry-run mode simulate an expired token.
|
||||||
info = run_coro(self.strategy.query_info(token="xpiredToken"))
|
info = await self.strategy.query_info(token="xpiredToken")
|
||||||
self.assertEqual(info["expires"], "2010-12-31T00:00:00+00:00")
|
self.assertEqual(info["expires"], "2010-12-31T00:00:00+00:00")
|
||||||
|
|
||||||
def test_query_info_valid(self):
|
async def test_query_info_valid(self):
|
||||||
# Other tokens are considered valid in dry-run mode.
|
# Other tokens are considered valid in dry-run mode.
|
||||||
info = run_coro(self.strategy.query_info(token="validToken"))
|
info = await self.strategy.query_info(token="validToken")
|
||||||
self.assertEqual(info["expires"], "2035-12-31T00:00:00+00:00")
|
self.assertEqual(info["expires"], "2035-12-31T00:00:00+00:00")
|
||||||
|
|
||||||
|
|
||||||
class TestUAClientUAInterfaceStrategy(unittest.TestCase):
|
class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase):
|
||||||
arun_command_sym = "subiquity.server.ubuntu_advantage.utils.arun_command"
|
arun_command_sym = "subiquity.server.ubuntu_advantage.utils.arun_command"
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
|
@ -75,7 +74,7 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
|
||||||
self.assertEqual(strategy.executable,
|
self.assertEqual(strategy.executable,
|
||||||
["python3", "/usr/bin/ubuntu-advantage"])
|
["python3", "/usr/bin/ubuntu-advantage"])
|
||||||
|
|
||||||
def test_query_info_succeeded(self):
|
async def test_query_info_succeeded(self):
|
||||||
strategy = UAClientUAInterfaceStrategy()
|
strategy = UAClientUAInterfaceStrategy()
|
||||||
command = (
|
command = (
|
||||||
"ubuntu-advantage",
|
"ubuntu-advantage",
|
||||||
|
@ -87,10 +86,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
|
||||||
with patch(self.arun_command_sym) as mock_arun:
|
with patch(self.arun_command_sym) as mock_arun:
|
||||||
mock_arun.return_value = CompletedProcess([], 0)
|
mock_arun.return_value = CompletedProcess([], 0)
|
||||||
mock_arun.return_value.stdout = "{}"
|
mock_arun.return_value.stdout = "{}"
|
||||||
run_coro(strategy.query_info(token="123456789"))
|
await strategy.query_info(token="123456789")
|
||||||
mock_arun.assert_called_once_with(command, check=True)
|
mock_arun.assert_called_once_with(command, check=True)
|
||||||
|
|
||||||
def test_query_info_failed(self):
|
async def test_query_info_failed(self):
|
||||||
strategy = UAClientUAInterfaceStrategy()
|
strategy = UAClientUAInterfaceStrategy()
|
||||||
command = (
|
command = (
|
||||||
"ubuntu-advantage",
|
"ubuntu-advantage",
|
||||||
|
@ -104,10 +103,10 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
|
||||||
cmd=command)
|
cmd=command)
|
||||||
mock_arun.return_value.stdout = "{}"
|
mock_arun.return_value.stdout = "{}"
|
||||||
with self.assertRaises(CheckSubscriptionError):
|
with self.assertRaises(CheckSubscriptionError):
|
||||||
run_coro(strategy.query_info(token="123456789"))
|
await strategy.query_info(token="123456789")
|
||||||
mock_arun.assert_called_once_with(command, check=True)
|
mock_arun.assert_called_once_with(command, check=True)
|
||||||
|
|
||||||
def test_query_info_invalid_json(self):
|
async def test_query_info_invalid_json(self):
|
||||||
strategy = UAClientUAInterfaceStrategy()
|
strategy = UAClientUAInterfaceStrategy()
|
||||||
command = (
|
command = (
|
||||||
"ubuntu-advantage",
|
"ubuntu-advantage",
|
||||||
|
@ -120,31 +119,31 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
|
||||||
mock_arun.return_value = CompletedProcess([], 0)
|
mock_arun.return_value = CompletedProcess([], 0)
|
||||||
mock_arun.return_value.stdout = "invalid-json"
|
mock_arun.return_value.stdout = "invalid-json"
|
||||||
with self.assertRaises(CheckSubscriptionError):
|
with self.assertRaises(CheckSubscriptionError):
|
||||||
run_coro(strategy.query_info(token="123456789"))
|
await strategy.query_info(token="123456789")
|
||||||
mock_arun.assert_called_once_with(command, check=True)
|
mock_arun.assert_called_once_with(command, check=True)
|
||||||
|
|
||||||
|
|
||||||
class TestUAInterface(unittest.TestCase):
|
class TestUAInterface(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
def test_mocked_get_activable_services(self):
|
async def test_mocked_get_activable_services(self):
|
||||||
strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
|
strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000)
|
||||||
interface = UAInterface(strategy)
|
interface = UAInterface(strategy)
|
||||||
|
|
||||||
with self.assertRaises(InvalidTokenError):
|
with self.assertRaises(InvalidTokenError):
|
||||||
run_coro(interface.get_activable_services(token="invalidToken"))
|
await interface.get_activable_services(token="invalidToken")
|
||||||
# Tokens starting with "f" in dry-run mode simulate an "internal"
|
# Tokens starting with "f" in dry-run mode simulate an "internal"
|
||||||
# error.
|
# error.
|
||||||
with self.assertRaises(CheckSubscriptionError):
|
with self.assertRaises(CheckSubscriptionError):
|
||||||
run_coro(interface.get_activable_services(token="failure"))
|
await interface.get_activable_services(token="failure")
|
||||||
|
|
||||||
# Tokens starting with "x" is dry-run mode simulate an expired token.
|
# Tokens starting with "x" is dry-run mode simulate an expired token.
|
||||||
with self.assertRaises(ExpiredTokenError):
|
with self.assertRaises(ExpiredTokenError):
|
||||||
run_coro(interface.get_activable_services(token="xpiredToken"))
|
await interface.get_activable_services(token="xpiredToken")
|
||||||
|
|
||||||
# Other tokens are considered valid in dry-run mode.
|
# Other tokens are considered valid in dry-run mode.
|
||||||
run_coro(interface.get_activable_services(token="validToken"))
|
await interface.get_activable_services(token="validToken")
|
||||||
|
|
||||||
def test_get_activable_services(self):
|
async def test_get_activable_services(self):
|
||||||
# We use the standard strategy but don't actually run it
|
# We use the standard strategy but don't actually run it
|
||||||
strategy = UAClientUAInterfaceStrategy()
|
strategy = UAClientUAInterfaceStrategy()
|
||||||
interface = UAInterface(strategy)
|
interface = UAInterface(strategy)
|
||||||
|
@ -185,8 +184,7 @@ class TestUAInterface(unittest.TestCase):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
interface.get_subscription = AsyncMock(return_value=subscription)
|
interface.get_subscription = AsyncMock(return_value=subscription)
|
||||||
services = run_coro(
|
services = await interface.get_activable_services(token="XXX")
|
||||||
interface.get_activable_services(token="XXX"))
|
|
||||||
|
|
||||||
self.assertIn(UbuntuProService(
|
self.assertIn(UbuntuProService(
|
||||||
name="esm-infra",
|
name="esm-infra",
|
||||||
|
@ -211,5 +209,4 @@ class TestUAInterface(unittest.TestCase):
|
||||||
|
|
||||||
# Test with "Z" suffix for the expiration date.
|
# Test with "Z" suffix for the expiration date.
|
||||||
subscription["expires"] = "2035-12-31T00:00:00Z"
|
subscription["expires"] = "2035-12-31T00:00:00Z"
|
||||||
services = run_coro(
|
services = await interface.get_activable_services(token="XXX")
|
||||||
interface.get_activable_services(token="XXX"))
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ import unittest
|
||||||
from unittest.mock import patch, AsyncMock, Mock
|
from unittest.mock import patch, AsyncMock, Mock
|
||||||
|
|
||||||
from subiquitycore.tests.mocks import make_app
|
from subiquitycore.tests.mocks import make_app
|
||||||
from subiquitycore.tests.util import run_coro
|
|
||||||
|
|
||||||
from subiquity.server.ubuntu_drivers import (
|
from subiquity.server.ubuntu_drivers import (
|
||||||
UbuntuDriversInterface,
|
UbuntuDriversInterface,
|
||||||
|
@ -28,7 +27,7 @@ from subiquity.server.ubuntu_drivers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestUbuntuDriversInterface(unittest.TestCase):
|
class TestUbuntuDriversInterface(unittest.IsolatedAsyncioTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.app = make_app()
|
self.app = make_app()
|
||||||
|
|
||||||
|
@ -54,11 +53,11 @@ class TestUbuntuDriversInterface(unittest.TestCase):
|
||||||
|
|
||||||
@patch.multiple(UbuntuDriversInterface, __abstractmethods__=set())
|
@patch.multiple(UbuntuDriversInterface, __abstractmethods__=set())
|
||||||
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
|
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
|
||||||
def test_install_drivers(self, mock_run_curtin_command):
|
async def test_install_drivers(self, mock_run_curtin_command):
|
||||||
ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False)
|
ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False)
|
||||||
run_coro(ubuntu_drivers.install_drivers(
|
await ubuntu_drivers.install_drivers(
|
||||||
root_dir="/target",
|
root_dir="/target",
|
||||||
context="installing third-party drivers"))
|
context="installing third-party drivers")
|
||||||
mock_run_curtin_command.assert_called_once_with(
|
mock_run_curtin_command.assert_called_once_with(
|
||||||
self.app, "installing third-party drivers",
|
self.app, "installing third-party drivers",
|
||||||
"in-target", "-t", "/target",
|
"in-target", "-t", "/target",
|
||||||
|
@ -91,18 +90,18 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
|
||||||
["nvidia-driver-470", "nvidia-driver-510"])
|
["nvidia-driver-470", "nvidia-driver-510"])
|
||||||
|
|
||||||
|
|
||||||
class TestUbuntuDriversClientInterface(unittest.TestCase):
|
class TestUbuntuDriversClientInterface(unittest.IsolatedAsyncioTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.app = make_app()
|
self.app = make_app()
|
||||||
self.ubuntu_drivers = UbuntuDriversClientInterface(
|
self.ubuntu_drivers = UbuntuDriversClientInterface(
|
||||||
self.app, gpgpu=False)
|
self.app, gpgpu=False)
|
||||||
|
|
||||||
def test_ensure_cmd_exists(self):
|
async def test_ensure_cmd_exists(self):
|
||||||
with patch.object(
|
with patch.object(
|
||||||
self.app, "command_runner",
|
self.app, "command_runner",
|
||||||
create=True, new_callable=AsyncMock) as mock_runner:
|
create=True, new_callable=AsyncMock) as mock_runner:
|
||||||
# On success
|
# On success
|
||||||
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
|
await self.ubuntu_drivers.ensure_cmd_exists("/target")
|
||||||
mock_runner.run.assert_called_once_with(
|
mock_runner.run.assert_called_once_with(
|
||||||
[
|
[
|
||||||
"chroot", "/target",
|
"chroot", "/target",
|
||||||
|
@ -115,17 +114,17 @@ class TestUbuntuDriversClientInterface(unittest.TestCase):
|
||||||
cmd=["sh", "-c", "command -v ubuntu-drivers"])
|
cmd=["sh", "-c", "command -v ubuntu-drivers"])
|
||||||
|
|
||||||
with self.assertRaises(CommandNotFoundError):
|
with self.assertRaises(CommandNotFoundError):
|
||||||
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
|
await self.ubuntu_drivers.ensure_cmd_exists("/target")
|
||||||
|
|
||||||
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
|
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
|
||||||
def test_list_drivers(self, mock_run_curtin_command):
|
async def test_list_drivers(self, mock_run_curtin_command):
|
||||||
# Make sure this gets decoded as utf-8.
|
# Make sure this gets decoded as utf-8.
|
||||||
mock_run_curtin_command.return_value = Mock(stdout=b"""\
|
mock_run_curtin_command.return_value = Mock(stdout=b"""\
|
||||||
nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
|
nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
|
||||||
""")
|
""")
|
||||||
drivers = run_coro(self.ubuntu_drivers.list_drivers(
|
drivers = await self.ubuntu_drivers.list_drivers(
|
||||||
root_dir="/target",
|
root_dir="/target",
|
||||||
context="listing third-party drivers"))
|
context="listing third-party drivers")
|
||||||
|
|
||||||
mock_run_curtin_command.assert_called_once_with(
|
mock_run_curtin_command.assert_called_once_with(
|
||||||
self.app, "listing third-party drivers",
|
self.app, "listing third-party drivers",
|
||||||
|
@ -137,15 +136,15 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
|
||||||
self.assertEqual(drivers, ["nvidia-driver-510"])
|
self.assertEqual(drivers, ["nvidia-driver-510"])
|
||||||
|
|
||||||
|
|
||||||
class TestUbuntuDriversRunDriversInterface(unittest.TestCase):
|
class TestUbuntuDriversRunDriversInterface(unittest.IsolatedAsyncioTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.app = make_app()
|
self.app = make_app()
|
||||||
self.ubuntu_drivers = UbuntuDriversRunDriversInterface(
|
self.ubuntu_drivers = UbuntuDriversRunDriversInterface(
|
||||||
self.app, gpgpu=False)
|
self.app, gpgpu=False)
|
||||||
|
|
||||||
@patch("subiquity.server.ubuntu_drivers.arun_command")
|
@patch("subiquity.server.ubuntu_drivers.arun_command")
|
||||||
def test_ensure_cmd_exists(self, mock_arun_command):
|
async def test_ensure_cmd_exists(self, mock_arun_command):
|
||||||
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
|
await self.ubuntu_drivers.ensure_cmd_exists("/target")
|
||||||
mock_arun_command.assert_called_once_with(
|
mock_arun_command.assert_called_once_with(
|
||||||
["sh", "-c", "command -v ubuntu-drivers"],
|
["sh", "-c", "command -v ubuntu-drivers"],
|
||||||
check=True)
|
check=True)
|
||||||
|
|
|
@ -7,7 +7,6 @@ from functools import wraps
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
|
@ -134,7 +133,7 @@ class Server(Client):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestAPI(unittest.IsolatedAsyncioTestCase, SubiTestCase):
|
class TestAPI(SubiTestCase):
|
||||||
class _MachineConfig(os.PathLike):
|
class _MachineConfig(os.PathLike):
|
||||||
def __init__(self, outer, path):
|
def __init__(self, outer, path):
|
||||||
self.outer = outer
|
self.outer = outer
|
||||||
|
|
|
@ -20,7 +20,6 @@ from subiquity.models.timezone import TimeZoneModel
|
||||||
from subiquity.server.controllers.timezone import TimeZoneController
|
from subiquity.server.controllers.timezone import TimeZoneController
|
||||||
from subiquitycore.tests import SubiTestCase
|
from subiquitycore.tests import SubiTestCase
|
||||||
from subiquitycore.tests.mocks import make_app
|
from subiquitycore.tests.mocks import make_app
|
||||||
from subiquitycore.tests.util import run_coro
|
|
||||||
|
|
||||||
|
|
||||||
class MockGeoIP:
|
class MockGeoIP:
|
||||||
|
@ -47,7 +46,7 @@ class TestTimeZoneController(SubiTestCase):
|
||||||
|
|
||||||
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
|
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
|
||||||
@mock.patch('subiquity.server.controllers.timezone.timedatectl_gettz')
|
@mock.patch('subiquity.server.controllers.timezone.timedatectl_gettz')
|
||||||
def test_good_tzs(self, tdc_gettz, tdc_settz):
|
async def test_good_tzs(self, tdc_gettz, tdc_settz):
|
||||||
tdc_gettz.return_value = tz_utc
|
tdc_gettz.return_value = tz_utc
|
||||||
goods = [
|
goods = [
|
||||||
# val - autoinstall value
|
# val - autoinstall value
|
||||||
|
@ -77,7 +76,7 @@ class TestTimeZoneController(SubiTestCase):
|
||||||
self.tzc.model)
|
self.tzc.model)
|
||||||
self.assertEqual(geoip, self.tzc.model.detect_with_geoip,
|
self.assertEqual(geoip, self.tzc.model.detect_with_geoip,
|
||||||
self.tzc.model)
|
self.tzc.model)
|
||||||
self.assertEqual(tz, run_coro(self.tzc.GET()), self.tzc.model)
|
self.assertEqual(tz, await self.tzc.GET(), self.tzc.model)
|
||||||
cloudconfig = {}
|
cloudconfig = {}
|
||||||
if self.tzc.model.should_set_tz:
|
if self.tzc.model.should_set_tz:
|
||||||
cloudconfig = {'timezone': tz.timezone}
|
cloudconfig = {'timezone': tz.timezone}
|
||||||
|
@ -104,7 +103,7 @@ class TestTimeZoneController(SubiTestCase):
|
||||||
self.assertEqual('sleep', subprocess_run.call_args.args[0][0])
|
self.assertEqual('sleep', subprocess_run.call_args.args[0][0])
|
||||||
|
|
||||||
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
|
@mock.patch('subiquity.server.controllers.timezone.timedatectl_settz')
|
||||||
def test_get_tz_should_not_set(self, tdc_settz):
|
async def test_get_tz_should_not_set(self, tdc_settz):
|
||||||
run_coro(self.tzc.GET())
|
await self.tzc.GET()
|
||||||
self.assertFalse(self.tzc.model.should_set_tz)
|
self.assertFalse(self.tzc.model.should_set_tz)
|
||||||
tdc_settz.assert_not_called()
|
tdc_settz.assert_not_called()
|
||||||
|
|
|
@ -1,4 +1,18 @@
|
||||||
import asyncio
|
# Copyright 2017-2022 Canonical, Ltd.
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
@ -42,20 +56,6 @@ system_reserved = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def tearDownModule() -> None:
|
|
||||||
# Set empty loop policy, so that subsequent get_event_loop() returns a new
|
|
||||||
# loop. If there is no running event loop set, that function will return
|
|
||||||
# the result of `get_event_loop_policy().get_event_loop()` call. If there
|
|
||||||
# is a policy there must be a running loop. IsolatedAsyncioTestCase
|
|
||||||
# closes the loop during tear down, though. It doesn't touch the policy.
|
|
||||||
# By having it as None, it autoinits and the next tests run smoothly.
|
|
||||||
# Another approach would be set a new event_loop for the current policy on
|
|
||||||
# test fixture tearDown, as pytest-asyncio does.
|
|
||||||
# Either way we would prevent failure on tests that depend on [run_coro]
|
|
||||||
# (subiquitycore/tests/util.py), for instance.
|
|
||||||
asyncio.set_event_loop_policy(None)
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityViewTests(unittest.IsolatedAsyncioTestCase):
|
class IdentityViewTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
def make_view(self):
|
def make_view(self):
|
||||||
|
|
|
@ -3,10 +3,10 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from unittest import TestCase
|
from unittest import IsolatedAsyncioTestCase
|
||||||
|
|
||||||
|
|
||||||
class SubiTestCase(TestCase):
|
class SubiTestCase(IsolatedAsyncioTestCase):
|
||||||
def tmp_dir(self, dir=None, cleanup=True):
|
def tmp_dir(self, dir=None, cleanup=True):
|
||||||
# return a full path to a temporary directory that will be cleaned up.
|
# return a full path to a temporary directory that will be cleaned up.
|
||||||
if dir is None:
|
if dir is None:
|
||||||
|
|
|
@ -13,14 +13,9 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
|
|
||||||
def run_coro(coro):
|
|
||||||
return asyncio.get_event_loop().run_until_complete(coro)
|
|
||||||
|
|
||||||
|
|
||||||
def random_string():
|
def random_string():
|
||||||
return ''.join(random.choice(string.ascii_letters) for _ in range(8))
|
return ''.join(random.choice(string.ascii_letters) for _ in range(8))
|
||||||
|
|
Loading…
Reference in New Issue