Merge pull request #1424 from mwhudson/api-enhancements

api extensions needed to talk to snapd API
This commit is contained in:
Michael Hudson-Doyle 2022-09-21 14:24:25 +12:00 committed by GitHub
commit c063717897
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 266 additions and 23 deletions

View File

@ -22,7 +22,7 @@ from subiquity.common.serialize import Serializer
from .defs import Payload
def _wrap(make_request, path, meth, serializer):
def _wrap(make_request, path, meth, serializer, serialize_query_args):
sig = inspect.signature(meth)
meth_params = sig.parameters
payload_arg = None
@ -32,7 +32,7 @@ def _wrap(make_request, path, meth, serializer):
payload_ann = param.annotation.__args__[0]
r_ann = sig.return_annotation
async def impl(*args, **kw):
async def impl(self, *args, **kw):
args = sig.bind(*args, **kw)
query_args = {}
data = None
@ -40,29 +40,56 @@ def _wrap(make_request, path, meth, serializer):
if arg_name == payload_arg:
data = serializer.serialize(payload_ann, value)
else:
query_args[arg_name] = serializer.to_json(
if serialize_query_args:
value = serializer.to_json(
meth_params[arg_name].annotation, value)
query_args[arg_name] = value
async with make_request(
meth.__name__, path, json=data, params=query_args) as resp:
meth.__name__, path.format(**self.path_args),
json=data, params=query_args) as resp:
resp.raise_for_status()
return serializer.deserialize(r_ann, await resp.json())
return impl
def make_client(endpoint_cls, make_request, serializer=None):
def make_getitem(endpoint_cls, make_request, serializer):
cls = make_client_cls(endpoint_cls, make_request, serializer)
def gi(self, item):
new_args = self.path_args.copy()
new_args[endpoint_cls.__shortname__] = item
return cls(new_args)
return gi
def client_init(self, path_args=None):
if path_args is None:
path_args = {}
self.path_args = path_args
def make_client_cls(endpoint_cls, make_request, serializer=None):
if serializer is None:
serializer = Serializer()
class C:
pass
ns = {'__init__': client_init}
for k, v in endpoint_cls.__dict__.items():
if isinstance(v, type):
setattr(C, k, make_client(v, make_request, serializer))
if getattr(v, '__parameter__', False):
ns['__getitem__'] = make_getitem(v, make_request, serializer)
else:
ns[k] = make_client(v, make_request, serializer)
elif callable(v):
setattr(C, k, _wrap(
make_request, endpoint_cls.fullpath, v, serializer))
return C
ns[k] = _wrap(make_request, endpoint_cls.fullpath, v, serializer,
endpoint_cls.serialize_query_args)
return type('ClientFor({})'.format(endpoint_cls.__name__), (object,), ns)
def make_client(endpoint_cls, make_request, serializer=None, path_args=None):
return make_client_cls(endpoint_cls, make_request, serializer)(path_args)
def make_client_for_conn(

View File

@ -13,18 +13,70 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import inspect
import typing
def api(cls, prefix=(), foo=None):
cls.fullpath = '/' + '/'.join(prefix)
cls.fullname = prefix
class InvalidAPIDefinition(Exception):
pass
class InvalidQueryArgs(InvalidAPIDefinition):
def __init__(self, callable, param):
self.callable = callable
self.param = param
def __str__(self):
return (f"{self.callable.__qualname__} does not serialize query "
f"arguments but has non-str parameter '{self.param}'")
class MultiplePathParameters(InvalidAPIDefinition):
def __init__(self, cls, param1, param2):
self.cls = cls
self.param1 = param1
self.param2 = param2
def __str__(self):
return (f"{self.cls.__name__} has multiple path parameters "
f"{self.param1!r} and {self.param2!r}")
def api(cls, prefix_names=(), prefix_path=(), path_params=(),
serialize_query_args=True):
if hasattr(cls, 'serialize_query_args'):
serialize_query_args = cls.serialize_query_args
else:
cls.serialize_query_args = serialize_query_args
cls.fullpath = '/' + '/'.join(prefix_path)
cls.fullname = prefix_names
seen_path_param = None
for k, v in cls.__dict__.items():
if isinstance(v, type):
v.__shortname__ = k
v.__name__ = cls.__name__ + '.' + k
api(v, prefix + (k,))
path_part = k
path_param = ()
if getattr(v, '__parameter__', False):
if seen_path_param:
raise MultiplePathParameters(cls, seen_path_param, k)
seen_path_param = k
path_part = '{' + path_part + '}'
path_param = (k,)
api(
v,
prefix_names + (k,),
prefix_path + (path_part,),
path_params + path_param,
serialize_query_args)
if callable(v):
v.__qualname__ = cls.__name__ + '.' + k
if not cls.serialize_query_args:
params = inspect.signature(v).parameters
for param_name, param in params.items():
if param.annotation is not str:
raise InvalidQueryArgs(v, param)
v.__path_params__ = path_params
return cls
@ -35,6 +87,11 @@ class Payload(typing.Generic[T]):
pass
def path_parameter(cls):
cls.__parameter__ = True
return cls
def simple_endpoint(typ):
class endpoint:
def GET() -> typ: ...

View File

@ -57,7 +57,8 @@ def trim(text):
return text
def _make_handler(controller, definition, implementation, serializer):
def _make_handler(controller, definition, implementation, serializer,
serialize_query_args):
def_sig = inspect.signature(definition)
def_ret_ann = def_sig.return_annotation
def_params = def_sig.parameters
@ -71,6 +72,13 @@ def _make_handler(controller, definition, implementation, serializer):
check_def_params = []
for param_name in definition.__path_params__:
check_def_params.append(
inspect.Parameter(
param_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=str))
for param_name, param in def_params.items():
if param_name in ('request', 'context'):
raise Exception(
@ -108,13 +116,17 @@ def _make_handler(controller, definition, implementation, serializer):
data_annotation, await request.text())
for arg, ann, default in query_args_anns:
if arg in request.query:
v = serializer.from_json(ann, request.query[arg])
v = request.query[arg]
if serialize_query_args:
v = serializer.from_json(ann, v)
elif default != inspect._empty:
v = default
else:
raise TypeError(
'missing required argument "{}"'.format(arg))
args[arg] = v
for param_name in definition.__path_params__:
args[param_name] = request.match_info[param_name]
if 'context' in impl_params:
args['context'] = context
if 'request' in impl_params:
@ -164,7 +176,9 @@ def bind(router, endpoint, controller, serializer=None, _depth=None):
router.add_route(
method=method,
path=endpoint.fullpath,
handler=_make_handler(controller, v, impl, serializer))
handler=_make_handler(
controller, v, impl, serializer,
endpoint.serialize_query_args))
async def make_server_at_path(socket_path, endpoint, controller, **kw):

