Merge pull request #1424 from mwhudson/api-enhancements
api extensions needed to talk to snapd API
This commit is contained in:
commit
c063717897
|
@ -22,7 +22,7 @@ from subiquity.common.serialize import Serializer
|
|||
from .defs import Payload
|
||||
|
||||
|
||||
def _wrap(make_request, path, meth, serializer):
|
||||
def _wrap(make_request, path, meth, serializer, serialize_query_args):
|
||||
sig = inspect.signature(meth)
|
||||
meth_params = sig.parameters
|
||||
payload_arg = None
|
||||
|
@ -32,7 +32,7 @@ def _wrap(make_request, path, meth, serializer):
|
|||
payload_ann = param.annotation.__args__[0]
|
||||
r_ann = sig.return_annotation
|
||||
|
||||
async def impl(*args, **kw):
|
||||
async def impl(self, *args, **kw):
|
||||
args = sig.bind(*args, **kw)
|
||||
query_args = {}
|
||||
data = None
|
||||
|
@ -40,29 +40,56 @@ def _wrap(make_request, path, meth, serializer):
|
|||
if arg_name == payload_arg:
|
||||
data = serializer.serialize(payload_ann, value)
|
||||
else:
|
||||
query_args[arg_name] = serializer.to_json(
|
||||
if serialize_query_args:
|
||||
value = serializer.to_json(
|
||||
meth_params[arg_name].annotation, value)
|
||||
query_args[arg_name] = value
|
||||
async with make_request(
|
||||
meth.__name__, path, json=data, params=query_args) as resp:
|
||||
meth.__name__, path.format(**self.path_args),
|
||||
json=data, params=query_args) as resp:
|
||||
resp.raise_for_status()
|
||||
return serializer.deserialize(r_ann, await resp.json())
|
||||
return impl
|
||||
|
||||
|
||||
def make_client(endpoint_cls, make_request, serializer=None):
|
||||
def make_getitem(endpoint_cls, make_request, serializer):
|
||||
cls = make_client_cls(endpoint_cls, make_request, serializer)
|
||||
|
||||
def gi(self, item):
|
||||
new_args = self.path_args.copy()
|
||||
new_args[endpoint_cls.__shortname__] = item
|
||||
return cls(new_args)
|
||||
|
||||
return gi
|
||||
|
||||
|
||||
def client_init(self, path_args=None):
|
||||
if path_args is None:
|
||||
path_args = {}
|
||||
self.path_args = path_args
|
||||
|
||||
|
||||
def make_client_cls(endpoint_cls, make_request, serializer=None):
|
||||
if serializer is None:
|
||||
serializer = Serializer()
|
||||
|
||||
class C:
|
||||
pass
|
||||
ns = {'__init__': client_init}
|
||||
|
||||
for k, v in endpoint_cls.__dict__.items():
|
||||
if isinstance(v, type):
|
||||
setattr(C, k, make_client(v, make_request, serializer))
|
||||
if getattr(v, '__parameter__', False):
|
||||
ns['__getitem__'] = make_getitem(v, make_request, serializer)
|
||||
else:
|
||||
ns[k] = make_client(v, make_request, serializer)
|
||||
elif callable(v):
|
||||
setattr(C, k, _wrap(
|
||||
make_request, endpoint_cls.fullpath, v, serializer))
|
||||
return C
|
||||
ns[k] = _wrap(make_request, endpoint_cls.fullpath, v, serializer,
|
||||
endpoint_cls.serialize_query_args)
|
||||
|
||||
return type('ClientFor({})'.format(endpoint_cls.__name__), (object,), ns)
|
||||
|
||||
|
||||
def make_client(endpoint_cls, make_request, serializer=None, path_args=None):
|
||||
return make_client_cls(endpoint_cls, make_request, serializer)(path_args)
|
||||
|
||||
|
||||
def make_client_for_conn(
|
||||
|
|
|
@ -13,18 +13,70 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
|
||||
def api(cls, prefix=(), foo=None):
|
||||
cls.fullpath = '/' + '/'.join(prefix)
|
||||
cls.fullname = prefix
|
||||
class InvalidAPIDefinition(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidQueryArgs(InvalidAPIDefinition):
|
||||
def __init__(self, callable, param):
|
||||
self.callable = callable
|
||||
self.param = param
|
||||
|
||||
def __str__(self):
|
||||
return (f"{self.callable.__qualname__} does not serialize query "
|
||||
f"arguments but has non-str parameter '{self.param}'")
|
||||
|
||||
|
||||
class MultiplePathParameters(InvalidAPIDefinition):
|
||||
def __init__(self, cls, param1, param2):
|
||||
self.cls = cls
|
||||
self.param1 = param1
|
||||
self.param2 = param2
|
||||
|
||||
def __str__(self):
|
||||
return (f"{self.cls.__name__} has multiple path parameters "
|
||||
f"{self.param1!r} and {self.param2!r}")
|
||||
|
||||
|
||||
def api(cls, prefix_names=(), prefix_path=(), path_params=(),
|
||||
serialize_query_args=True):
|
||||
if hasattr(cls, 'serialize_query_args'):
|
||||
serialize_query_args = cls.serialize_query_args
|
||||
else:
|
||||
cls.serialize_query_args = serialize_query_args
|
||||
cls.fullpath = '/' + '/'.join(prefix_path)
|
||||
cls.fullname = prefix_names
|
||||
seen_path_param = None
|
||||
for k, v in cls.__dict__.items():
|
||||
if isinstance(v, type):
|
||||
v.__shortname__ = k
|
||||
v.__name__ = cls.__name__ + '.' + k
|
||||
api(v, prefix + (k,))
|
||||
path_part = k
|
||||
path_param = ()
|
||||
if getattr(v, '__parameter__', False):
|
||||
if seen_path_param:
|
||||
raise MultiplePathParameters(cls, seen_path_param, k)
|
||||
seen_path_param = k
|
||||
path_part = '{' + path_part + '}'
|
||||
path_param = (k,)
|
||||
api(
|
||||
v,
|
||||
prefix_names + (k,),
|
||||
prefix_path + (path_part,),
|
||||
path_params + path_param,
|
||||
serialize_query_args)
|
||||
if callable(v):
|
||||
v.__qualname__ = cls.__name__ + '.' + k
|
||||
if not cls.serialize_query_args:
|
||||
params = inspect.signature(v).parameters
|
||||
for param_name, param in params.items():
|
||||
if param.annotation is not str:
|
||||
raise InvalidQueryArgs(v, param)
|
||||
v.__path_params__ = path_params
|
||||
return cls
|
||||
|
||||
|
||||
|
@ -35,6 +87,11 @@ class Payload(typing.Generic[T]):
|
|||
pass
|
||||
|
||||
|
||||
def path_parameter(cls):
|
||||
cls.__parameter__ = True
|
||||
return cls
|
||||
|
||||
|
||||
def simple_endpoint(typ):
|
||||
class endpoint:
|
||||
def GET() -> typ: ...
|
||||
|
|
|
@ -57,7 +57,8 @@ def trim(text):
|
|||
return text
|
||||
|
||||
|
||||
def _make_handler(controller, definition, implementation, serializer):
|
||||
def _make_handler(controller, definition, implementation, serializer,
|
||||
serialize_query_args):
|
||||
def_sig = inspect.signature(definition)
|
||||
def_ret_ann = def_sig.return_annotation
|
||||
def_params = def_sig.parameters
|
||||
|
@ -71,6 +72,13 @@ def _make_handler(controller, definition, implementation, serializer):
|
|||
|
||||
check_def_params = []
|
||||
|
||||
for param_name in definition.__path_params__:
|
||||
check_def_params.append(
|
||||
inspect.Parameter(
|
||||
param_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=str))
|
||||
|
||||
for param_name, param in def_params.items():
|
||||
if param_name in ('request', 'context'):
|
||||
raise Exception(
|
||||
|
@ -108,13 +116,17 @@ def _make_handler(controller, definition, implementation, serializer):
|
|||
data_annotation, await request.text())
|
||||
for arg, ann, default in query_args_anns:
|
||||
if arg in request.query:
|
||||
v = serializer.from_json(ann, request.query[arg])
|
||||
v = request.query[arg]
|
||||
if serialize_query_args:
|
||||
v = serializer.from_json(ann, v)
|
||||
elif default != inspect._empty:
|
||||
v = default
|
||||
else:
|
||||
raise TypeError(
|
||||
'missing required argument "{}"'.format(arg))
|
||||
args[arg] = v
|
||||
for param_name in definition.__path_params__:
|
||||
args[param_name] = request.match_info[param_name]
|
||||
if 'context' in impl_params:
|
||||
args['context'] = context
|
||||
if 'request' in impl_params:
|
||||
|
@ -164,7 +176,9 @@ def bind(router, endpoint, controller, serializer=None, _depth=None):
|
|||
router.add_route(
|
||||
method=method,
|
||||
path=endpoint.fullpath,
|
||||
handler=_make_handler(controller, v, impl, serializer))
|
||||
handler=_make_handler(
|
||||
controller, v, impl, serializer,
|
||||
endpoint.serialize_query_args))
|
||||
|
||||
|
||||
async def make_server_at_path(socket_path, endpoint, controller, **kw):
|
||||
|
|
|
@ -17,7 +17,12 @@ import contextlib
|
|||
import unittest
|
||||
|
||||
from subiquity.common.api.client import make_client
|
||||
from subiquity.common.api.defs import api, Payload
|
||||
from subiquity.common.api.defs import (
|
||||
api,
|
||||
InvalidQueryArgs,
|
||||
path_parameter,
|
||||
Payload,
|
||||
)
|
||||
|
||||
|
||||
def extract(c):
|
||||
|
@ -89,3 +94,59 @@ class TestClient(unittest.TestCase):
|
|||
r = extract(client.GET(arg='v'))
|
||||
self.assertEqual(r, '"v"')
|
||||
self.assertEqual(requests, [("GET", '/', {'arg': '"v"'}, None)])
|
||||
|
||||
def test_path_params(self):
|
||||
|
||||
@api
|
||||
class API:
|
||||
@path_parameter
|
||||
class param:
|
||||
def GET(arg: str) -> str: ...
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_request(method, path, *, params, json):
|
||||
requests.append((method, path, params, json))
|
||||
yield FakeResponse(params['arg'])
|
||||
|
||||
client = make_client(API, make_request)
|
||||
|
||||
requests = []
|
||||
r = extract(client['foo'].GET(arg='v'))
|
||||
self.assertEqual(r, '"v"')
|
||||
self.assertEqual(requests, [("GET", '/foo', {'arg': '"v"'}, None)])
|
||||
|
||||
def test_serialize_query_args(self):
|
||||
@api
|
||||
class API:
|
||||
serialize_query_args = False
|
||||
def GET(arg: str) -> str: ...
|
||||
|
||||
class meth:
|
||||
def GET(arg: str) -> str: ...
|
||||
|
||||
class more:
|
||||
serialize_query_args = True
|
||||
def GET(arg: str) -> str: ...
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_request(method, path, *, params, json):
|
||||
requests.append((method, path, params, json))
|
||||
yield FakeResponse(params['arg'])
|
||||
|
||||
client = make_client(API, make_request)
|
||||
|
||||
requests = []
|
||||
extract(client.GET(arg='v'))
|
||||
extract(client.meth.GET(arg='v'))
|
||||
extract(client.meth.more.GET(arg='v'))
|
||||
self.assertEqual(requests, [
|
||||
("GET", '/', {'arg': 'v'}, None),
|
||||
("GET", '/meth', {'arg': 'v'}, None),
|
||||
("GET", '/meth/more', {'arg': '"v"'}, None)])
|
||||
|
||||
class API2:
|
||||
serialize_query_args = False
|
||||
def GET(arg: int) -> str: ...
|
||||
|
||||
with self.assertRaises(InvalidQueryArgs):
|
||||
api(API2)
|
||||
|
|
|
@ -22,7 +22,12 @@ import aiohttp
|
|||
from aiohttp import web
|
||||
|
||||
from subiquity.common.api.client import make_client
|
||||
from subiquity.common.api.defs import api, Payload
|
||||
from subiquity.common.api.defs import (
|
||||
api,
|
||||
MultiplePathParameters,
|
||||
path_parameter,
|
||||
Payload,
|
||||
)
|
||||
|
||||
from .test_server import (
|
||||
makeTestClient,
|
||||
|
@ -85,13 +90,44 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
|
|||
self.assertEqual(
|
||||
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||
|
||||
async def test_path_args(self):
|
||||
@api
|
||||
class API:
|
||||
@path_parameter
|
||||
class param1:
|
||||
@path_parameter
|
||||
class param2:
|
||||
def GET(arg1: str, arg2: str) -> str: ...
|
||||
|
||||
class Impl(ControllerBase):
|
||||
async def param1_param2_GET(self, param1: str, param2: str,
|
||||
arg1: str, arg2: str) -> str:
|
||||
return '{}+{}+{}+{}'.format(param1, param2, arg1, arg2)
|
||||
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(
|
||||
await client["1"]["2"].GET(arg1="A", arg2="B"), '1+2+A+B')
|
||||
|
||||
def test_only_one_path_parameter(self):
|
||||
class API:
|
||||
@path_parameter
|
||||
class param1:
|
||||
def GET(arg: str) -> str: ...
|
||||
|
||||
@path_parameter
|
||||
class param2:
|
||||
def GET(arg: str) -> str: ...
|
||||
|
||||
with self.assertRaises(MultiplePathParameters):
|
||||
api(API)
|
||||
|
||||
async def test_defaults(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg1: str, arg2: str = "arg2") -> str: ...
|
||||
def GET(arg1: str = "arg1", arg2: str = "arg2") -> str: ...
|
||||
|
||||
class Impl(ControllerBase):
|
||||
async def GET(self, arg1: str, arg2: str = "arg2") -> str:
|
||||
async def GET(self, arg1: str = "arg1", arg2: str = "arg2") -> str:
|
||||
return '{}+{}'.format(arg1, arg2)
|
||||
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
|
@ -99,6 +135,8 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
|
|||
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||
self.assertEqual(
|
||||
await client.GET(arg1="A"), 'A+arg2')
|
||||
self.assertEqual(
|
||||
await client.GET(arg2="B"), 'arg1+B')
|
||||
|
||||
async def test_post(self):
|
||||
@api
|
||||
|
|
|
@ -21,7 +21,7 @@ from aiohttp import web
|
|||
|
||||
from subiquitycore.context import Context
|
||||
|
||||
from subiquity.common.api.defs import api, Payload
|
||||
from subiquity.common.api.defs import api, path_parameter, Payload
|
||||
from subiquity.common.api.server import (
|
||||
bind,
|
||||
controller_for_request,
|
||||
|
@ -233,3 +233,49 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
|
|||
self.assertEqual(await resp.json(), '')
|
||||
|
||||
self.assertIs(impl, seen_controller)
|
||||
|
||||
async def test_serialize_query_args(self):
|
||||
@api
|
||||
class API:
|
||||
serialize_query_args = False
|
||||
def GET(arg: str) -> str: ...
|
||||
|
||||
class meth:
|
||||
def GET(arg: str) -> str: ...
|
||||
|
||||
class more:
|
||||
serialize_query_args = True
|
||||
def GET(arg: str) -> str: ...
|
||||
|
||||
class Impl(ControllerBase):
|
||||
async def GET(self, arg: str) -> str:
|
||||
return arg
|
||||
|
||||
async def meth_GET(self, arg: str) -> str:
|
||||
return arg
|
||||
|
||||
async def meth_more_GET(self, arg: str) -> str:
|
||||
return arg
|
||||
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get('/?arg=whut'), 'whut')
|
||||
await self.assertResponse(
|
||||
client.get('/meth?arg=whut'), 'whut')
|
||||
await self.assertResponse(
|
||||
client.get('/meth/more?arg="whut"'), 'whut')
|
||||
|
||||
async def test_path_parameters(self):
|
||||
@api
|
||||
class API:
|
||||
@path_parameter
|
||||
class param:
|
||||
def GET(arg: int): ...
|
||||
|
||||
class Impl(ControllerBase):
|
||||
async def param_GET(self, param: str, arg: int):
|
||||
return param + str(arg)
|
||||
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get('/value?arg=2'), 'value2')
|
||||
|
|
Loading…
Reference in New Issue