Merge pull request #823 from mwhudson/add-api
add api definition and support code for client server split
This commit is contained in:
commit
04e5cf95b0
2
Makefile
2
Makefile
|
@ -50,7 +50,7 @@ lint: flake8
|
|||
|
||||
flake8:
|
||||
@echo 'tox -e flake8' is preferred to 'make flake8'
|
||||
$(PYTHON) -m flake8 $(CHECK_DIRS) --exclude gettext38.py
|
||||
$(PYTHON) -m flake8 $(CHECK_DIRS) --exclude gettext38.py,contextlib38.py
|
||||
|
||||
unit:
|
||||
echo "Running unit tests..."
|
||||
|
|
|
@ -11,4 +11,5 @@ jsonschema
|
|||
pyudev
|
||||
requests
|
||||
requests-unixsocket
|
||||
aiohttp
|
||||
-e git+https://github.com/CanonicalLtd/probert@b697ab779e7e056301e779f4708a9f1ce51b0027#egg=probert
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2020 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
|
||||
from subiquitycore import contextlib38
|
||||
|
||||
from subiquity.common.serialize import Serializer
|
||||
from .defs import Payload
|
||||
|
||||
|
||||
def _wrap(make_request, path, meth, serializer):
|
||||
sig = inspect.signature(meth)
|
||||
meth_params = sig.parameters
|
||||
payload_arg = None
|
||||
for name, param in meth_params.items():
|
||||
if getattr(param.annotation, '__origin__', None) is Payload:
|
||||
payload_arg = name
|
||||
payload_ann = param.annotation.__args__[0]
|
||||
r_ann = sig.return_annotation
|
||||
|
||||
async def impl(*args, **kw):
|
||||
args = sig.bind(*args, **kw)
|
||||
query_args = {}
|
||||
data = None
|
||||
for arg_name, value in args.arguments.items():
|
||||
if arg_name == payload_arg:
|
||||
data = serializer.serialize(payload_ann, value)
|
||||
else:
|
||||
query_args[arg_name] = json.dumps(
|
||||
serializer.serialize(
|
||||
meth_params[arg_name].annotation, value))
|
||||
async with make_request(
|
||||
meth.__name__, path, json=data, params=query_args) as resp:
|
||||
resp.raise_for_status()
|
||||
return serializer.deserialize(r_ann, await resp.json())
|
||||
return impl
|
||||
|
||||
|
||||
def make_client(endpoint_cls, make_request, serializer=None):
|
||||
if serializer is None:
|
||||
serializer = Serializer()
|
||||
|
||||
class C:
|
||||
pass
|
||||
|
||||
for k, v in endpoint_cls.__dict__.items():
|
||||
if isinstance(v, type):
|
||||
setattr(C, k, make_client(v, make_request, serializer))
|
||||
elif callable(v):
|
||||
setattr(C, k, _wrap(
|
||||
make_request, endpoint_cls.fullpath, v, serializer))
|
||||
return C
|
||||
|
||||
|
||||
def make_client_for_conn(
|
||||
endpoint_cls, conn, resp_hook=lambda r: r, serializer=None):
|
||||
@contextlib38.asynccontextmanager
|
||||
async def make_request(method, path, *, params, json):
|
||||
async with aiohttp.ClientSession(
|
||||
connector=conn, connector_owner=False) as session:
|
||||
# session.request needs a full URL with scheme and host
|
||||
# even though that's in some ways a bit silly with a unix
|
||||
# socket, so we just hardcode something here (I guess the
|
||||
# "a" gets sent a long to the server inthe Host: header
|
||||
# and the server could in principle do something like
|
||||
# virtual host based selection but well....)
|
||||
url = 'http://a' + path
|
||||
async with session.request(
|
||||
method, url, json=json, params=params,
|
||||
timeout=0) as response:
|
||||
yield resp_hook(response)
|
||||
|
||||
return make_client(endpoint_cls, make_request, serializer)
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import typing
|
||||
|
||||
|
||||
def api(cls, prefix=(), foo=None):
|
||||
cls.fullpath = '/' + '/'.join(prefix)
|
||||
cls.fullname = prefix
|
||||
for k, v in cls.__dict__.items():
|
||||
if isinstance(v, type):
|
||||
v.__name__ = cls.__name__ + '.' + k
|
||||
api(v, prefix + (k,))
|
||||
if callable(v):
|
||||
v.__qualname__ = cls.__name__ + '.' + k
|
||||
return cls
|
||||
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
class Payload(typing.Generic[T]):
|
||||
pass
|
||||
|
||||
|
||||
def simple_endpoint(typ):
|
||||
class endpoint:
|
||||
def GET() -> typ: ...
|
||||
def POST(data: Payload[typ]): ...
|
||||
return endpoint
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright 2020 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from subiquity.common.serialize import Serializer
|
||||
|
||||
from .defs import Payload
|
||||
|
||||
|
||||
class BindError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MissingImplementationError(BindError):
|
||||
def __init__(self, controller, methname):
|
||||
self.controller = controller
|
||||
self.methname = methname
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.controller} must have method called {self.methname}"
|
||||
|
||||
|
||||
class SignatureMisatchError(BindError):
|
||||
def __init__(self, methname, expected, actual):
|
||||
self.methname = methname
|
||||
self.expected = expected
|
||||
self.actual = actual
|
||||
|
||||
def __str__(self):
|
||||
return (f"implementation of {self.methname} has wrong signature, "
|
||||
f"should be {self.expected} but is {self.actual}")
|
||||
|
||||
|
||||
def trim(text):
|
||||
if text is None:
|
||||
return ''
|
||||
elif len(text) > 80:
|
||||
return text[:77] + '...'
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def _make_handler(controller, definition, implementation, serializer):
|
||||
def_sig = inspect.signature(definition)
|
||||
def_ret_ann = def_sig.return_annotation
|
||||
def_params = def_sig.parameters
|
||||
|
||||
impl_sig = inspect.signature(implementation)
|
||||
impl_params = impl_sig.parameters
|
||||
|
||||
data_annotation = None
|
||||
data_arg = None
|
||||
query_args_anns = []
|
||||
|
||||
check_def_params = []
|
||||
|
||||
for param_name, param in def_params.items():
|
||||
if param_name in ('request', 'context'):
|
||||
raise Exception(
|
||||
"api method {} cannot have parameter called request or "
|
||||
"context".format(definition))
|
||||
if getattr(param.annotation, '__origin__', None) is Payload:
|
||||
data_arg = param_name
|
||||
data_annotation = param.annotation.__args__[0]
|
||||
check_def_params.append(param.replace(annotation=data_annotation))
|
||||
else:
|
||||
query_args_anns.append(
|
||||
(param_name, param.annotation, param.default))
|
||||
check_def_params.append(param)
|
||||
|
||||
check_impl_params = [
|
||||
p for p in impl_params.values()
|
||||
if p.name not in ('context', 'request')
|
||||
]
|
||||
check_impl_sig = impl_sig.replace(parameters=check_impl_params)
|
||||
|
||||
check_def_sig = def_sig.replace(parameters=check_def_params)
|
||||
|
||||
if check_impl_sig != check_def_sig:
|
||||
raise SignatureMisatchError(
|
||||
definition.__qualname__, check_def_sig, check_impl_sig)
|
||||
|
||||
async def handler(request):
|
||||
context = controller.context.child(
|
||||
implementation.__name__, trim(await request.text()))
|
||||
with context:
|
||||
context.set('request', request)
|
||||
args = {}
|
||||
try:
|
||||
if data_annotation is not None:
|
||||
payload = json.loads(await request.text())
|
||||
args[data_arg] = serializer.deserialize(
|
||||
data_annotation, payload)
|
||||
for arg, ann, default in query_args_anns:
|
||||
if arg in request.query:
|
||||
v = serializer.deserialize(
|
||||
ann, json.loads(request.query[arg]))
|
||||
elif default != inspect._empty:
|
||||
v = default
|
||||
else:
|
||||
raise TypeError(
|
||||
'missing required argument "{}"'.format(arg))
|
||||
args[arg] = v
|
||||
if 'context' in impl_params:
|
||||
args['context'] = context
|
||||
if 'request' in impl_params:
|
||||
args['request'] = request
|
||||
result = await implementation(**args)
|
||||
resp = web.json_response(
|
||||
serializer.serialize(def_ret_ann, result),
|
||||
headers={'x-status': 'ok'})
|
||||
except Exception as exc:
|
||||
resp = web.Response(
|
||||
status=500,
|
||||
headers={
|
||||
'x-status': 'error',
|
||||
'x-error-type': type(exc).__name__,
|
||||
'x-error-msg': str(exc),
|
||||
})
|
||||
resp['exception'] = exc
|
||||
context.description = '{} {}'.format(resp.status, trim(resp.text))
|
||||
return resp
|
||||
handler.controller = controller
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
async def controller_for_request(request):
|
||||
match_info = await request.app.router.resolve(request)
|
||||
return getattr(match_info.handler, 'controller', None)
|
||||
|
||||
|
||||
def bind(router, endpoint, controller, serializer=None, _depth=None):
|
||||
if serializer is None:
|
||||
serializer = Serializer()
|
||||
if _depth is None:
|
||||
_depth = len(endpoint.fullname)
|
||||
|
||||
for v in endpoint.__dict__.values():
|
||||
if isinstance(v, type):
|
||||
bind(router, v, controller, serializer, _depth)
|
||||
elif callable(v):
|
||||
method = v.__name__
|
||||
impl_name = "_".join(endpoint.fullname[_depth:] + (method,))
|
||||
if not hasattr(controller, impl_name):
|
||||
raise MissingImplementationError(controller, impl_name)
|
||||
impl = getattr(controller, impl_name)
|
||||
router.add_route(
|
||||
method=method,
|
||||
path=endpoint.fullpath,
|
||||
handler=_make_handler(controller, v, impl, serializer))
|
||||
|
||||
|
||||
async def make_server_at_path(socket_path, endpoint, controller):
|
||||
app = web.Application()
|
||||
bind(app.router, endpoint, controller)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.UnixSite(runner, socket_path)
|
||||
await site.start()
|
||||
return site
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2020 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import unittest
|
||||
|
||||
from subiquitycore import contextlib38
|
||||
|
||||
from subiquity.common.api.client import make_client
|
||||
from subiquity.common.api.defs import api, Payload
|
||||
|
||||
|
||||
def extract(c):
|
||||
try:
|
||||
c.__await__().send(None)
|
||||
except StopIteration as s:
|
||||
return s.value
|
||||
else:
|
||||
raise AssertionError("coroutine not done")
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
async def json(self):
|
||||
return self.data
|
||||
|
||||
|
||||
class TestClient(unittest.TestCase):
|
||||
|
||||
def test_simple(self):
|
||||
|
||||
@api
|
||||
class API:
|
||||
class endpoint:
|
||||
def GET() -> str: ...
|
||||
def POST(data: Payload[str]) -> None: ...
|
||||
|
||||
@contextlib38.asynccontextmanager
|
||||
async def make_request(method, path, *, params, json):
|
||||
requests.append((method, path, params, json))
|
||||
if method == "GET":
|
||||
v = 'value'
|
||||
else:
|
||||
v = None
|
||||
yield FakeResponse(v)
|
||||
|
||||
client = make_client(API, make_request)
|
||||
|
||||
requests = []
|
||||
r = extract(client.endpoint.GET())
|
||||
self.assertEqual(r, 'value')
|
||||
self.assertEqual(requests, [("GET", '/endpoint', {}, None)])
|
||||
|
||||
requests = []
|
||||
r = extract(client.endpoint.POST('value'))
|
||||
self.assertEqual(r, None)
|
||||
self.assertEqual(
|
||||
requests, [("POST", '/endpoint', {}, 'value')])
|
||||
|
||||
def test_args(self):
|
||||
|
||||
@api
|
||||
class API:
|
||||
def GET(arg: str) -> str: ...
|
||||
|
||||
@contextlib38.asynccontextmanager
|
||||
async def make_request(method, path, *, params, json):
|
||||
requests.append((method, path, params, json))
|
||||
yield FakeResponse(params['arg'])
|
||||
|
||||
client = make_client(API, make_request)
|
||||
|
||||
requests = []
|
||||
r = extract(client.GET(arg='v'))
|
||||
self.assertEqual(r, '"v"')
|
||||
self.assertEqual(requests, [("GET", '/', {'arg': '"v"'}, None)])
|
|
@ -0,0 +1,264 @@
|
|||
# Copyright 2020 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import attr
|
||||
import functools
|
||||
import unittest
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
from subiquitycore import contextlib38
|
||||
|
||||
from subiquity.common.api.client import make_client
|
||||
from subiquity.common.api.defs import api, Payload
|
||||
|
||||
from .test_server import (
|
||||
makeTestClient,
|
||||
run_coro,
|
||||
TestControllerBase,
|
||||
)
|
||||
|
||||
|
||||
def make_request(client, method, path, *, params, json):
|
||||
return client.request(
|
||||
method, path, params=params, json=json)
|
||||
|
||||
|
||||
@contextlib38.asynccontextmanager
|
||||
async def makeE2EClient(api, impl,
|
||||
*, middlewares=(), make_request=make_request):
|
||||
async with makeTestClient(
|
||||
api, impl, middlewares=middlewares) as client:
|
||||
mr = functools.partial(make_request, client)
|
||||
yield make_client(api, mr)
|
||||
|
||||
|
||||
class TestEndToEnd(unittest.TestCase):
|
||||
|
||||
def test_simple(self):
|
||||
@api
|
||||
class API:
|
||||
def GET() -> str: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self) -> str:
|
||||
return 'value'
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(await client.GET(), 'value')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_nested(self):
|
||||
@api
|
||||
class API:
|
||||
class endpoint:
|
||||
class nested:
|
||||
def GET() -> str: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def endpoint_nested_GET(self) -> str:
|
||||
return 'value'
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(await client.endpoint.nested.GET(), 'value')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_args(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg1: str, arg2: str) -> str: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self, arg1: str, arg2: str) -> str:
|
||||
return '{}+{}'.format(arg1, arg2)
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(
|
||||
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_defaults(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg1: str, arg2: str = "arg2") -> str: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self, arg1: str, arg2: str = "arg2") -> str:
|
||||
return '{}+{}'.format(arg1, arg2)
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(
|
||||
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||
self.assertEqual(
|
||||
await client.GET(arg1="A"), 'A+arg2')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_post(self):
|
||||
@api
|
||||
class API:
|
||||
def POST(data: Payload[dict]) -> str: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def POST(self, data: dict) -> str:
|
||||
return data['key']
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
self.assertEqual(
|
||||
await client.POST({'key': 'value'}), 'value')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_typed(self):
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class In:
|
||||
val: int
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class Out:
|
||||
doubled: int
|
||||
|
||||
@api
|
||||
class API:
|
||||
class doubler:
|
||||
def POST(data: In) -> Out: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def doubler_POST(self, data: In) -> Out:
|
||||
return Out(doubled=data.val*2)
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
out = await client.doubler.POST(In(3))
|
||||
self.assertEqual(out.doubled, 6)
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_middleware(self):
|
||||
@api
|
||||
class API:
|
||||
def GET() -> int: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self) -> int:
|
||||
return 1/0
|
||||
|
||||
@web.middleware
|
||||
async def middleware(request, handler):
|
||||
return web.Response(
|
||||
status=200,
|
||||
headers={'x-status': 'skip'})
|
||||
|
||||
class Skip(Exception):
|
||||
pass
|
||||
|
||||
@contextlib38.asynccontextmanager
|
||||
async def custom_make_request(client, method, path, *, params, json):
|
||||
async with make_request(
|
||||
client, method, path, params=params, json=json) as resp:
|
||||
if resp.headers.get('x-status') == 'skip':
|
||||
raise Skip
|
||||
yield resp
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(
|
||||
API, Impl(),
|
||||
middlewares=[middleware],
|
||||
make_request=custom_make_request) as client:
|
||||
with self.assertRaises(Skip):
|
||||
await client.GET()
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_error(self):
|
||||
@api
|
||||
class API:
|
||||
class good:
|
||||
def GET(x: int) -> int: ...
|
||||
|
||||
class bad:
|
||||
def GET(x: int) -> int: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def good_GET(self, x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
async def bad_GET(self, x: int) -> int:
|
||||
raise Exception("baz")
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(API, Impl()) as client:
|
||||
r = await client.good.GET(2)
|
||||
self.assertEqual(r, 3)
|
||||
with self.assertRaises(aiohttp.ClientResponseError):
|
||||
await client.bad.GET(2)
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_error_middleware(self):
|
||||
@api
|
||||
class API:
|
||||
class good:
|
||||
def GET(x: int) -> int: ...
|
||||
|
||||
class bad:
|
||||
def GET(x: int) -> int: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def good_GET(self, x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
async def bad_GET(self, x: int) -> int:
|
||||
1/0
|
||||
|
||||
@web.middleware
|
||||
async def middleware(request, handler):
|
||||
resp = await handler(request)
|
||||
if resp.get('exception'):
|
||||
resp.headers['x-status'] = 'ERROR'
|
||||
return resp
|
||||
|
||||
class Abort(Exception):
|
||||
pass
|
||||
|
||||
@contextlib38.asynccontextmanager
|
||||
async def custom_make_request(client, method, path, *, params, json):
|
||||
async with make_request(
|
||||
client, method, path, params=params, json=json) as resp:
|
||||
if resp.headers.get('x-status') == 'ERROR':
|
||||
raise Abort
|
||||
yield resp
|
||||
|
||||
async def run():
|
||||
async with makeE2EClient(
|
||||
API, Impl(),
|
||||
middlewares=[middleware],
|
||||
make_request=custom_make_request) as client:
|
||||
r = await client.good.GET(2)
|
||||
self.assertEqual(r, 3)
|
||||
with self.assertRaises(Abort):
|
||||
await client.bad.GET(2)
|
||||
|
||||
run_coro(run())
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright 2020 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
from aiohttp import web
|
||||
|
||||
from subiquitycore.context import Context
|
||||
from subiquitycore import contextlib38
|
||||
|
||||
from subiquity.common.api.defs import api, Payload
|
||||
from subiquity.common.api.server import (
|
||||
bind,
|
||||
controller_for_request,
|
||||
MissingImplementationError,
|
||||
SignatureMisatchError,
|
||||
)
|
||||
|
||||
|
||||
def run_coro(coro):
|
||||
asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
class TestApp:
|
||||
|
||||
def report_start_event(self, context, description):
|
||||
pass
|
||||
|
||||
def report_finish_event(self, context, description, result):
|
||||
pass
|
||||
|
||||
project = 'test'
|
||||
|
||||
|
||||
class TestControllerBase:
|
||||
|
||||
def __init__(self):
|
||||
self.context = Context.new(TestApp())
|
||||
|
||||
|
||||
@contextlib38.asynccontextmanager
|
||||
async def makeTestClient(api, impl, middlewares=()):
|
||||
app = web.Application(middlewares=middlewares)
|
||||
bind(app.router, api, impl)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
yield client
|
||||
|
||||
|
||||
class TestBind(unittest.TestCase):
|
||||
|
||||
async def assertResponse(self, coro, value):
|
||||
resp = await coro
|
||||
self.assertEqual(resp.status, 200)
|
||||
self.assertEqual(await resp.json(), value)
|
||||
|
||||
def test_simple(self):
|
||||
@api
|
||||
class API:
|
||||
def GET() -> str: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self) -> str:
|
||||
return 'value'
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get("/"), 'value')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_nested(self):
|
||||
@api
|
||||
class API:
|
||||
class endpoint:
|
||||
class nested:
|
||||
def get(): ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def nested_get(self, request, context):
|
||||
return 'nested'
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API.endpoint, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get("/endpoint/nested"), 'nested')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_args(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg: str): ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self, arg: str):
|
||||
return arg
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.get('/?arg="whut"'), 'whut')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_missing_argument(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg: str): ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self, arg: str):
|
||||
return arg
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
resp = await client.get('/')
|
||||
self.assertEqual(resp.status, 500)
|
||||
self.assertEqual(resp.headers['x-status'], 'error')
|
||||
self.assertEqual(resp.headers['x-error-type'], 'TypeError')
|
||||
self.assertEqual(
|
||||
resp.headers['x-error-msg'],
|
||||
'missing required argument "arg"')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_error(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(): ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self):
|
||||
return 1/0
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
resp = await client.get('/')
|
||||
self.assertEqual(resp.status, 500)
|
||||
self.assertEqual(resp.headers['x-status'], 'error')
|
||||
self.assertEqual(
|
||||
resp.headers['x-error-type'], 'ZeroDivisionError')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_post(self):
|
||||
@api
|
||||
class API:
|
||||
def POST(data: Payload[str]) -> str: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def POST(self, data: str) -> str:
|
||||
return data
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(API, Impl()) as client:
|
||||
await self.assertResponse(
|
||||
client.post("/", json='value'),
|
||||
'value')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_missing_method(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg: str): ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
pass
|
||||
|
||||
app = web.Application()
|
||||
with self.assertRaises(MissingImplementationError) as cm:
|
||||
bind(app.router, API, Impl())
|
||||
self.assertEqual(cm.exception.methname, "GET")
|
||||
|
||||
def test_signature_checking(self):
|
||||
@api
|
||||
class API:
|
||||
def GET(arg: str): ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self, arg: int):
|
||||
return arg
|
||||
|
||||
app = web.Application()
|
||||
with self.assertRaises(SignatureMisatchError) as cm:
|
||||
bind(app.router, API, Impl())
|
||||
self.assertEqual(cm.exception.methname, "API.GET")
|
||||
|
||||
def test_middleware(self):
|
||||
|
||||
@web.middleware
|
||||
async def middleware(request, handler):
|
||||
resp = await handler(request)
|
||||
exc = resp.get('exception')
|
||||
if resp.get('exception') is not None:
|
||||
resp.headers['x-error-type'] = type(exc).__name__
|
||||
return resp
|
||||
|
||||
@api
|
||||
class API:
|
||||
def GET() -> str: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self) -> str:
|
||||
return 1/0
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(
|
||||
API, Impl(), middlewares=[middleware]) as client:
|
||||
resp = await client.get("/")
|
||||
self.assertEqual(
|
||||
resp.headers['x-error-type'], 'ZeroDivisionError')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
def test_controller_for_request(self):
|
||||
|
||||
seen_controller = None
|
||||
|
||||
@web.middleware
|
||||
async def middleware(request, handler):
|
||||
nonlocal seen_controller
|
||||
seen_controller = await controller_for_request(request)
|
||||
return await handler(request)
|
||||
|
||||
@api
|
||||
class API:
|
||||
class meth:
|
||||
def GET() -> str: ...
|
||||
|
||||
class Impl(TestControllerBase):
|
||||
async def GET(self) -> str:
|
||||
return ''
|
||||
|
||||
impl = Impl()
|
||||
|
||||
async def run():
|
||||
async with makeTestClient(
|
||||
API.meth, impl, middlewares=[middleware]) as client:
|
||||
resp = await client.get("/meth")
|
||||
self.assertEqual(await resp.json(), '')
|
||||
|
||||
run_coro(run())
|
||||
|
||||
self.assertIs(impl, seen_controller)
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2020 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import datetime
|
||||
import enum
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
import attr
|
||||
|
||||
# This is basically a half-assed version of # https://pypi.org/project/cattrs/
|
||||
# but that's not packaged and this is enough for our needs.
|
||||
|
||||
|
||||
class Serializer:
|
||||
|
||||
def __init__(self):
|
||||
self.typing_walkers = {
|
||||
typing.Union: self._walk_Union,
|
||||
list: self._walk_List,
|
||||
typing.List: self._walk_List,
|
||||
}
|
||||
self.type_serializers = {}
|
||||
self.type_deserializers = {}
|
||||
for typ in int, str, dict, bool, list, type(None):
|
||||
self.type_serializers[typ] = self._scalar
|
||||
self.type_deserializers[typ] = self._scalar
|
||||
self.type_serializers[datetime.datetime] = self._serialize_datetime
|
||||
self.type_deserializers[datetime.datetime] = self._deserialize_datetime
|
||||
|
||||
def _scalar(self, annotation, value, metadata):
|
||||
assert type(value) is annotation, "{} is not a {}".format(
|
||||
value, annotation)
|
||||
return value
|
||||
|
||||
def _walk_Union(self, meth, args, value, metadata):
|
||||
NoneType = type(None)
|
||||
assert NoneType in args, "can only serialize Optional"
|
||||
args = [a for a in args if a is not NoneType]
|
||||
assert len(args) == 1, "can only serialize Optional"
|
||||
if value is None:
|
||||
return value
|
||||
return meth(args[0], value, metadata)
|
||||
|
||||
def _walk_List(self, meth, args, value, metadata):
|
||||
return [meth(args[0], v, metadata) for v in value]
|
||||
|
||||
def _serialize_datetime(self, annotation, value, metadata):
|
||||
assert type(value) is annotation
|
||||
if metadata is not None and 'time_fmt' in metadata:
|
||||
return value.strftime(metadata['time_fmt'])
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
def _serialize_field(self, field, value):
|
||||
return {field.name: self.serialize(field.type, value, field.metadata)}
|
||||
|
||||
def _serialize_attr(self, annotation, value, metadata):
|
||||
r = {}
|
||||
for field in attr.fields(annotation):
|
||||
r.update(self._serialize_field(field, getattr(value, field.name)))
|
||||
return r
|
||||
|
||||
def serialize(self, annotation, value, metadata=None):
|
||||
if annotation is None:
|
||||
assert value is None
|
||||
return None
|
||||
if annotation is inspect.Signature.empty:
|
||||
return value
|
||||
if attr.has(annotation):
|
||||
return self._serialize_attr(annotation, value, metadata)
|
||||
origin = getattr(annotation, '__origin__', None)
|
||||
if origin is not None:
|
||||
args = annotation.__args__
|
||||
return self.typing_walkers[origin](
|
||||
self.serialize, args, value, metadata)
|
||||
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
|
||||
return value.name
|
||||
return self.type_serializers[annotation](annotation, value, metadata)
|
||||
|
||||
def _deserialize_datetime(self, annotation, value, metadata):
|
||||
assert type(value) is str
|
||||
if metadata is not None and 'time_fmt' in metadata:
|
||||
return datetime.datetime.strptime(value, metadata['time_fmt'])
|
||||
else:
|
||||
1/0
|
||||
|
||||
def _deserialize_field(self, field, value):
|
||||
return {
|
||||
field.name: self.deserialize(field.type, value, field.metadata)
|
||||
}
|
||||
|
||||
def _deserialize_attr(self, annotation, value, metadata):
|
||||
args = {}
|
||||
for field in attr.fields(annotation):
|
||||
args.update(self._deserialize_field(field, value[field.name]))
|
||||
return annotation(**args)
|
||||
|
||||
def deserialize(self, annotation, value, metadata=None):
|
||||
if annotation is None:
|
||||
assert value is None
|
||||
return None
|
||||
if annotation is inspect.Signature.empty:
|
||||
return value
|
||||
if attr.has(annotation):
|
||||
return self._deserialize_attr(annotation, value, metadata)
|
||||
origin = getattr(annotation, '__origin__', None)
|
||||
if origin is not None:
|
||||
args = annotation.__args__
|
||||
return self.typing_walkers[origin](
|
||||
self.deserialize, args, value, metadata)
|
||||
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
|
||||
return getattr(annotation, value)
|
||||
return self.type_deserializers[annotation](annotation, value, metadata)
|
|
@ -0,0 +1,704 @@
|
|||
"""Utilities for with-statement contexts. See PEP 343."""
|
||||
import abc
|
||||
import sys
|
||||
import _collections_abc
|
||||
from collections import deque
|
||||
from functools import wraps
|
||||
from types import MethodType
|
||||
|
||||
__all__ = ["asynccontextmanager", "contextmanager", "closing", "nullcontext",
|
||||
"AbstractContextManager", "AbstractAsyncContextManager",
|
||||
"AsyncExitStack", "ContextDecorator", "ExitStack",
|
||||
"redirect_stdout", "redirect_stderr", "suppress"]
|
||||
|
||||
|
||||
class AbstractContextManager(abc.ABC):
|
||||
|
||||
"""An abstract base class for context managers."""
|
||||
|
||||
def __enter__(self):
|
||||
"""Return `self` upon entering the runtime context."""
|
||||
return self
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Raise any exception triggered within the runtime context."""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is AbstractContextManager:
|
||||
return _collections_abc._check_methods(C, "__enter__", "__exit__")
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class AbstractAsyncContextManager(abc.ABC):
|
||||
|
||||
"""An abstract base class for asynchronous context managers."""
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Return `self` upon entering the runtime context."""
|
||||
return self
|
||||
|
||||
@abc.abstractmethod
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
"""Raise any exception triggered within the runtime context."""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is AbstractAsyncContextManager:
|
||||
return _collections_abc._check_methods(C, "__aenter__",
|
||||
"__aexit__")
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class ContextDecorator(object):
|
||||
"A base class or mixin that enables context managers to work as decorators."
|
||||
|
||||
def _recreate_cm(self):
|
||||
"""Return a recreated instance of self.
|
||||
|
||||
Allows an otherwise one-shot context manager like
|
||||
_GeneratorContextManager to support use as
|
||||
a decorator via implicit recreation.
|
||||
|
||||
This is a private interface just for _GeneratorContextManager.
|
||||
See issue #11647 for details.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __call__(self, func):
|
||||
@wraps(func)
|
||||
def inner(*args, **kwds):
|
||||
with self._recreate_cm():
|
||||
return func(*args, **kwds)
|
||||
return inner
|
||||
|
||||
|
||||
class _GeneratorContextManagerBase:
|
||||
"""Shared functionality for @contextmanager and @asynccontextmanager."""
|
||||
|
||||
def __init__(self, func, args, kwds):
|
||||
self.gen = func(*args, **kwds)
|
||||
self.func, self.args, self.kwds = func, args, kwds
|
||||
# Issue 19330: ensure context manager instances have good docstrings
|
||||
doc = getattr(func, "__doc__", None)
|
||||
if doc is None:
|
||||
doc = type(self).__doc__
|
||||
self.__doc__ = doc
|
||||
# Unfortunately, this still doesn't provide good help output when
|
||||
# inspecting the created context manager instances, since pydoc
|
||||
# currently bypasses the instance docstring and shows the docstring
|
||||
# for the class instead.
|
||||
# See http://bugs.python.org/issue19404 for more details.
|
||||
|
||||
|
||||
class _GeneratorContextManager(_GeneratorContextManagerBase,
|
||||
AbstractContextManager,
|
||||
ContextDecorator):
|
||||
"""Helper for @contextmanager decorator."""
|
||||
|
||||
def _recreate_cm(self):
|
||||
# _GCM instances are one-shot context managers, so the
|
||||
# CM must be recreated each time a decorated function is
|
||||
# called
|
||||
return self.__class__(self.func, self.args, self.kwds)
|
||||
|
||||
def __enter__(self):
|
||||
# do not keep args and kwds alive unnecessarily
|
||||
# they are only needed for recreation, which is not possible anymore
|
||||
del self.args, self.kwds, self.func
|
||||
try:
|
||||
return next(self.gen)
|
||||
except StopIteration:
|
||||
raise RuntimeError("generator didn't yield") from None
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
if type is None:
|
||||
try:
|
||||
next(self.gen)
|
||||
except StopIteration:
|
||||
return False
|
||||
else:
|
||||
raise RuntimeError("generator didn't stop")
|
||||
else:
|
||||
if value is None:
|
||||
# Need to force instantiation so we can reliably
|
||||
# tell if we get the same exception back
|
||||
value = type()
|
||||
try:
|
||||
self.gen.throw(type, value, traceback)
|
||||
except StopIteration as exc:
|
||||
# Suppress StopIteration *unless* it's the same exception that
|
||||
# was passed to throw(). This prevents a StopIteration
|
||||
# raised inside the "with" statement from being suppressed.
|
||||
return exc is not value
|
||||
except RuntimeError as exc:
|
||||
# Don't re-raise the passed in exception. (issue27122)
|
||||
if exc is value:
|
||||
return False
|
||||
# Likewise, avoid suppressing if a StopIteration exception
|
||||
# was passed to throw() and later wrapped into a RuntimeError
|
||||
# (see PEP 479).
|
||||
if type is StopIteration and exc.__cause__ is value:
|
||||
return False
|
||||
raise
|
||||
except:
|
||||
# only re-raise if it's *not* the exception that was
|
||||
# passed to throw(), because __exit__() must not raise
|
||||
# an exception unless __exit__() itself failed. But throw()
|
||||
# has to raise the exception to signal propagation, so this
|
||||
# fixes the impedance mismatch between the throw() protocol
|
||||
# and the __exit__() protocol.
|
||||
#
|
||||
# This cannot use 'except BaseException as exc' (as in the
|
||||
# async implementation) to maintain compatibility with
|
||||
# Python 2, where old-style class exceptions are not caught
|
||||
# by 'except BaseException'.
|
||||
if sys.exc_info()[1] is value:
|
||||
return False
|
||||
raise
|
||||
raise RuntimeError("generator didn't stop after throw()")
|
||||
|
||||
|
||||
class _AsyncGeneratorContextManager(_GeneratorContextManagerBase,
|
||||
AbstractAsyncContextManager):
|
||||
"""Helper for @asynccontextmanager."""
|
||||
|
||||
async def __aenter__(self):
|
||||
try:
|
||||
return await self.gen.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise RuntimeError("generator didn't yield") from None
|
||||
|
||||
async def __aexit__(self, typ, value, traceback):
|
||||
if typ is None:
|
||||
try:
|
||||
await self.gen.__anext__()
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("generator didn't stop")
|
||||
else:
|
||||
if value is None:
|
||||
value = typ()
|
||||
# See _GeneratorContextManager.__exit__ for comments on subtleties
|
||||
# in this implementation
|
||||
try:
|
||||
await self.gen.athrow(typ, value, traceback)
|
||||
raise RuntimeError("generator didn't stop after athrow()")
|
||||
except StopAsyncIteration as exc:
|
||||
return exc is not value
|
||||
except RuntimeError as exc:
|
||||
if exc is value:
|
||||
return False
|
||||
# Avoid suppressing if a StopIteration exception
|
||||
# was passed to throw() and later wrapped into a RuntimeError
|
||||
# (see PEP 479 for sync generators; async generators also
|
||||
# have this behavior). But do this only if the exception wrapped
|
||||
# by the RuntimeError is actully Stop(Async)Iteration (see
|
||||
# issue29692).
|
||||
if isinstance(value, (StopIteration, StopAsyncIteration)):
|
||||
if exc.__cause__ is value:
|
||||
return False
|
||||
raise
|
||||
except BaseException as exc:
|
||||
if exc is not value:
|
||||
raise
|
||||
|
||||
|
||||
def contextmanager(func):
|
||||
"""@contextmanager decorator.
|
||||
|
||||
Typical usage:
|
||||
|
||||
@contextmanager
|
||||
def some_generator(<arguments>):
|
||||
<setup>
|
||||
try:
|
||||
yield <value>
|
||||
finally:
|
||||
<cleanup>
|
||||
|
||||
This makes this:
|
||||
|
||||
with some_generator(<arguments>) as <variable>:
|
||||
<body>
|
||||
|
||||
equivalent to this:
|
||||
|
||||
<setup>
|
||||
try:
|
||||
<variable> = <value>
|
||||
<body>
|
||||
finally:
|
||||
<cleanup>
|
||||
"""
|
||||
@wraps(func)
|
||||
def helper(*args, **kwds):
|
||||
return _GeneratorContextManager(func, args, kwds)
|
||||
return helper
|
||||
|
||||
|
||||
def asynccontextmanager(func):
|
||||
"""@asynccontextmanager decorator.
|
||||
|
||||
Typical usage:
|
||||
|
||||
@asynccontextmanager
|
||||
async def some_async_generator(<arguments>):
|
||||
<setup>
|
||||
try:
|
||||
yield <value>
|
||||
finally:
|
||||
<cleanup>
|
||||
|
||||
This makes this:
|
||||
|
||||
async with some_async_generator(<arguments>) as <variable>:
|
||||
<body>
|
||||
|
||||
equivalent to this:
|
||||
|
||||
<setup>
|
||||
try:
|
||||
<variable> = <value>
|
||||
<body>
|
||||
finally:
|
||||
<cleanup>
|
||||
"""
|
||||
@wraps(func)
|
||||
def helper(*args, **kwds):
|
||||
return _AsyncGeneratorContextManager(func, args, kwds)
|
||||
return helper
|
||||
|
||||
|
||||
class closing(AbstractContextManager):
|
||||
"""Context to automatically close something at the end of a block.
|
||||
|
||||
Code like this:
|
||||
|
||||
with closing(<module>.open(<arguments>)) as f:
|
||||
<block>
|
||||
|
||||
is equivalent to this:
|
||||
|
||||
f = <module>.open(<arguments>)
|
||||
try:
|
||||
<block>
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
"""
|
||||
def __init__(self, thing):
|
||||
self.thing = thing
|
||||
def __enter__(self):
|
||||
return self.thing
|
||||
def __exit__(self, *exc_info):
|
||||
self.thing.close()
|
||||
|
||||
|
||||
class _RedirectStream(AbstractContextManager):
|
||||
|
||||
_stream = None
|
||||
|
||||
def __init__(self, new_target):
|
||||
self._new_target = new_target
|
||||
# We use a list of old targets to make this CM re-entrant
|
||||
self._old_targets = []
|
||||
|
||||
def __enter__(self):
|
||||
self._old_targets.append(getattr(sys, self._stream))
|
||||
setattr(sys, self._stream, self._new_target)
|
||||
return self._new_target
|
||||
|
||||
def __exit__(self, exctype, excinst, exctb):
|
||||
setattr(sys, self._stream, self._old_targets.pop())
|
||||
|
||||
|
||||
class redirect_stdout(_RedirectStream):
|
||||
"""Context manager for temporarily redirecting stdout to another file.
|
||||
|
||||
# How to send help() to stderr
|
||||
with redirect_stdout(sys.stderr):
|
||||
help(dir)
|
||||
|
||||
# How to write help() to a file
|
||||
with open('help.txt', 'w') as f:
|
||||
with redirect_stdout(f):
|
||||
help(pow)
|
||||
"""
|
||||
|
||||
_stream = "stdout"
|
||||
|
||||
|
||||
class redirect_stderr(_RedirectStream):
|
||||
"""Context manager for temporarily redirecting stderr to another file."""
|
||||
|
||||
_stream = "stderr"
|
||||
|
||||
|
||||
class suppress(AbstractContextManager):
|
||||
"""Context manager to suppress specified exceptions
|
||||
|
||||
After the exception is suppressed, execution proceeds with the next
|
||||
statement following the with statement.
|
||||
|
||||
with suppress(FileNotFoundError):
|
||||
os.remove(somefile)
|
||||
# Execution still resumes here if the file was already removed
|
||||
"""
|
||||
|
||||
def __init__(self, *exceptions):
|
||||
self._exceptions = exceptions
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exctype, excinst, exctb):
|
||||
# Unlike isinstance and issubclass, CPython exception handling
|
||||
# currently only looks at the concrete type hierarchy (ignoring
|
||||
# the instance and subclass checking hooks). While Guido considers
|
||||
# that a bug rather than a feature, it's a fairly hard one to fix
|
||||
# due to various internal implementation details. suppress provides
|
||||
# the simpler issubclass based semantics, rather than trying to
|
||||
# exactly reproduce the limitations of the CPython interpreter.
|
||||
#
|
||||
# See http://bugs.python.org/issue12029 for more details
|
||||
return exctype is not None and issubclass(exctype, self._exceptions)
|
||||
|
||||
|
||||
class _BaseExitStack:
|
||||
"""A base class for ExitStack and AsyncExitStack."""
|
||||
|
||||
@staticmethod
|
||||
def _create_exit_wrapper(cm, cm_exit):
|
||||
return MethodType(cm_exit, cm)
|
||||
|
||||
@staticmethod
|
||||
def _create_cb_wrapper(callback, *args, **kwds):
|
||||
def _exit_wrapper(exc_type, exc, tb):
|
||||
callback(*args, **kwds)
|
||||
return _exit_wrapper
|
||||
|
||||
def __init__(self):
|
||||
self._exit_callbacks = deque()
|
||||
|
||||
def pop_all(self):
|
||||
"""Preserve the context stack by transferring it to a new instance."""
|
||||
new_stack = type(self)()
|
||||
new_stack._exit_callbacks = self._exit_callbacks
|
||||
self._exit_callbacks = deque()
|
||||
return new_stack
|
||||
|
||||
def push(self, exit):
|
||||
"""Registers a callback with the standard __exit__ method signature.
|
||||
|
||||
Can suppress exceptions the same way __exit__ method can.
|
||||
Also accepts any object with an __exit__ method (registering a call
|
||||
to the method instead of the object itself).
|
||||
"""
|
||||
# We use an unbound method rather than a bound method to follow
|
||||
# the standard lookup behaviour for special methods.
|
||||
_cb_type = type(exit)
|
||||
|
||||
try:
|
||||
exit_method = _cb_type.__exit__
|
||||
except AttributeError:
|
||||
# Not a context manager, so assume it's a callable.
|
||||
self._push_exit_callback(exit)
|
||||
else:
|
||||
self._push_cm_exit(exit, exit_method)
|
||||
return exit # Allow use as a decorator.
|
||||
|
||||
def enter_context(self, cm):
|
||||
"""Enters the supplied context manager.
|
||||
|
||||
If successful, also pushes its __exit__ method as a callback and
|
||||
returns the result of the __enter__ method.
|
||||
"""
|
||||
# We look up the special methods on the type to match the with
|
||||
# statement.
|
||||
_cm_type = type(cm)
|
||||
_exit = _cm_type.__exit__
|
||||
result = _cm_type.__enter__(cm)
|
||||
self._push_cm_exit(cm, _exit)
|
||||
return result
|
||||
|
||||
def callback(*args, **kwds):
|
||||
"""Registers an arbitrary callback and arguments.
|
||||
|
||||
Cannot suppress exceptions.
|
||||
"""
|
||||
if len(args) >= 2:
|
||||
self, callback, *args = args
|
||||
elif not args:
|
||||
raise TypeError("descriptor 'callback' of '_BaseExitStack' object "
|
||||
"needs an argument")
|
||||
elif 'callback' in kwds:
|
||||
callback = kwds.pop('callback')
|
||||
self, *args = args
|
||||
import warnings
|
||||
warnings.warn("Passing 'callback' as keyword argument is deprecated",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
else:
|
||||
raise TypeError('callback expected at least 1 positional argument, '
|
||||
'got %d' % (len(args)-1))
|
||||
|
||||
_exit_wrapper = self._create_cb_wrapper(callback, *args, **kwds)
|
||||
|
||||
# We changed the signature, so using @wraps is not appropriate, but
|
||||
# setting __wrapped__ may still help with introspection.
|
||||
_exit_wrapper.__wrapped__ = callback
|
||||
self._push_exit_callback(_exit_wrapper)
|
||||
return callback # Allow use as a decorator
|
||||
callback.__text_signature__ = '($self, callback, /, *args, **kwds)'
|
||||
|
||||
def _push_cm_exit(self, cm, cm_exit):
|
||||
"""Helper to correctly register callbacks to __exit__ methods."""
|
||||
_exit_wrapper = self._create_exit_wrapper(cm, cm_exit)
|
||||
self._push_exit_callback(_exit_wrapper, True)
|
||||
|
||||
def _push_exit_callback(self, callback, is_sync=True):
|
||||
self._exit_callbacks.append((is_sync, callback))
|
||||
|
||||
|
||||
# Inspired by discussions on http://bugs.python.org/issue13585
|
||||
class ExitStack(_BaseExitStack, AbstractContextManager):
|
||||
"""Context manager for dynamic management of a stack of exit callbacks.
|
||||
|
||||
For example:
|
||||
with ExitStack() as stack:
|
||||
files = [stack.enter_context(open(fname)) for fname in filenames]
|
||||
# All opened files will automatically be closed at the end of
|
||||
# the with statement, even if attempts to open files later
|
||||
# in the list raise an exception.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_details):
|
||||
received_exc = exc_details[0] is not None
|
||||
|
||||
# We manipulate the exception state so it behaves as though
|
||||
# we were actually nesting multiple with statements
|
||||
frame_exc = sys.exc_info()[1]
|
||||
def _fix_exception_context(new_exc, old_exc):
|
||||
# Context may not be correct, so find the end of the chain
|
||||
while 1:
|
||||
exc_context = new_exc.__context__
|
||||
if exc_context is old_exc:
|
||||
# Context is already set correctly (see issue 20317)
|
||||
return
|
||||
if exc_context is None or exc_context is frame_exc:
|
||||
break
|
||||
new_exc = exc_context
|
||||
# Change the end of the chain to point to the exception
|
||||
# we expect it to reference
|
||||
new_exc.__context__ = old_exc
|
||||
|
||||
# Callbacks are invoked in LIFO order to match the behaviour of
|
||||
# nested context managers
|
||||
suppressed_exc = False
|
||||
pending_raise = False
|
||||
while self._exit_callbacks:
|
||||
is_sync, cb = self._exit_callbacks.pop()
|
||||
assert is_sync
|
||||
try:
|
||||
if cb(*exc_details):
|
||||
suppressed_exc = True
|
||||
pending_raise = False
|
||||
exc_details = (None, None, None)
|
||||
except:
|
||||
new_exc_details = sys.exc_info()
|
||||
# simulate the stack of exceptions by setting the context
|
||||
_fix_exception_context(new_exc_details[1], exc_details[1])
|
||||
pending_raise = True
|
||||
exc_details = new_exc_details
|
||||
if pending_raise:
|
||||
try:
|
||||
# bare "raise exc_details[1]" replaces our carefully
|
||||
# set-up context
|
||||
fixed_ctx = exc_details[1].__context__
|
||||
raise exc_details[1]
|
||||
except BaseException:
|
||||
exc_details[1].__context__ = fixed_ctx
|
||||
raise
|
||||
return received_exc and suppressed_exc
|
||||
|
||||
def close(self):
|
||||
"""Immediately unwind the context stack."""
|
||||
self.__exit__(None, None, None)
|
||||
|
||||
|
||||
# Inspired by discussions on https://bugs.python.org/issue29302
|
||||
class AsyncExitStack(_BaseExitStack, AbstractAsyncContextManager):
|
||||
"""Async context manager for dynamic management of a stack of exit
|
||||
callbacks.
|
||||
|
||||
For example:
|
||||
async with AsyncExitStack() as stack:
|
||||
connections = [await stack.enter_async_context(get_connection())
|
||||
for i in range(5)]
|
||||
# All opened connections will automatically be released at the
|
||||
# end of the async with statement, even if attempts to open a
|
||||
# connection later in the list raise an exception.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _create_async_exit_wrapper(cm, cm_exit):
|
||||
return MethodType(cm_exit, cm)
|
||||
|
||||
@staticmethod
|
||||
def _create_async_cb_wrapper(callback, *args, **kwds):
|
||||
async def _exit_wrapper(exc_type, exc, tb):
|
||||
await callback(*args, **kwds)
|
||||
return _exit_wrapper
|
||||
|
||||
async def enter_async_context(self, cm):
|
||||
"""Enters the supplied async context manager.
|
||||
|
||||
If successful, also pushes its __aexit__ method as a callback and
|
||||
returns the result of the __aenter__ method.
|
||||
"""
|
||||
_cm_type = type(cm)
|
||||
_exit = _cm_type.__aexit__
|
||||
result = await _cm_type.__aenter__(cm)
|
||||
self._push_async_cm_exit(cm, _exit)
|
||||
return result
|
||||
|
||||
def push_async_exit(self, exit):
|
||||
"""Registers a coroutine function with the standard __aexit__ method
|
||||
signature.
|
||||
|
||||
Can suppress exceptions the same way __aexit__ method can.
|
||||
Also accepts any object with an __aexit__ method (registering a call
|
||||
to the method instead of the object itself).
|
||||
"""
|
||||
_cb_type = type(exit)
|
||||
try:
|
||||
exit_method = _cb_type.__aexit__
|
||||
except AttributeError:
|
||||
# Not an async context manager, so assume it's a coroutine function
|
||||
self._push_exit_callback(exit, False)
|
||||
else:
|
||||
self._push_async_cm_exit(exit, exit_method)
|
||||
return exit # Allow use as a decorator
|
||||
|
||||
def push_async_callback(*args, **kwds):
|
||||
"""Registers an arbitrary coroutine function and arguments.
|
||||
|
||||
Cannot suppress exceptions.
|
||||
"""
|
||||
if len(args) >= 2:
|
||||
self, callback, *args = args
|
||||
elif not args:
|
||||
raise TypeError("descriptor 'push_async_callback' of "
|
||||
"'AsyncExitStack' object needs an argument")
|
||||
elif 'callback' in kwds:
|
||||
callback = kwds.pop('callback')
|
||||
self, *args = args
|
||||
import warnings
|
||||
warnings.warn("Passing 'callback' as keyword argument is deprecated",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
else:
|
||||
raise TypeError('push_async_callback expected at least 1 '
|
||||
'positional argument, got %d' % (len(args)-1))
|
||||
|
||||
_exit_wrapper = self._create_async_cb_wrapper(callback, *args, **kwds)
|
||||
|
||||
# We changed the signature, so using @wraps is not appropriate, but
|
||||
# setting __wrapped__ may still help with introspection.
|
||||
_exit_wrapper.__wrapped__ = callback
|
||||
self._push_exit_callback(_exit_wrapper, False)
|
||||
return callback # Allow use as a decorator
|
||||
push_async_callback.__text_signature__ = '($self, callback, /, *args, **kwds)'
|
||||
|
||||
async def aclose(self):
|
||||
"""Immediately unwind the context stack."""
|
||||
await self.__aexit__(None, None, None)
|
||||
|
||||
def _push_async_cm_exit(self, cm, cm_exit):
|
||||
"""Helper to correctly register coroutine function to __aexit__
|
||||
method."""
|
||||
_exit_wrapper = self._create_async_exit_wrapper(cm, cm_exit)
|
||||
self._push_exit_callback(_exit_wrapper, False)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_details):
|
||||
received_exc = exc_details[0] is not None
|
||||
|
||||
# We manipulate the exception state so it behaves as though
|
||||
# we were actually nesting multiple with statements
|
||||
frame_exc = sys.exc_info()[1]
|
||||
def _fix_exception_context(new_exc, old_exc):
|
||||
# Context may not be correct, so find the end of the chain
|
||||
while 1:
|
||||
exc_context = new_exc.__context__
|
||||
if exc_context is old_exc:
|
||||
# Context is already set correctly (see issue 20317)
|
||||
return
|
||||
if exc_context is None or exc_context is frame_exc:
|
||||
break
|
||||
new_exc = exc_context
|
||||
# Change the end of the chain to point to the exception
|
||||
# we expect it to reference
|
||||
new_exc.__context__ = old_exc
|
||||
|
||||
# Callbacks are invoked in LIFO order to match the behaviour of
|
||||
# nested context managers
|
||||
suppressed_exc = False
|
||||
pending_raise = False
|
||||
while self._exit_callbacks:
|
||||
is_sync, cb = self._exit_callbacks.pop()
|
||||
try:
|
||||
if is_sync:
|
||||
cb_suppress = cb(*exc_details)
|
||||
else:
|
||||
cb_suppress = await cb(*exc_details)
|
||||
|
||||
if cb_suppress:
|
||||
suppressed_exc = True
|
||||
pending_raise = False
|
||||
exc_details = (None, None, None)
|
||||
except:
|
||||
new_exc_details = sys.exc_info()
|
||||
# simulate the stack of exceptions by setting the context
|
||||
_fix_exception_context(new_exc_details[1], exc_details[1])
|
||||
pending_raise = True
|
||||
exc_details = new_exc_details
|
||||
if pending_raise:
|
||||
try:
|
||||
# bare "raise exc_details[1]" replaces our carefully
|
||||
# set-up context
|
||||
fixed_ctx = exc_details[1].__context__
|
||||
raise exc_details[1]
|
||||
except BaseException:
|
||||
exc_details[1].__context__ = fixed_ctx
|
||||
raise
|
||||
return received_exc and suppressed_exc
|
||||
|
||||
|
||||
class nullcontext(AbstractContextManager):
|
||||
"""Context manager that does no additional processing.
|
||||
|
||||
Used as a stand-in for a normal context manager, when a particular
|
||||
block of code is only sometimes used with a normal context manager:
|
||||
|
||||
cm = optional_cm if condition else nullcontext()
|
||||
with cm:
|
||||
# Perform operation, using optional_cm if condition is True
|
||||
"""
|
||||
|
||||
def __init__(self, enter_result=None):
|
||||
self.enter_result = enter_result
|
||||
|
||||
def __enter__(self):
|
||||
return self.enter_result
|
||||
|
||||
def __exit__(self, *excinfo):
|
||||
pass
|
Loading…
Reference in New Issue