View File

@ -17,7 +17,12 @@ import contextlib
import unittest
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import api, Payload
from subiquity.common.api.defs import (
api,
InvalidQueryArgs,
path_parameter,
Payload,
)
def extract(c):
@ -89,3 +94,59 @@ class TestClient(unittest.TestCase):
r = extract(client.GET(arg='v'))
self.assertEqual(r, '"v"')
self.assertEqual(requests, [("GET", '/', {'arg': '"v"'}, None)])
def test_path_params(self):
@api
class API:
@path_parameter
class param:
def GET(arg: str) -> str: ...
@contextlib.asynccontextmanager
async def make_request(method, path, *, params, json):
requests.append((method, path, params, json))
yield FakeResponse(params['arg'])
client = make_client(API, make_request)
requests = []
r = extract(client['foo'].GET(arg='v'))
self.assertEqual(r, '"v"')
self.assertEqual(requests, [("GET", '/foo', {'arg': '"v"'}, None)])
def test_serialize_query_args(self):
@api
class API:
serialize_query_args = False
def GET(arg: str) -> str: ...
class meth:
def GET(arg: str) -> str: ...
class more:
serialize_query_args = True
def GET(arg: str) -> str: ...
@contextlib.asynccontextmanager
async def make_request(method, path, *, params, json):
requests.append((method, path, params, json))
yield FakeResponse(params['arg'])
client = make_client(API, make_request)
requests = []
extract(client.GET(arg='v'))
extract(client.meth.GET(arg='v'))
extract(client.meth.more.GET(arg='v'))
self.assertEqual(requests, [
("GET", '/', {'arg': 'v'}, None),
("GET", '/meth', {'arg': 'v'}, None),
("GET", '/meth/more', {'arg': '"v"'}, None)])
class API2:
serialize_query_args = False
def GET(arg: int) -> str: ...
with self.assertRaises(InvalidQueryArgs):
api(API2)

