Merge pull request #823 from mwhudson/add-api
add api definition and support code for client server split
This commit is contained in:
commit
04e5cf95b0
2
Makefile
2
Makefile
|
@ -50,7 +50,7 @@ lint: flake8
|
||||||
|
|
||||||
flake8:
|
flake8:
|
||||||
@echo 'tox -e flake8' is preferred to 'make flake8'
|
@echo 'tox -e flake8' is preferred to 'make flake8'
|
||||||
$(PYTHON) -m flake8 $(CHECK_DIRS) --exclude gettext38.py
|
$(PYTHON) -m flake8 $(CHECK_DIRS) --exclude gettext38.py,contextlib38.py
|
||||||
|
|
||||||
unit:
|
unit:
|
||||||
echo "Running unit tests..."
|
echo "Running unit tests..."
|
||||||
|
|
|
@ -11,4 +11,5 @@ jsonschema
|
||||||
pyudev
|
pyudev
|
||||||
requests
|
requests
|
||||||
requests-unixsocket
|
requests-unixsocket
|
||||||
|
aiohttp
|
||||||
-e git+https://github.com/CanonicalLtd/probert@b697ab779e7e056301e779f4708a9f1ce51b0027#egg=probert
|
-e git+https://github.com/CanonicalLtd/probert@b697ab779e7e056301e779f4708a9f1ce51b0027#egg=probert
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
# Copyright 2020 Canonical, Ltd.
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
@ -0,0 +1,89 @@
|
||||||
|
# Copyright 2020 Canonical, Ltd.
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from subiquitycore import contextlib38
|
||||||
|
|
||||||
|
from subiquity.common.serialize import Serializer
|
||||||
|
from .defs import Payload
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap(make_request, path, meth, serializer):
|
||||||
|
sig = inspect.signature(meth)
|
||||||
|
meth_params = sig.parameters
|
||||||
|
payload_arg = None
|
||||||
|
for name, param in meth_params.items():
|
||||||
|
if getattr(param.annotation, '__origin__', None) is Payload:
|
||||||
|
payload_arg = name
|
||||||
|
payload_ann = param.annotation.__args__[0]
|
||||||
|
r_ann = sig.return_annotation
|
||||||
|
|
||||||
|
async def impl(*args, **kw):
|
||||||
|
args = sig.bind(*args, **kw)
|
||||||
|
query_args = {}
|
||||||
|
data = None
|
||||||
|
for arg_name, value in args.arguments.items():
|
||||||
|
if arg_name == payload_arg:
|
||||||
|
data = serializer.serialize(payload_ann, value)
|
||||||
|
else:
|
||||||
|
query_args[arg_name] = json.dumps(
|
||||||
|
serializer.serialize(
|
||||||
|
meth_params[arg_name].annotation, value))
|
||||||
|
async with make_request(
|
||||||
|
meth.__name__, path, json=data, params=query_args) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return serializer.deserialize(r_ann, await resp.json())
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
def make_client(endpoint_cls, make_request, serializer=None):
|
||||||
|
if serializer is None:
|
||||||
|
serializer = Serializer()
|
||||||
|
|
||||||
|
class C:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for k, v in endpoint_cls.__dict__.items():
|
||||||
|
if isinstance(v, type):
|
||||||
|
setattr(C, k, make_client(v, make_request, serializer))
|
||||||
|
elif callable(v):
|
||||||
|
setattr(C, k, _wrap(
|
||||||
|
make_request, endpoint_cls.fullpath, v, serializer))
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
def make_client_for_conn(
|
||||||
|
endpoint_cls, conn, resp_hook=lambda r: r, serializer=None):
|
||||||
|
@contextlib38.asynccontextmanager
|
||||||
|
async def make_request(method, path, *, params, json):
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
connector=conn, connector_owner=False) as session:
|
||||||
|
# session.request needs a full URL with scheme and host
|
||||||
|
# even though that's in some ways a bit silly with a unix
|
||||||
|
# socket, so we just hardcode something here (I guess the
|
||||||
|
# "a" gets sent a long to the server inthe Host: header
|
||||||
|
# and the server could in principle do something like
|
||||||
|
# virtual host based selection but well....)
|
||||||
|
url = 'http://a' + path
|
||||||
|
async with session.request(
|
||||||
|
method, url, json=json, params=params,
|
||||||
|
timeout=0) as response:
|
||||||
|
yield resp_hook(response)
|
||||||
|
|
||||||
|
return make_client(endpoint_cls, make_request, serializer)
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright 2020 Canonical, Ltd.
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
|
def api(cls, prefix=(), foo=None):
|
||||||
|
cls.fullpath = '/' + '/'.join(prefix)
|
||||||
|
cls.fullname = prefix
|
||||||
|
for k, v in cls.__dict__.items():
|
||||||
|
if isinstance(v, type):
|
||||||
|
v.__name__ = cls.__name__ + '.' + k
|
||||||
|
api(v, prefix + (k,))
|
||||||
|
if callable(v):
|
||||||
|
v.__qualname__ = cls.__name__ + '.' + k
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
T = typing.TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Payload(typing.Generic[T]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def simple_endpoint(typ):
|
||||||
|
class endpoint:
|
||||||
|
def GET() -> typ: ...
|
||||||
|
def POST(data: Payload[typ]): ...
|
||||||
|
return endpoint
|
|
@ -0,0 +1,177 @@
|
||||||
|
# Copyright 2020 Canonical, Ltd.
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from subiquity.common.serialize import Serializer
|
||||||
|
|
||||||
|
from .defs import Payload
|
||||||
|
|
||||||
|
|
||||||
|
class BindError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MissingImplementationError(BindError):
|
||||||
|
def __init__(self, controller, methname):
|
||||||
|
self.controller = controller
|
||||||
|
self.methname = methname
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.controller} must have method called {self.methname}"
|
||||||
|
|
||||||
|
|
||||||
|
class SignatureMisatchError(BindError):
|
||||||
|
def __init__(self, methname, expected, actual):
|
||||||
|
self.methname = methname
|
||||||
|
self.expected = expected
|
||||||
|
self.actual = actual
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return (f"implementation of {self.methname} has wrong signature, "
|
||||||
|
f"should be {self.expected} but is {self.actual}")
|
||||||
|
|
||||||
|
|
||||||
|
def trim(text):
|
||||||
|
if text is None:
|
||||||
|
return ''
|
||||||
|
elif len(text) > 80:
|
||||||
|
return text[:77] + '...'
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _make_handler(controller, definition, implementation, serializer):
|
||||||
|
def_sig = inspect.signature(definition)
|
||||||
|
def_ret_ann = def_sig.return_annotation
|
||||||
|
def_params = def_sig.parameters
|
||||||
|
|
||||||
|
impl_sig = inspect.signature(implementation)
|
||||||
|
impl_params = impl_sig.parameters
|
||||||
|
|
||||||
|
data_annotation = None
|
||||||
|
data_arg = None
|
||||||
|
query_args_anns = []
|
||||||
|
|
||||||
|
check_def_params = []
|
||||||
|
|
||||||
|
for param_name, param in def_params.items():
|
||||||
|
if param_name in ('request', 'context'):
|
||||||
|
raise Exception(
|
||||||
|
"api method {} cannot have parameter called request or "
|
||||||
|
"context".format(definition))
|
||||||
|
if getattr(param.annotation, '__origin__', None) is Payload:
|
||||||
|
data_arg = param_name
|
||||||
|
data_annotation = param.annotation.__args__[0]
|
||||||
|
check_def_params.append(param.replace(annotation=data_annotation))
|
||||||
|
else:
|
||||||
|
query_args_anns.append(
|
||||||
|
(param_name, param.annotation, param.default))
|
||||||
|
check_def_params.append(param)
|
||||||
|
|
||||||
|
check_impl_params = [
|
||||||
|
p for p in impl_params.values()
|
||||||
|
if p.name not in ('context', 'request')
|
||||||
|
]
|
||||||
|
check_impl_sig = impl_sig.replace(parameters=check_impl_params)
|
||||||
|
|
||||||
|
check_def_sig = def_sig.replace(parameters=check_def_params)
|
||||||
|
|
||||||
|
if check_impl_sig != check_def_sig:
|
||||||
|
raise SignatureMisatchError(
|
||||||
|
definition.__qualname__, check_def_sig, check_impl_sig)
|
||||||
|
|
||||||
|
async def handler(request):
|
||||||
|
context = controller.context.child(
|
||||||
|
implementation.__name__, trim(await request.text()))
|
||||||
|
with context:
|
||||||
|
context.set('request', request)
|
||||||
|
args = {}
|
||||||
|
try:
|
||||||
|
if data_annotation is not None:
|
||||||
|
payload = json.loads(await request.text())
|
||||||
|
args[data_arg] = serializer.deserialize(
|
||||||
|
data_annotation, payload)
|
||||||
|
for arg, ann, default in query_args_anns:
|
||||||
|
if arg in request.query:
|
||||||
|
v = serializer.deserialize(
|
||||||
|
ann, json.loads(request.query[arg]))
|
||||||
|
elif default != inspect._empty:
|
||||||
|
v = default
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'missing required argument "{}"'.format(arg))
|
||||||
|
args[arg] = v
|
||||||
|
if 'context' in impl_params:
|
||||||
|
args['context'] = context
|
||||||
|
if 'request' in impl_params:
|
||||||
|
args['request'] = request
|
||||||
|
result = await implementation(**args)
|
||||||
|
resp = web.json_response(
|
||||||
|
serializer.serialize(def_ret_ann, result),
|
||||||
|
headers={'x-status': 'ok'})
|
||||||
|
except Exception as exc:
|
||||||
|
resp = web.Response(
|
||||||
|
status=500,
|
||||||
|
headers={
|
||||||
|
'x-status': 'error',
|
||||||
|
'x-error-type': type(exc).__name__,
|
||||||
|
'x-error-msg': str(exc),
|
||||||
|
})
|
||||||
|
resp['exception'] = exc
|
||||||
|
context.description = '{} {}'.format(resp.status, trim(resp.text))
|
||||||
|
return resp
|
||||||
|
handler.controller = controller
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
async def controller_for_request(request):
|
||||||
|
match_info = await request.app.router.resolve(request)
|
||||||
|
return getattr(match_info.handler, 'controller', None)
|
||||||
|
|
||||||
|
|
||||||
|
def bind(router, endpoint, controller, serializer=None, _depth=None):
|
||||||
|
if serializer is None:
|
||||||
|
serializer = Serializer()
|
||||||
|
if _depth is None:
|
||||||
|
_depth = len(endpoint.fullname)
|
||||||
|
|
||||||
|
for v in endpoint.__dict__.values():
|
||||||
|
if isinstance(v, type):
|
||||||
|
bind(router, v, controller, serializer, _depth)
|
||||||
|
elif callable(v):
|
||||||
|
method = v.__name__
|
||||||
|
impl_name = "_".join(endpoint.fullname[_depth:] + (method,))
|
||||||
|
if not hasattr(controller, impl_name):
|
||||||
|
raise MissingImplementationError(controller, impl_name)
|
||||||
|
impl = getattr(controller, impl_name)
|
||||||
|
router.add_route(
|
||||||
|
method=method,
|
||||||
|
path=endpoint.fullpath,
|
||||||
|
handler=_make_handler(controller, v, impl, serializer))
|
||||||
|
|
||||||
|
|
||||||
|
async def make_server_at_path(socket_path, endpoint, controller):
|
||||||
|
app = web.Application()
|
||||||
|
bind(app.router, endpoint, controller)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.UnixSite(runner, socket_path)
|
||||||
|
await site.start()
|
||||||
|
return site
|
|
@ -0,0 +1,14 @@
|
||||||
|
# Copyright 2020 Canonical, Ltd.
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright 2020 Canonical, Ltd.
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from subiquitycore import contextlib38
|
||||||
|
|
||||||
|
from subiquity.common.api.client import make_client
|
||||||
|
from subiquity.common.api.defs import api, Payload
|
||||||
|
|
||||||
|
|
||||||
|
def extract(c):
|
||||||
|
try:
|
||||||
|
c.__await__().send(None)
|
||||||
|
except StopIteration as s:
|
||||||
|
return s.value
|
||||||
|
else:
|
||||||
|
raise AssertionError("coroutine not done")
|
||||||
|
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def json(self):
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
|
||||||
|
class TestClient(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
class endpoint:
|
||||||
|
def GET() -> str: ...
|
||||||
|
def POST(data: Payload[str]) -> None: ...
|
||||||
|
|
||||||
|
@contextlib38.asynccontextmanager
|
||||||
|
async def make_request(method, path, *, params, json):
|
||||||
|
requests.append((method, path, params, json))
|
||||||
|
if method == "GET":
|
||||||
|
v = 'value'
|
||||||
|
else:
|
||||||
|
v = None
|
||||||
|
yield FakeResponse(v)
|
||||||
|
|
||||||
|
client = make_client(API, make_request)
|
||||||
|
|
||||||
|
requests = []
|
||||||
|
r = extract(client.endpoint.GET())
|
||||||
|
self.assertEqual(r, 'value')
|
||||||
|
self.assertEqual(requests, [("GET", '/endpoint', {}, None)])
|
||||||
|
|
||||||
|
requests = []
|
||||||
|
r = extract(client.endpoint.POST('value'))
|
||||||
|
self.assertEqual(r, None)
|
||||||
|
self.assertEqual(
|
||||||
|
requests, [("POST", '/endpoint', {}, 'value')])
|
||||||
|
|
||||||
|
def test_args(self):
|
||||||
|
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET(arg: str) -> str: ...
|
||||||
|
|
||||||
|
@contextlib38.asynccontextmanager
|
||||||
|
async def make_request(method, path, *, params, json):
|
||||||
|
requests.append((method, path, params, json))
|
||||||
|
yield FakeResponse(params['arg'])
|
||||||
|
|
||||||
|
client = make_client(API, make_request)
|
||||||
|
|
||||||
|
requests = []
|
||||||
|
r = extract(client.GET(arg='v'))
|
||||||
|
self.assertEqual(r, '"v"')
|
||||||
|
self.assertEqual(requests, [("GET", '/', {'arg': '"v"'}, None)])
|
|
@ -0,0 +1,264 @@
|
||||||
|
# Copyright 2020 Canonical, Ltd.
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import attr
|
||||||
|
import functools
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from subiquitycore import contextlib38
|
||||||
|
|
||||||
|
from subiquity.common.api.client import make_client
|
||||||
|
from subiquity.common.api.defs import api, Payload
|
||||||
|
|
||||||
|
from .test_server import (
|
||||||
|
makeTestClient,
|
||||||
|
run_coro,
|
||||||
|
TestControllerBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_request(client, method, path, *, params, json):
|
||||||
|
return client.request(
|
||||||
|
method, path, params=params, json=json)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib38.asynccontextmanager
|
||||||
|
async def makeE2EClient(api, impl,
|
||||||
|
*, middlewares=(), make_request=make_request):
|
||||||
|
async with makeTestClient(
|
||||||
|
api, impl, middlewares=middlewares) as client:
|
||||||
|
mr = functools.partial(make_request, client)
|
||||||
|
yield make_client(api, mr)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndToEnd(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET() -> str: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self) -> str:
|
||||||
|
return 'value'
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
|
self.assertEqual(await client.GET(), 'value')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_nested(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
class endpoint:
|
||||||
|
class nested:
|
||||||
|
def GET() -> str: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def endpoint_nested_GET(self) -> str:
|
||||||
|
return 'value'
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
|
self.assertEqual(await client.endpoint.nested.GET(), 'value')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_args(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET(arg1: str, arg2: str) -> str: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self, arg1: str, arg2: str) -> str:
|
||||||
|
return '{}+{}'.format(arg1, arg2)
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
|
self.assertEqual(
|
||||||
|
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_defaults(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET(arg1: str, arg2: str = "arg2") -> str: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self, arg1: str, arg2: str = "arg2") -> str:
|
||||||
|
return '{}+{}'.format(arg1, arg2)
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
|
self.assertEqual(
|
||||||
|
await client.GET(arg1="A", arg2="B"), 'A+B')
|
||||||
|
self.assertEqual(
|
||||||
|
await client.GET(arg1="A"), 'A+arg2')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_post(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def POST(data: Payload[dict]) -> str: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def POST(self, data: dict) -> str:
|
||||||
|
return data['key']
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
|
self.assertEqual(
|
||||||
|
await client.POST({'key': 'value'}), 'value')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_typed(self):
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
class In:
|
||||||
|
val: int
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
class Out:
|
||||||
|
doubled: int
|
||||||
|
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
class doubler:
|
||||||
|
def POST(data: In) -> Out: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def doubler_POST(self, data: In) -> Out:
|
||||||
|
return Out(doubled=data.val*2)
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
|
out = await client.doubler.POST(In(3))
|
||||||
|
self.assertEqual(out.doubled, 6)
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_middleware(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET() -> int: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self) -> int:
|
||||||
|
return 1/0
|
||||||
|
|
||||||
|
@web.middleware
|
||||||
|
async def middleware(request, handler):
|
||||||
|
return web.Response(
|
||||||
|
status=200,
|
||||||
|
headers={'x-status': 'skip'})
|
||||||
|
|
||||||
|
class Skip(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@contextlib38.asynccontextmanager
|
||||||
|
async def custom_make_request(client, method, path, *, params, json):
|
||||||
|
async with make_request(
|
||||||
|
client, method, path, params=params, json=json) as resp:
|
||||||
|
if resp.headers.get('x-status') == 'skip':
|
||||||
|
raise Skip
|
||||||
|
yield resp
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeE2EClient(
|
||||||
|
API, Impl(),
|
||||||
|
middlewares=[middleware],
|
||||||
|
make_request=custom_make_request) as client:
|
||||||
|
with self.assertRaises(Skip):
|
||||||
|
await client.GET()
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_error(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
class good:
|
||||||
|
def GET(x: int) -> int: ...
|
||||||
|
|
||||||
|
class bad:
|
||||||
|
def GET(x: int) -> int: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def good_GET(self, x: int) -> int:
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
async def bad_GET(self, x: int) -> int:
|
||||||
|
raise Exception("baz")
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeE2EClient(API, Impl()) as client:
|
||||||
|
r = await client.good.GET(2)
|
||||||
|
self.assertEqual(r, 3)
|
||||||
|
with self.assertRaises(aiohttp.ClientResponseError):
|
||||||
|
await client.bad.GET(2)
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_error_middleware(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
class good:
|
||||||
|
def GET(x: int) -> int: ...
|
||||||
|
|
||||||
|
class bad:
|
||||||
|
def GET(x: int) -> int: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def good_GET(self, x: int) -> int:
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
async def bad_GET(self, x: int) -> int:
|
||||||
|
1/0
|
||||||
|
|
||||||
|
@web.middleware
|
||||||
|
async def middleware(request, handler):
|
||||||
|
resp = await handler(request)
|
||||||
|
if resp.get('exception'):
|
||||||
|
resp.headers['x-status'] = 'ERROR'
|
||||||
|
return resp
|
||||||
|
|
||||||
|
class Abort(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@contextlib38.asynccontextmanager
|
||||||
|
async def custom_make_request(client, method, path, *, params, json):
|
||||||
|
async with make_request(
|
||||||
|
client, method, path, params=params, json=json) as resp:
|
||||||
|
if resp.headers.get('x-status') == 'ERROR':
|
||||||
|
raise Abort
|
||||||
|
yield resp
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeE2EClient(
|
||||||
|
API, Impl(),
|
||||||
|
middlewares=[middleware],
|
||||||
|
make_request=custom_make_request) as client:
|
||||||
|
r = await client.good.GET(2)
|
||||||
|
self.assertEqual(r, 3)
|
||||||
|
with self.assertRaises(Abort):
|
||||||
|
await client.bad.GET(2)
|
||||||
|
|
||||||
|
run_coro(run())
|
|
@ -0,0 +1,260 @@
|
||||||
|
# Copyright 2020 Canonical, Ltd.
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from aiohttp.test_utils import TestClient, TestServer
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from subiquitycore.context import Context
|
||||||
|
from subiquitycore import contextlib38
|
||||||
|
|
||||||
|
from subiquity.common.api.defs import api, Payload
|
||||||
|
from subiquity.common.api.server import (
|
||||||
|
bind,
|
||||||
|
controller_for_request,
|
||||||
|
MissingImplementationError,
|
||||||
|
SignatureMisatchError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_coro(coro):
|
||||||
|
asyncio.get_event_loop().run_until_complete(coro)
|
||||||
|
|
||||||
|
|
||||||
|
class TestApp:
|
||||||
|
|
||||||
|
def report_start_event(self, context, description):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def report_finish_event(self, context, description, result):
|
||||||
|
pass
|
||||||
|
|
||||||
|
project = 'test'
|
||||||
|
|
||||||
|
|
||||||
|
class TestControllerBase:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.context = Context.new(TestApp())
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib38.asynccontextmanager
|
||||||
|
async def makeTestClient(api, impl, middlewares=()):
|
||||||
|
app = web.Application(middlewares=middlewares)
|
||||||
|
bind(app.router, api, impl)
|
||||||
|
async with TestClient(TestServer(app)) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
class TestBind(unittest.TestCase):
|
||||||
|
|
||||||
|
async def assertResponse(self, coro, value):
|
||||||
|
resp = await coro
|
||||||
|
self.assertEqual(resp.status, 200)
|
||||||
|
self.assertEqual(await resp.json(), value)
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET() -> str: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self) -> str:
|
||||||
|
return 'value'
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeTestClient(API, Impl()) as client:
|
||||||
|
await self.assertResponse(
|
||||||
|
client.get("/"), 'value')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_nested(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
class endpoint:
|
||||||
|
class nested:
|
||||||
|
def get(): ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def nested_get(self, request, context):
|
||||||
|
return 'nested'
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeTestClient(API.endpoint, Impl()) as client:
|
||||||
|
await self.assertResponse(
|
||||||
|
client.get("/endpoint/nested"), 'nested')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_args(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET(arg: str): ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self, arg: str):
|
||||||
|
return arg
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeTestClient(API, Impl()) as client:
|
||||||
|
await self.assertResponse(
|
||||||
|
client.get('/?arg="whut"'), 'whut')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_missing_argument(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET(arg: str): ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self, arg: str):
|
||||||
|
return arg
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeTestClient(API, Impl()) as client:
|
||||||
|
resp = await client.get('/')
|
||||||
|
self.assertEqual(resp.status, 500)
|
||||||
|
self.assertEqual(resp.headers['x-status'], 'error')
|
||||||
|
self.assertEqual(resp.headers['x-error-type'], 'TypeError')
|
||||||
|
self.assertEqual(
|
||||||
|
resp.headers['x-error-msg'],
|
||||||
|
'missing required argument "arg"')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_error(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET(): ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self):
|
||||||
|
return 1/0
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeTestClient(API, Impl()) as client:
|
||||||
|
resp = await client.get('/')
|
||||||
|
self.assertEqual(resp.status, 500)
|
||||||
|
self.assertEqual(resp.headers['x-status'], 'error')
|
||||||
|
self.assertEqual(
|
||||||
|
resp.headers['x-error-type'], 'ZeroDivisionError')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_post(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def POST(data: Payload[str]) -> str: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def POST(self, data: str) -> str:
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeTestClient(API, Impl()) as client:
|
||||||
|
await self.assertResponse(
|
||||||
|
client.post("/", json='value'),
|
||||||
|
'value')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_missing_method(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET(arg: str): ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
with self.assertRaises(MissingImplementationError) as cm:
|
||||||
|
bind(app.router, API, Impl())
|
||||||
|
self.assertEqual(cm.exception.methname, "GET")
|
||||||
|
|
||||||
|
def test_signature_checking(self):
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET(arg: str): ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self, arg: int):
|
||||||
|
return arg
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
with self.assertRaises(SignatureMisatchError) as cm:
|
||||||
|
bind(app.router, API, Impl())
|
||||||
|
self.assertEqual(cm.exception.methname, "API.GET")
|
||||||
|
|
||||||
|
def test_middleware(self):
|
||||||
|
|
||||||
|
@web.middleware
|
||||||
|
async def middleware(request, handler):
|
||||||
|
resp = await handler(request)
|
||||||
|
exc = resp.get('exception')
|
||||||
|
if resp.get('exception') is not None:
|
||||||
|
resp.headers['x-error-type'] = type(exc).__name__
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
def GET() -> str: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self) -> str:
|
||||||
|
return 1/0
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeTestClient(
|
||||||
|
API, Impl(), middlewares=[middleware]) as client:
|
||||||
|
resp = await client.get("/")
|
||||||
|
self.assertEqual(
|
||||||
|
resp.headers['x-error-type'], 'ZeroDivisionError')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
def test_controller_for_request(self):
|
||||||
|
|
||||||
|
seen_controller = None
|
||||||
|
|
||||||
|
@web.middleware
|
||||||
|
async def middleware(request, handler):
|
||||||
|
nonlocal seen_controller
|
||||||
|
seen_controller = await controller_for_request(request)
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
@api
|
||||||
|
class API:
|
||||||
|
class meth:
|
||||||
|
def GET() -> str: ...
|
||||||
|
|
||||||
|
class Impl(TestControllerBase):
|
||||||
|
async def GET(self) -> str:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
impl = Impl()
|
||||||
|
|
||||||
|
async def run():
|
||||||
|
async with makeTestClient(
|
||||||
|
API.meth, impl, middlewares=[middleware]) as client:
|
||||||
|
resp = await client.get("/meth")
|
||||||
|
self.assertEqual(await resp.json(), '')
|
||||||
|
|
||||||
|
run_coro(run())
|
||||||
|
|
||||||
|
self.assertIs(impl, seen_controller)
|
|
@ -0,0 +1,126 @@
|
||||||
|
# Copyright 2020 Canonical, Ltd.
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import enum
|
||||||
|
import inspect
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
# This is basically a half-assed version of # https://pypi.org/project/cattrs/
|
||||||
|
# but that's not packaged and this is enough for our needs.
|
||||||
|
|
||||||
|
|
||||||
|
class Serializer:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.typing_walkers = {
|
||||||
|
typing.Union: self._walk_Union,
|
||||||
|
list: self._walk_List,
|
||||||
|
typing.List: self._walk_List,
|
||||||
|
}
|
||||||
|
self.type_serializers = {}
|
||||||
|
self.type_deserializers = {}
|
||||||
|
for typ in int, str, dict, bool, list, type(None):
|
||||||
|
self.type_serializers[typ] = self._scalar
|
||||||
|
self.type_deserializers[typ] = self._scalar
|
||||||
|
self.type_serializers[datetime.datetime] = self._serialize_datetime
|
||||||
|
self.type_deserializers[datetime.datetime] = self._deserialize_datetime
|
||||||
|
|
||||||
|
def _scalar(self, annotation, value, metadata):
|
||||||
|
assert type(value) is annotation, "{} is not a {}".format(
|
||||||
|
value, annotation)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _walk_Union(self, meth, args, value, metadata):
|
||||||
|
NoneType = type(None)
|
||||||
|
assert NoneType in args, "can only serialize Optional"
|
||||||
|
args = [a for a in args if a is not NoneType]
|
||||||
|
assert len(args) == 1, "can only serialize Optional"
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return meth(args[0], value, metadata)
|
||||||
|
|
||||||
|
def _walk_List(self, meth, args, value, metadata):
|
||||||
|
return [meth(args[0], v, metadata) for v in value]
|
||||||
|
|
||||||
|
def _serialize_datetime(self, annotation, value, metadata):
|
||||||
|
assert type(value) is annotation
|
||||||
|
if metadata is not None and 'time_fmt' in metadata:
|
||||||
|
return value.strftime(metadata['time_fmt'])
|
||||||
|
else:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
def _serialize_field(self, field, value):
|
||||||
|
return {field.name: self.serialize(field.type, value, field.metadata)}
|
||||||
|
|
||||||
|
def _serialize_attr(self, annotation, value, metadata):
|
||||||
|
r = {}
|
||||||
|
for field in attr.fields(annotation):
|
||||||
|
r.update(self._serialize_field(field, getattr(value, field.name)))
|
||||||
|
return r
|
||||||
|
|
||||||
|
def serialize(self, annotation, value, metadata=None):
|
||||||
|
if annotation is None:
|
||||||
|
assert value is None
|
||||||
|
return None
|
||||||
|
if annotation is inspect.Signature.empty:
|
||||||
|
return value
|
||||||
|
if attr.has(annotation):
|
||||||
|
return self._serialize_attr(annotation, value, metadata)
|
||||||
|
origin = getattr(annotation, '__origin__', None)
|
||||||
|
if origin is not None:
|
||||||
|
args = annotation.__args__
|
||||||
|
return self.typing_walkers[origin](
|
||||||
|
self.serialize, args, value, metadata)
|
||||||
|
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
|
||||||
|
return value.name
|
||||||
|
return self.type_serializers[annotation](annotation, value, metadata)
|
||||||
|
|
||||||
|
def _deserialize_datetime(self, annotation, value, metadata):
|
||||||
|
assert type(value) is str
|
||||||
|
if metadata is not None and 'time_fmt' in metadata:
|
||||||
|
return datetime.datetime.strptime(value, metadata['time_fmt'])
|
||||||
|
else:
|
||||||
|
1/0
|
||||||
|
|
||||||
|
def _deserialize_field(self, field, value):
|
||||||
|
return {
|
||||||
|
field.name: self.deserialize(field.type, value, field.metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _deserialize_attr(self, annotation, value, metadata):
|
||||||
|
args = {}
|
||||||
|
for field in attr.fields(annotation):
|
||||||
|
args.update(self._deserialize_field(field, value[field.name]))
|
||||||
|
return annotation(**args)
|
||||||
|
|
||||||
|
def deserialize(self, annotation, value, metadata=None):
|
||||||
|
if annotation is None:
|
||||||
|
assert value is None
|
||||||
|
return None
|
||||||
|
if annotation is inspect.Signature.empty:
|
||||||
|
return value
|
||||||
|
if attr.has(annotation):
|
||||||
|
return self._deserialize_attr(annotation, value, metadata)
|
||||||
|
origin = getattr(annotation, '__origin__', None)
|
||||||
|
if origin is not None:
|
||||||
|
args = annotation.__args__
|
||||||
|
return self.typing_walkers[origin](
|
||||||
|
self.deserialize, args, value, metadata)
|
||||||
|
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
|
||||||
|
return getattr(annotation, value)
|
||||||
|
return self.type_deserializers[annotation](annotation, value, metadata)
|
|
@ -0,0 +1,704 @@
|
||||||
|
"""Utilities for with-statement contexts. See PEP 343."""
|
||||||
|
import abc
|
||||||
|
import sys
|
||||||
|
import _collections_abc
|
||||||
|
from collections import deque
|
||||||
|
from functools import wraps
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
__all__ = ["asynccontextmanager", "contextmanager", "closing", "nullcontext",
|
||||||
|
"AbstractContextManager", "AbstractAsyncContextManager",
|
||||||
|
"AsyncExitStack", "ContextDecorator", "ExitStack",
|
||||||
|
"redirect_stdout", "redirect_stderr", "suppress"]
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractContextManager(abc.ABC):
|
||||||
|
|
||||||
|
"""An abstract base class for context managers."""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Return `self` upon entering the runtime context."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
"""Raise any exception triggered within the runtime context."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __subclasshook__(cls, C):
|
||||||
|
if cls is AbstractContextManager:
|
||||||
|
return _collections_abc._check_methods(C, "__enter__", "__exit__")
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractAsyncContextManager(abc.ABC):
|
||||||
|
|
||||||
|
"""An abstract base class for asynchronous context managers."""
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Return `self` upon entering the runtime context."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
"""Raise any exception triggered within the runtime context."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __subclasshook__(cls, C):
|
||||||
|
if cls is AbstractAsyncContextManager:
|
||||||
|
return _collections_abc._check_methods(C, "__aenter__",
|
||||||
|
"__aexit__")
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
class ContextDecorator(object):
|
||||||
|
"A base class or mixin that enables context managers to work as decorators."
|
||||||
|
|
||||||
|
def _recreate_cm(self):
|
||||||
|
"""Return a recreated instance of self.
|
||||||
|
|
||||||
|
Allows an otherwise one-shot context manager like
|
||||||
|
_GeneratorContextManager to support use as
|
||||||
|
a decorator via implicit recreation.
|
||||||
|
|
||||||
|
This is a private interface just for _GeneratorContextManager.
|
||||||
|
See issue #11647 for details.
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __call__(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
def inner(*args, **kwds):
|
||||||
|
with self._recreate_cm():
|
||||||
|
return func(*args, **kwds)
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
class _GeneratorContextManagerBase:
|
||||||
|
"""Shared functionality for @contextmanager and @asynccontextmanager."""
|
||||||
|
|
||||||
|
def __init__(self, func, args, kwds):
|
||||||
|
self.gen = func(*args, **kwds)
|
||||||
|
self.func, self.args, self.kwds = func, args, kwds
|
||||||
|
# Issue 19330: ensure context manager instances have good docstrings
|
||||||
|
doc = getattr(func, "__doc__", None)
|
||||||
|
if doc is None:
|
||||||
|
doc = type(self).__doc__
|
||||||
|
self.__doc__ = doc
|
||||||
|
# Unfortunately, this still doesn't provide good help output when
|
||||||
|
# inspecting the created context manager instances, since pydoc
|
||||||
|
# currently bypasses the instance docstring and shows the docstring
|
||||||
|
# for the class instead.
|
||||||
|
# See http://bugs.python.org/issue19404 for more details.
|
||||||
|
|
||||||
|
|
||||||
|
class _GeneratorContextManager(_GeneratorContextManagerBase,
|
||||||
|
AbstractContextManager,
|
||||||
|
ContextDecorator):
|
||||||
|
"""Helper for @contextmanager decorator."""
|
||||||
|
|
||||||
|
def _recreate_cm(self):
|
||||||
|
# _GCM instances are one-shot context managers, so the
|
||||||
|
# CM must be recreated each time a decorated function is
|
||||||
|
# called
|
||||||
|
return self.__class__(self.func, self.args, self.kwds)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# do not keep args and kwds alive unnecessarily
|
||||||
|
# they are only needed for recreation, which is not possible anymore
|
||||||
|
del self.args, self.kwds, self.func
|
||||||
|
try:
|
||||||
|
return next(self.gen)
|
||||||
|
except StopIteration:
|
||||||
|
raise RuntimeError("generator didn't yield") from None
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
if type is None:
|
||||||
|
try:
|
||||||
|
next(self.gen)
|
||||||
|
except StopIteration:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise RuntimeError("generator didn't stop")
|
||||||
|
else:
|
||||||
|
if value is None:
|
||||||
|
# Need to force instantiation so we can reliably
|
||||||
|
# tell if we get the same exception back
|
||||||
|
value = type()
|
||||||
|
try:
|
||||||
|
self.gen.throw(type, value, traceback)
|
||||||
|
except StopIteration as exc:
|
||||||
|
# Suppress StopIteration *unless* it's the same exception that
|
||||||
|
# was passed to throw(). This prevents a StopIteration
|
||||||
|
# raised inside the "with" statement from being suppressed.
|
||||||
|
return exc is not value
|
||||||
|
except RuntimeError as exc:
|
||||||
|
# Don't re-raise the passed in exception. (issue27122)
|
||||||
|
if exc is value:
|
||||||
|
return False
|
||||||
|
# Likewise, avoid suppressing if a StopIteration exception
|
||||||
|
# was passed to throw() and later wrapped into a RuntimeError
|
||||||
|
# (see PEP 479).
|
||||||
|
if type is StopIteration and exc.__cause__ is value:
|
||||||
|
return False
|
||||||
|
raise
|
||||||
|
except:
|
||||||
|
# only re-raise if it's *not* the exception that was
|
||||||
|
# passed to throw(), because __exit__() must not raise
|
||||||
|
# an exception unless __exit__() itself failed. But throw()
|
||||||
|
# has to raise the exception to signal propagation, so this
|
||||||
|
# fixes the impedance mismatch between the throw() protocol
|
||||||
|
# and the __exit__() protocol.
|
||||||
|
#
|
||||||
|
# This cannot use 'except BaseException as exc' (as in the
|
||||||
|
# async implementation) to maintain compatibility with
|
||||||
|
# Python 2, where old-style class exceptions are not caught
|
||||||
|
# by 'except BaseException'.
|
||||||
|
if sys.exc_info()[1] is value:
|
||||||
|
return False
|
||||||
|
raise
|
||||||
|
raise RuntimeError("generator didn't stop after throw()")
|
||||||
|
|
||||||
|
|
||||||
|
class _AsyncGeneratorContextManager(_GeneratorContextManagerBase,
|
||||||
|
AbstractAsyncContextManager):
|
||||||
|
"""Helper for @asynccontextmanager."""
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
try:
|
||||||
|
return await self.gen.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise RuntimeError("generator didn't yield") from None
|
||||||
|
|
||||||
|
async def __aexit__(self, typ, value, traceback):
|
||||||
|
if typ is None:
|
||||||
|
try:
|
||||||
|
await self.gen.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise RuntimeError("generator didn't stop")
|
||||||
|
else:
|
||||||
|
if value is None:
|
||||||
|
value = typ()
|
||||||
|
# See _GeneratorContextManager.__exit__ for comments on subtleties
|
||||||
|
# in this implementation
|
||||||
|
try:
|
||||||
|
await self.gen.athrow(typ, value, traceback)
|
||||||
|
raise RuntimeError("generator didn't stop after athrow()")
|
||||||
|
except StopAsyncIteration as exc:
|
||||||
|
return exc is not value
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if exc is value:
|
||||||
|
return False
|
||||||
|
# Avoid suppressing if a StopIteration exception
|
||||||
|
# was passed to throw() and later wrapped into a RuntimeError
|
||||||
|
# (see PEP 479 for sync generators; async generators also
|
||||||
|
# have this behavior). But do this only if the exception wrapped
|
||||||
|
# by the RuntimeError is actully Stop(Async)Iteration (see
|
||||||
|
# issue29692).
|
||||||
|
if isinstance(value, (StopIteration, StopAsyncIteration)):
|
||||||
|
if exc.__cause__ is value:
|
||||||
|
return False
|
||||||
|
raise
|
||||||
|
except BaseException as exc:
|
||||||
|
if exc is not value:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def contextmanager(func):
|
||||||
|
"""@contextmanager decorator.
|
||||||
|
|
||||||
|
Typical usage:
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def some_generator(<arguments>):
|
||||||
|
<setup>
|
||||||
|
try:
|
||||||
|
yield <value>
|
||||||
|
finally:
|
||||||
|
<cleanup>
|
||||||
|
|
||||||
|
This makes this:
|
||||||
|
|
||||||
|
with some_generator(<arguments>) as <variable>:
|
||||||
|
<body>
|
||||||
|
|
||||||
|
equivalent to this:
|
||||||
|
|
||||||
|
<setup>
|
||||||
|
try:
|
||||||
|
<variable> = <value>
|
||||||
|
<body>
|
||||||
|
finally:
|
||||||
|
<cleanup>
|
||||||
|
"""
|
||||||
|
@wraps(func)
|
||||||
|
def helper(*args, **kwds):
|
||||||
|
return _GeneratorContextManager(func, args, kwds)
|
||||||
|
return helper
|
||||||
|
|
||||||
|
|
||||||
|
def asynccontextmanager(func):
|
||||||
|
"""@asynccontextmanager decorator.
|
||||||
|
|
||||||
|
Typical usage:
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def some_async_generator(<arguments>):
|
||||||
|
<setup>
|
||||||
|
try:
|
||||||
|
yield <value>
|
||||||
|
finally:
|
||||||
|
<cleanup>
|
||||||
|
|
||||||
|
This makes this:
|
||||||
|
|
||||||
|
async with some_async_generator(<arguments>) as <variable>:
|
||||||
|
<body>
|
||||||
|
|
||||||
|
equivalent to this:
|
||||||
|
|
||||||
|
<setup>
|
||||||
|
try:
|
||||||
|
<variable> = <value>
|
||||||
|
<body>
|
||||||
|
finally:
|
||||||
|
<cleanup>
|
||||||
|
"""
|
||||||
|
@wraps(func)
|
||||||
|
def helper(*args, **kwds):
|
||||||
|
return _AsyncGeneratorContextManager(func, args, kwds)
|
||||||
|
return helper
|
||||||
|
|
||||||
|
|
||||||
|
class closing(AbstractContextManager):
|
||||||
|
"""Context to automatically close something at the end of a block.
|
||||||
|
|
||||||
|
Code like this:
|
||||||
|
|
||||||
|
with closing(<module>.open(<arguments>)) as f:
|
||||||
|
<block>
|
||||||
|
|
||||||
|
is equivalent to this:
|
||||||
|
|
||||||
|
f = <module>.open(<arguments>)
|
||||||
|
try:
|
||||||
|
<block>
|
||||||
|
finally:
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, thing):
|
||||||
|
self.thing = thing
|
||||||
|
def __enter__(self):
|
||||||
|
return self.thing
|
||||||
|
def __exit__(self, *exc_info):
|
||||||
|
self.thing.close()
|
||||||
|
|
||||||
|
|
||||||
|
class _RedirectStream(AbstractContextManager):
|
||||||
|
|
||||||
|
_stream = None
|
||||||
|
|
||||||
|
def __init__(self, new_target):
|
||||||
|
self._new_target = new_target
|
||||||
|
# We use a list of old targets to make this CM re-entrant
|
||||||
|
self._old_targets = []
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._old_targets.append(getattr(sys, self._stream))
|
||||||
|
setattr(sys, self._stream, self._new_target)
|
||||||
|
return self._new_target
|
||||||
|
|
||||||
|
def __exit__(self, exctype, excinst, exctb):
|
||||||
|
setattr(sys, self._stream, self._old_targets.pop())
|
||||||
|
|
||||||
|
|
||||||
|
class redirect_stdout(_RedirectStream):
|
||||||
|
"""Context manager for temporarily redirecting stdout to another file.
|
||||||
|
|
||||||
|
# How to send help() to stderr
|
||||||
|
with redirect_stdout(sys.stderr):
|
||||||
|
help(dir)
|
||||||
|
|
||||||
|
# How to write help() to a file
|
||||||
|
with open('help.txt', 'w') as f:
|
||||||
|
with redirect_stdout(f):
|
||||||
|
help(pow)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_stream = "stdout"
|
||||||
|
|
||||||
|
|
||||||
|
class redirect_stderr(_RedirectStream):
|
||||||
|
"""Context manager for temporarily redirecting stderr to another file."""
|
||||||
|
|
||||||
|
_stream = "stderr"
|
||||||
|
|
||||||
|
|
||||||
|
class suppress(AbstractContextManager):
|
||||||
|
"""Context manager to suppress specified exceptions
|
||||||
|
|
||||||
|
After the exception is suppressed, execution proceeds with the next
|
||||||
|
statement following the with statement.
|
||||||
|
|
||||||
|
with suppress(FileNotFoundError):
|
||||||
|
os.remove(somefile)
|
||||||
|
# Execution still resumes here if the file was already removed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *exceptions):
|
||||||
|
self._exceptions = exceptions
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, exctype, excinst, exctb):
|
||||||
|
# Unlike isinstance and issubclass, CPython exception handling
|
||||||
|
# currently only looks at the concrete type hierarchy (ignoring
|
||||||
|
# the instance and subclass checking hooks). While Guido considers
|
||||||
|
# that a bug rather than a feature, it's a fairly hard one to fix
|
||||||
|
# due to various internal implementation details. suppress provides
|
||||||
|
# the simpler issubclass based semantics, rather than trying to
|
||||||
|
# exactly reproduce the limitations of the CPython interpreter.
|
||||||
|
#
|
||||||
|
# See http://bugs.python.org/issue12029 for more details
|
||||||
|
return exctype is not None and issubclass(exctype, self._exceptions)
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseExitStack:
|
||||||
|
"""A base class for ExitStack and AsyncExitStack."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_exit_wrapper(cm, cm_exit):
|
||||||
|
return MethodType(cm_exit, cm)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_cb_wrapper(callback, *args, **kwds):
|
||||||
|
def _exit_wrapper(exc_type, exc, tb):
|
||||||
|
callback(*args, **kwds)
|
||||||
|
return _exit_wrapper
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._exit_callbacks = deque()
|
||||||
|
|
||||||
|
def pop_all(self):
|
||||||
|
"""Preserve the context stack by transferring it to a new instance."""
|
||||||
|
new_stack = type(self)()
|
||||||
|
new_stack._exit_callbacks = self._exit_callbacks
|
||||||
|
self._exit_callbacks = deque()
|
||||||
|
return new_stack
|
||||||
|
|
||||||
|
def push(self, exit):
|
||||||
|
"""Registers a callback with the standard __exit__ method signature.
|
||||||
|
|
||||||
|
Can suppress exceptions the same way __exit__ method can.
|
||||||
|
Also accepts any object with an __exit__ method (registering a call
|
||||||
|
to the method instead of the object itself).
|
||||||
|
"""
|
||||||
|
# We use an unbound method rather than a bound method to follow
|
||||||
|
# the standard lookup behaviour for special methods.
|
||||||
|
_cb_type = type(exit)
|
||||||
|
|
||||||
|
try:
|
||||||
|
exit_method = _cb_type.__exit__
|
||||||
|
except AttributeError:
|
||||||
|
# Not a context manager, so assume it's a callable.
|
||||||
|
self._push_exit_callback(exit)
|
||||||
|
else:
|
||||||
|
self._push_cm_exit(exit, exit_method)
|
||||||
|
return exit # Allow use as a decorator.
|
||||||
|
|
||||||
|
def enter_context(self, cm):
|
||||||
|
"""Enters the supplied context manager.
|
||||||
|
|
||||||
|
If successful, also pushes its __exit__ method as a callback and
|
||||||
|
returns the result of the __enter__ method.
|
||||||
|
"""
|
||||||
|
# We look up the special methods on the type to match the with
|
||||||
|
# statement.
|
||||||
|
_cm_type = type(cm)
|
||||||
|
_exit = _cm_type.__exit__
|
||||||
|
result = _cm_type.__enter__(cm)
|
||||||
|
self._push_cm_exit(cm, _exit)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def callback(*args, **kwds):
|
||||||
|
"""Registers an arbitrary callback and arguments.
|
||||||
|
|
||||||
|
Cannot suppress exceptions.
|
||||||
|
"""
|
||||||
|
if len(args) >= 2:
|
||||||
|
self, callback, *args = args
|
||||||
|
elif not args:
|
||||||
|
raise TypeError("descriptor 'callback' of '_BaseExitStack' object "
|
||||||
|
"needs an argument")
|
||||||
|
elif 'callback' in kwds:
|
||||||
|
callback = kwds.pop('callback')
|
||||||
|
self, *args = args
|
||||||
|
import warnings
|
||||||
|
warnings.warn("Passing 'callback' as keyword argument is deprecated",
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
|
else:
|
||||||
|
raise TypeError('callback expected at least 1 positional argument, '
|
||||||
|
'got %d' % (len(args)-1))
|
||||||
|
|
||||||
|
_exit_wrapper = self._create_cb_wrapper(callback, *args, **kwds)
|
||||||
|
|
||||||
|
# We changed the signature, so using @wraps is not appropriate, but
|
||||||
|
# setting __wrapped__ may still help with introspection.
|
||||||
|
_exit_wrapper.__wrapped__ = callback
|
||||||
|
self._push_exit_callback(_exit_wrapper)
|
||||||
|
return callback # Allow use as a decorator
|
||||||
|
callback.__text_signature__ = '($self, callback, /, *args, **kwds)'
|
||||||
|
|
||||||
|
def _push_cm_exit(self, cm, cm_exit):
|
||||||
|
"""Helper to correctly register callbacks to __exit__ methods."""
|
||||||
|
_exit_wrapper = self._create_exit_wrapper(cm, cm_exit)
|
||||||
|
self._push_exit_callback(_exit_wrapper, True)
|
||||||
|
|
||||||
|
def _push_exit_callback(self, callback, is_sync=True):
|
||||||
|
self._exit_callbacks.append((is_sync, callback))
|
||||||
|
|
||||||
|
|
||||||
|
# Inspired by discussions on http://bugs.python.org/issue13585
|
||||||
|
class ExitStack(_BaseExitStack, AbstractContextManager):
|
||||||
|
"""Context manager for dynamic management of a stack of exit callbacks.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
with ExitStack() as stack:
|
||||||
|
files = [stack.enter_context(open(fname)) for fname in filenames]
|
||||||
|
# All opened files will automatically be closed at the end of
|
||||||
|
# the with statement, even if attempts to open files later
|
||||||
|
# in the list raise an exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc_details):
|
||||||
|
received_exc = exc_details[0] is not None
|
||||||
|
|
||||||
|
# We manipulate the exception state so it behaves as though
|
||||||
|
# we were actually nesting multiple with statements
|
||||||
|
frame_exc = sys.exc_info()[1]
|
||||||
|
def _fix_exception_context(new_exc, old_exc):
|
||||||
|
# Context may not be correct, so find the end of the chain
|
||||||
|
while 1:
|
||||||
|
exc_context = new_exc.__context__
|
||||||
|
if exc_context is old_exc:
|
||||||
|
# Context is already set correctly (see issue 20317)
|
||||||
|
return
|
||||||
|
if exc_context is None or exc_context is frame_exc:
|
||||||
|
break
|
||||||
|
new_exc = exc_context
|
||||||
|
# Change the end of the chain to point to the exception
|
||||||
|
# we expect it to reference
|
||||||
|
new_exc.__context__ = old_exc
|
||||||
|
|
||||||
|
# Callbacks are invoked in LIFO order to match the behaviour of
|
||||||
|
# nested context managers
|
||||||
|
suppressed_exc = False
|
||||||
|
pending_raise = False
|
||||||
|
while self._exit_callbacks:
|
||||||
|
is_sync, cb = self._exit_callbacks.pop()
|
||||||
|
assert is_sync
|
||||||
|
try:
|
||||||
|
if cb(*exc_details):
|
||||||
|
suppressed_exc = True
|
||||||
|
pending_raise = False
|
||||||
|
exc_details = (None, None, None)
|
||||||
|
except:
|
||||||
|
new_exc_details = sys.exc_info()
|
||||||
|
# simulate the stack of exceptions by setting the context
|
||||||
|
_fix_exception_context(new_exc_details[1], exc_details[1])
|
||||||
|
pending_raise = True
|
||||||
|
exc_details = new_exc_details
|
||||||
|
if pending_raise:
|
||||||
|
try:
|
||||||
|
# bare "raise exc_details[1]" replaces our carefully
|
||||||
|
# set-up context
|
||||||
|
fixed_ctx = exc_details[1].__context__
|
||||||
|
raise exc_details[1]
|
||||||
|
except BaseException:
|
||||||
|
exc_details[1].__context__ = fixed_ctx
|
||||||
|
raise
|
||||||
|
return received_exc and suppressed_exc
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Immediately unwind the context stack."""
|
||||||
|
self.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
# Inspired by discussions on https://bugs.python.org/issue29302
|
||||||
|
class AsyncExitStack(_BaseExitStack, AbstractAsyncContextManager):
|
||||||
|
"""Async context manager for dynamic management of a stack of exit
|
||||||
|
callbacks.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
async with AsyncExitStack() as stack:
|
||||||
|
connections = [await stack.enter_async_context(get_connection())
|
||||||
|
for i in range(5)]
|
||||||
|
# All opened connections will automatically be released at the
|
||||||
|
# end of the async with statement, even if attempts to open a
|
||||||
|
# connection later in the list raise an exception.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_async_exit_wrapper(cm, cm_exit):
|
||||||
|
return MethodType(cm_exit, cm)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_async_cb_wrapper(callback, *args, **kwds):
|
||||||
|
async def _exit_wrapper(exc_type, exc, tb):
|
||||||
|
await callback(*args, **kwds)
|
||||||
|
return _exit_wrapper
|
||||||
|
|
||||||
|
async def enter_async_context(self, cm):
|
||||||
|
"""Enters the supplied async context manager.
|
||||||
|
|
||||||
|
If successful, also pushes its __aexit__ method as a callback and
|
||||||
|
returns the result of the __aenter__ method.
|
||||||
|
"""
|
||||||
|
_cm_type = type(cm)
|
||||||
|
_exit = _cm_type.__aexit__
|
||||||
|
result = await _cm_type.__aenter__(cm)
|
||||||
|
self._push_async_cm_exit(cm, _exit)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def push_async_exit(self, exit):
|
||||||
|
"""Registers a coroutine function with the standard __aexit__ method
|
||||||
|
signature.
|
||||||
|
|
||||||
|
Can suppress exceptions the same way __aexit__ method can.
|
||||||
|
Also accepts any object with an __aexit__ method (registering a call
|
||||||
|
to the method instead of the object itself).
|
||||||
|
"""
|
||||||
|
_cb_type = type(exit)
|
||||||
|
try:
|
||||||
|
exit_method = _cb_type.__aexit__
|
||||||
|
except AttributeError:
|
||||||
|
# Not an async context manager, so assume it's a coroutine function
|
||||||
|
self._push_exit_callback(exit, False)
|
||||||
|
else:
|
||||||
|
self._push_async_cm_exit(exit, exit_method)
|
||||||
|
return exit # Allow use as a decorator
|
||||||
|
|
||||||
|
def push_async_callback(*args, **kwds):
|
||||||
|
"""Registers an arbitrary coroutine function and arguments.
|
||||||
|
|
||||||
|
Cannot suppress exceptions.
|
||||||
|
"""
|
||||||
|
if len(args) >= 2:
|
||||||
|
self, callback, *args = args
|
||||||
|
elif not args:
|
||||||
|
raise TypeError("descriptor 'push_async_callback' of "
|
||||||
|
"'AsyncExitStack' object needs an argument")
|
||||||
|
elif 'callback' in kwds:
|
||||||
|
callback = kwds.pop('callback')
|
||||||
|
self, *args = args
|
||||||
|
import warnings
|
||||||
|
warnings.warn("Passing 'callback' as keyword argument is deprecated",
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
|
else:
|
||||||
|
raise TypeError('push_async_callback expected at least 1 '
|
||||||
|
'positional argument, got %d' % (len(args)-1))
|
||||||
|
|
||||||
|
_exit_wrapper = self._create_async_cb_wrapper(callback, *args, **kwds)
|
||||||
|
|
||||||
|
# We changed the signature, so using @wraps is not appropriate, but
|
||||||
|
# setting __wrapped__ may still help with introspection.
|
||||||
|
_exit_wrapper.__wrapped__ = callback
|
||||||
|
self._push_exit_callback(_exit_wrapper, False)
|
||||||
|
return callback # Allow use as a decorator
|
||||||
|
push_async_callback.__text_signature__ = '($self, callback, /, *args, **kwds)'
|
||||||
|
|
||||||
|
async def aclose(self):
|
||||||
|
"""Immediately unwind the context stack."""
|
||||||
|
await self.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
def _push_async_cm_exit(self, cm, cm_exit):
|
||||||
|
"""Helper to correctly register coroutine function to __aexit__
|
||||||
|
method."""
|
||||||
|
_exit_wrapper = self._create_async_exit_wrapper(cm, cm_exit)
|
||||||
|
self._push_exit_callback(_exit_wrapper, False)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc_details):
|
||||||
|
received_exc = exc_details[0] is not None
|
||||||
|
|
||||||
|
# We manipulate the exception state so it behaves as though
|
||||||
|
# we were actually nesting multiple with statements
|
||||||
|
frame_exc = sys.exc_info()[1]
|
||||||
|
def _fix_exception_context(new_exc, old_exc):
|
||||||
|
# Context may not be correct, so find the end of the chain
|
||||||
|
while 1:
|
||||||
|
exc_context = new_exc.__context__
|
||||||
|
if exc_context is old_exc:
|
||||||
|
# Context is already set correctly (see issue 20317)
|
||||||
|
return
|
||||||
|
if exc_context is None or exc_context is frame_exc:
|
||||||
|
break
|
||||||
|
new_exc = exc_context
|
||||||
|
# Change the end of the chain to point to the exception
|
||||||
|
# we expect it to reference
|
||||||
|
new_exc.__context__ = old_exc
|
||||||
|
|
||||||
|
# Callbacks are invoked in LIFO order to match the behaviour of
|
||||||
|
# nested context managers
|
||||||
|
suppressed_exc = False
|
||||||
|
pending_raise = False
|
||||||
|
while self._exit_callbacks:
|
||||||
|
is_sync, cb = self._exit_callbacks.pop()
|
||||||
|
try:
|
||||||
|
if is_sync:
|
||||||
|
cb_suppress = cb(*exc_details)
|
||||||
|
else:
|
||||||
|
cb_suppress = await cb(*exc_details)
|
||||||
|
|
||||||
|
if cb_suppress:
|
||||||
|
suppressed_exc = True
|
||||||
|
pending_raise = False
|
||||||
|
exc_details = (None, None, None)
|
||||||
|
except:
|
||||||
|
new_exc_details = sys.exc_info()
|
||||||
|
# simulate the stack of exceptions by setting the context
|
||||||
|
_fix_exception_context(new_exc_details[1], exc_details[1])
|
||||||
|
pending_raise = True
|
||||||
|
exc_details = new_exc_details
|
||||||
|
if pending_raise:
|
||||||
|
try:
|
||||||
|
# bare "raise exc_details[1]" replaces our carefully
|
||||||
|
# set-up context
|
||||||
|
fixed_ctx = exc_details[1].__context__
|
||||||
|
raise exc_details[1]
|
||||||
|
except BaseException:
|
||||||
|
exc_details[1].__context__ = fixed_ctx
|
||||||
|
raise
|
||||||
|
return received_exc and suppressed_exc
|
||||||
|
|
||||||
|
|
||||||
|
class nullcontext(AbstractContextManager):
|
||||||
|
"""Context manager that does no additional processing.
|
||||||
|
|
||||||
|
Used as a stand-in for a normal context manager, when a particular
|
||||||
|
block of code is only sometimes used with a normal context manager:
|
||||||
|
|
||||||
|
cm = optional_cm if condition else nullcontext()
|
||||||
|
with cm:
|
||||||
|
# Perform operation, using optional_cm if condition is True
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, enter_result=None):
|
||||||
|
self.enter_result = enter_result
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self.enter_result
|
||||||
|
|
||||||
|
def __exit__(self, *excinfo):
|
||||||
|
pass
|
Loading…
Reference in New Issue