View File

@ -22,7 +22,12 @@ import aiohttp
from aiohttp import web
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import api, Payload
from subiquity.common.api.defs import (
api,
MultiplePathParameters,
path_parameter,
Payload,
)
from .test_server import (
makeTestClient,
@ -85,13 +90,44 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
self.assertEqual(
await client.GET(arg1="A", arg2="B"), 'A+B')
async def test_path_args(self):
@api
class API:
@path_parameter
class param1:
@path_parameter
class param2:
def GET(arg1: str, arg2: str) -> str: ...
class Impl(ControllerBase):
async def param1_param2_GET(self, param1: str, param2: str,
arg1: str, arg2: str) -> str:
return '{}+{}+{}+{}'.format(param1, param2, arg1, arg2)
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client["1"]["2"].GET(arg1="A", arg2="B"), '1+2+A+B')
def test_only_one_path_parameter(self):
class API:
@path_parameter
class param1:
def GET(arg: str) -> str: ...
@path_parameter
class param2:
def GET(arg: str) -> str: ...
with self.assertRaises(MultiplePathParameters):
api(API)
async def test_defaults(self):
@api
class API:
def GET(arg1: str, arg2: str = "arg2") -> str: ...
def GET(arg1: str = "arg1", arg2: str = "arg2") -> str: ...
class Impl(ControllerBase):
async def GET(self, arg1: str, arg2: str = "arg2") -> str:
async def GET(self, arg1: str = "arg1", arg2: str = "arg2") -> str:
return '{}+{}'.format(arg1, arg2)
async with makeE2EClient(API, Impl()) as client:
@ -99,6 +135,8 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
await client.GET(arg1="A", arg2="B"), 'A+B')
self.assertEqual(
await client.GET(arg1="A"), 'A+arg2')
self.assertEqual(
await client.GET(arg2="B"), 'arg1+B')
async def test_post(self):
@api

View File

@ -21,7 +21,7 @@ from aiohttp import web
from subiquitycore.context import Context
from subiquity.common.api.defs import api, Payload
from subiquity.common.api.defs import api, path_parameter, Payload
from subiquity.common.api.server import (
bind,
controller_for_request,
@ -233,3 +233,49 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
self.assertEqual(await resp.json(), '')
self.assertIs(impl, seen_controller)
async def test_serialize_query_args(self):
@api
class API:
serialize_query_args = False
def GET(arg: str) -> str: ...
class meth:
def GET(arg: str) -> str: ...
class more:
serialize_query_args = True
def GET(arg: str) -> str: ...
class Impl(ControllerBase):
async def GET(self, arg: str) -> str:
return arg
async def meth_GET(self, arg: str) -> str:
return arg
async def meth_more_GET(self, arg: str) -> str:
return arg
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get('/?arg=whut'), 'whut')
await self.assertResponse(
client.get('/meth?arg=whut'), 'whut')
await self.assertResponse(
client.get('/meth/more?arg="whut"'), 'whut')
async def test_path_parameters(self):
@api
class API:
@path_parameter
class param:
def GET(arg: int): ...
class Impl(ControllerBase):
async def param_GET(self, param: str, arg: int):
return param + str(arg)
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get('/value?arg=2'), 'value2')