Merge pull request #823 from mwhudson/add-api

add api definition and support code for client server split
This commit is contained in:
Michael Hudson-Doyle 2020-09-18 09:53:44 +12:00 committed by GitHub
commit 04e5cf95b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1784 additions and 1 deletions

View File

@ -50,7 +50,7 @@ lint: flake8
flake8: flake8:
@echo 'tox -e flake8' is preferred to 'make flake8' @echo 'tox -e flake8' is preferred to 'make flake8'
$(PYTHON) -m flake8 $(CHECK_DIRS) --exclude gettext38.py $(PYTHON) -m flake8 $(CHECK_DIRS) --exclude gettext38.py,contextlib38.py
unit: unit:
echo "Running unit tests..." echo "Running unit tests..."

View File

@ -11,4 +11,5 @@ jsonschema
pyudev pyudev
requests requests
requests-unixsocket requests-unixsocket
aiohttp
-e git+https://github.com/CanonicalLtd/probert@b697ab779e7e056301e779f4708a9f1ce51b0027#egg=probert -e git+https://github.com/CanonicalLtd/probert@b697ab779e7e056301e779f4708a9f1ce51b0027#egg=probert

View File

@ -0,0 +1,14 @@
# Copyright 2020 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

View File

@ -0,0 +1,89 @@
# Copyright 2020 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import inspect
import json
import aiohttp
from subiquitycore import contextlib38
from subiquity.common.serialize import Serializer
from .defs import Payload
def _wrap(make_request, path, meth, serializer):
sig = inspect.signature(meth)
meth_params = sig.parameters
payload_arg = None
for name, param in meth_params.items():
if getattr(param.annotation, '__origin__', None) is Payload:
payload_arg = name
payload_ann = param.annotation.__args__[0]
r_ann = sig.return_annotation
async def impl(*args, **kw):
args = sig.bind(*args, **kw)
query_args = {}
data = None
for arg_name, value in args.arguments.items():
if arg_name == payload_arg:
data = serializer.serialize(payload_ann, value)
else:
query_args[arg_name] = json.dumps(
serializer.serialize(
meth_params[arg_name].annotation, value))
async with make_request(
meth.__name__, path, json=data, params=query_args) as resp:
resp.raise_for_status()
return serializer.deserialize(r_ann, await resp.json())
return impl
def make_client(endpoint_cls, make_request, serializer=None):
if serializer is None:
serializer = Serializer()
class C:
pass
for k, v in endpoint_cls.__dict__.items():
if isinstance(v, type):
setattr(C, k, make_client(v, make_request, serializer))
elif callable(v):
setattr(C, k, _wrap(
make_request, endpoint_cls.fullpath, v, serializer))
return C
def make_client_for_conn(
endpoint_cls, conn, resp_hook=lambda r: r, serializer=None):
@contextlib38.asynccontextmanager
async def make_request(method, path, *, params, json):
async with aiohttp.ClientSession(
connector=conn, connector_owner=False) as session:
# session.request needs a full URL with scheme and host
# even though that's in some ways a bit silly with a unix
# socket, so we just hardcode something here (I guess the
# "a" gets sent a long to the server inthe Host: header
# and the server could in principle do something like
# virtual host based selection but well....)
url = 'http://a' + path
async with session.request(
method, url, json=json, params=params,
timeout=0) as response:
yield resp_hook(response)
return make_client(endpoint_cls, make_request, serializer)

View File

@ -0,0 +1,42 @@
# Copyright 2020 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import typing
def api(cls, prefix=(), foo=None):
cls.fullpath = '/' + '/'.join(prefix)
cls.fullname = prefix
for k, v in cls.__dict__.items():
if isinstance(v, type):
v.__name__ = cls.__name__ + '.' + k
api(v, prefix + (k,))
if callable(v):
v.__qualname__ = cls.__name__ + '.' + k
return cls
T = typing.TypeVar("T")
class Payload(typing.Generic[T]):
pass
def simple_endpoint(typ):
class endpoint:
def GET() -> typ: ...
def POST(data: Payload[typ]): ...
return endpoint

View File

@ -0,0 +1,177 @@
# Copyright 2020 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import inspect
import json
from aiohttp import web
from subiquity.common.serialize import Serializer
from .defs import Payload
class BindError(Exception):
pass
class MissingImplementationError(BindError):
def __init__(self, controller, methname):
self.controller = controller
self.methname = methname
def __str__(self):
return f"{self.controller} must have method called {self.methname}"
class SignatureMisatchError(BindError):
def __init__(self, methname, expected, actual):
self.methname = methname
self.expected = expected
self.actual = actual
def __str__(self):
return (f"implementation of {self.methname} has wrong signature, "
f"should be {self.expected} but is {self.actual}")
def trim(text):
if text is None:
return ''
elif len(text) > 80:
return text[:77] + '...'
else:
return text
def _make_handler(controller, definition, implementation, serializer):
def_sig = inspect.signature(definition)
def_ret_ann = def_sig.return_annotation
def_params = def_sig.parameters
impl_sig = inspect.signature(implementation)
impl_params = impl_sig.parameters
data_annotation = None
data_arg = None
query_args_anns = []
check_def_params = []
for param_name, param in def_params.items():
if param_name in ('request', 'context'):
raise Exception(
"api method {} cannot have parameter called request or "
"context".format(definition))
if getattr(param.annotation, '__origin__', None) is Payload:
data_arg = param_name
data_annotation = param.annotation.__args__[0]
check_def_params.append(param.replace(annotation=data_annotation))
else:
query_args_anns.append(
(param_name, param.annotation, param.default))
check_def_params.append(param)
check_impl_params = [
p for p in impl_params.values()
if p.name not in ('context', 'request')
]
check_impl_sig = impl_sig.replace(parameters=check_impl_params)
check_def_sig = def_sig.replace(parameters=check_def_params)
if check_impl_sig != check_def_sig:
raise SignatureMisatchError(
definition.__qualname__, check_def_sig, check_impl_sig)
async def handler(request):
context = controller.context.child(
implementation.__name__, trim(await request.text()))
with context:
context.set('request', request)
args = {}
try:
if data_annotation is not None:
payload = json.loads(await request.text())
args[data_arg] = serializer.deserialize(
data_annotation, payload)
for arg, ann, default in query_args_anns:
if arg in request.query:
v = serializer.deserialize(
ann, json.loads(request.query[arg]))
elif default != inspect._empty:
v = default
else:
raise TypeError(
'missing required argument "{}"'.format(arg))
args[arg] = v
if 'context' in impl_params:
args['context'] = context
if 'request' in impl_params:
args['request'] = request
result = await implementation(**args)
resp = web.json_response(
serializer.serialize(def_ret_ann, result),
headers={'x-status': 'ok'})
except Exception as exc:
resp = web.Response(
status=500,
headers={
'x-status': 'error',
'x-error-type': type(exc).__name__,
'x-error-msg': str(exc),
})
resp['exception'] = exc
context.description = '{} {}'.format(resp.status, trim(resp.text))
return resp
handler.controller = controller
return handler
async def controller_for_request(request):
match_info = await request.app.router.resolve(request)
return getattr(match_info.handler, 'controller', None)
def bind(router, endpoint, controller, serializer=None, _depth=None):
if serializer is None:
serializer = Serializer()
if _depth is None:
_depth = len(endpoint.fullname)
for v in endpoint.__dict__.values():
if isinstance(v, type):
bind(router, v, controller, serializer, _depth)
elif callable(v):
method = v.__name__
impl_name = "_".join(endpoint.fullname[_depth:] + (method,))
if not hasattr(controller, impl_name):
raise MissingImplementationError(controller, impl_name)
impl = getattr(controller, impl_name)
router.add_route(
method=method,
path=endpoint.fullpath,
handler=_make_handler(controller, v, impl, serializer))
async def make_server_at_path(socket_path, endpoint, controller):
app = web.Application()
bind(app.router, endpoint, controller)
runner = web.AppRunner(app)
await runner.setup()
site = web.UnixSite(runner, socket_path)
await site.start()
return site

View File

@ -0,0 +1,14 @@
# Copyright 2020 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

View File

@ -0,0 +1,92 @@
# Copyright 2020 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from subiquitycore import contextlib38
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import api, Payload
def extract(c):
try:
c.__await__().send(None)
except StopIteration as s:
return s.value
else:
raise AssertionError("coroutine not done")
class FakeResponse:
def __init__(self, data):
self.data = data
def raise_for_status(self):
pass
async def json(self):
return self.data
class TestClient(unittest.TestCase):
def test_simple(self):
@api
class API:
class endpoint:
def GET() -> str: ...
def POST(data: Payload[str]) -> None: ...
@contextlib38.asynccontextmanager
async def make_request(method, path, *, params, json):
requests.append((method, path, params, json))
if method == "GET":
v = 'value'
else:
v = None
yield FakeResponse(v)
client = make_client(API, make_request)
requests = []
r = extract(client.endpoint.GET())
self.assertEqual(r, 'value')
self.assertEqual(requests, [("GET", '/endpoint', {}, None)])
requests = []
r = extract(client.endpoint.POST('value'))
self.assertEqual(r, None)
self.assertEqual(
requests, [("POST", '/endpoint', {}, 'value')])
def test_args(self):
@api
class API:
def GET(arg: str) -> str: ...
@contextlib38.asynccontextmanager
async def make_request(method, path, *, params, json):
requests.append((method, path, params, json))
yield FakeResponse(params['arg'])
client = make_client(API, make_request)
requests = []
r = extract(client.GET(arg='v'))
self.assertEqual(r, '"v"')
self.assertEqual(requests, [("GET", '/', {'arg': '"v"'}, None)])

View File

@ -0,0 +1,264 @@
# Copyright 2020 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import attr
import functools
import unittest
import aiohttp
from aiohttp import web
from subiquitycore import contextlib38
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import api, Payload
from .test_server import (
makeTestClient,
run_coro,
TestControllerBase,
)
def make_request(client, method, path, *, params, json):
return client.request(
method, path, params=params, json=json)
@contextlib38.asynccontextmanager
async def makeE2EClient(api, impl,
*, middlewares=(), make_request=make_request):
async with makeTestClient(
api, impl, middlewares=middlewares) as client:
mr = functools.partial(make_request, client)
yield make_client(api, mr)
class TestEndToEnd(unittest.TestCase):
def test_simple(self):
@api
class API:
def GET() -> str: ...
class Impl(TestControllerBase):
async def GET(self) -> str:
return 'value'
async def run():
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(await client.GET(), 'value')
run_coro(run())
def test_nested(self):
@api
class API:
class endpoint:
class nested:
def GET() -> str: ...
class Impl(TestControllerBase):
async def endpoint_nested_GET(self) -> str:
return 'value'
async def run():
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(await client.endpoint.nested.GET(), 'value')
run_coro(run())
def test_args(self):
@api
class API:
def GET(arg1: str, arg2: str) -> str: ...
class Impl(TestControllerBase):
async def GET(self, arg1: str, arg2: str) -> str:
return '{}+{}'.format(arg1, arg2)
async def run():
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.GET(arg1="A", arg2="B"), 'A+B')
run_coro(run())
def test_defaults(self):
@api
class API:
def GET(arg1: str, arg2: str = "arg2") -> str: ...
class Impl(TestControllerBase):
async def GET(self, arg1: str, arg2: str = "arg2") -> str:
return '{}+{}'.format(arg1, arg2)
async def run():
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.GET(arg1="A", arg2="B"), 'A+B')
self.assertEqual(
await client.GET(arg1="A"), 'A+arg2')
run_coro(run())
def test_post(self):
@api
class API:
def POST(data: Payload[dict]) -> str: ...
class Impl(TestControllerBase):
async def POST(self, data: dict) -> str:
return data['key']
async def run():
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.POST({'key': 'value'}), 'value')
run_coro(run())
def test_typed(self):
@attr.s(auto_attribs=True)
class In:
val: int
@attr.s(auto_attribs=True)
class Out:
doubled: int
@api
class API:
class doubler:
def POST(data: In) -> Out: ...
class Impl(TestControllerBase):
async def doubler_POST(self, data: In) -> Out:
return Out(doubled=data.val*2)
async def run():
async with makeE2EClient(API, Impl()) as client:
out = await client.doubler.POST(In(3))
self.assertEqual(out.doubled, 6)
run_coro(run())
def test_middleware(self):
@api
class API:
def GET() -> int: ...
class Impl(TestControllerBase):
async def GET(self) -> int:
return 1/0
@web.middleware
async def middleware(request, handler):
return web.Response(
status=200,
headers={'x-status': 'skip'})
class Skip(Exception):
pass
@contextlib38.asynccontextmanager
async def custom_make_request(client, method, path, *, params, json):
async with make_request(
client, method, path, params=params, json=json) as resp:
if resp.headers.get('x-status') == 'skip':
raise Skip
yield resp
async def run():
async with makeE2EClient(
API, Impl(),
middlewares=[middleware],
make_request=custom_make_request) as client:
with self.assertRaises(Skip):
await client.GET()
run_coro(run())
def test_error(self):
@api
class API:
class good:
def GET(x: int) -> int: ...
class bad:
def GET(x: int) -> int: ...
class Impl(TestControllerBase):
async def good_GET(self, x: int) -> int:
return x + 1
async def bad_GET(self, x: int) -> int:
raise Exception("baz")
async def run():
async with makeE2EClient(API, Impl()) as client:
r = await client.good.GET(2)
self.assertEqual(r, 3)
with self.assertRaises(aiohttp.ClientResponseError):
await client.bad.GET(2)
run_coro(run())
def test_error_middleware(self):
@api
class API:
class good:
def GET(x: int) -> int: ...
class bad:
def GET(x: int) -> int: ...
class Impl(TestControllerBase):
async def good_GET(self, x: int) -> int:
return x + 1
async def bad_GET(self, x: int) -> int:
1/0
@web.middleware
async def middleware(request, handler):
resp = await handler(request)
if resp.get('exception'):
resp.headers['x-status'] = 'ERROR'
return resp
class Abort(Exception):
pass
@contextlib38.asynccontextmanager
async def custom_make_request(client, method, path, *, params, json):
async with make_request(
client, method, path, params=params, json=json) as resp:
if resp.headers.get('x-status') == 'ERROR':
raise Abort
yield resp
async def run():
async with makeE2EClient(
API, Impl(),
middlewares=[middleware],
make_request=custom_make_request) as client:
r = await client.good.GET(2)
self.assertEqual(r, 3)
with self.assertRaises(Abort):
await client.bad.GET(2)
run_coro(run())

View File

@ -0,0 +1,260 @@
# Copyright 2020 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import unittest
from aiohttp.test_utils import TestClient, TestServer
from aiohttp import web
from subiquitycore.context import Context
from subiquitycore import contextlib38
from subiquity.common.api.defs import api, Payload
from subiquity.common.api.server import (
bind,
controller_for_request,
MissingImplementationError,
SignatureMisatchError,
)
def run_coro(coro):
asyncio.get_event_loop().run_until_complete(coro)
class TestApp:
def report_start_event(self, context, description):
pass
def report_finish_event(self, context, description, result):
pass
project = 'test'
class TestControllerBase:
def __init__(self):
self.context = Context.new(TestApp())
@contextlib38.asynccontextmanager
async def makeTestClient(api, impl, middlewares=()):
app = web.Application(middlewares=middlewares)
bind(app.router, api, impl)
async with TestClient(TestServer(app)) as client:
yield client
class TestBind(unittest.TestCase):
async def assertResponse(self, coro, value):
resp = await coro
self.assertEqual(resp.status, 200)
self.assertEqual(await resp.json(), value)
def test_simple(self):
@api
class API:
def GET() -> str: ...
class Impl(TestControllerBase):
async def GET(self) -> str:
return 'value'
async def run():
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get("/"), 'value')
run_coro(run())
def test_nested(self):
@api
class API:
class endpoint:
class nested:
def get(): ...
class Impl(TestControllerBase):
async def nested_get(self, request, context):
return 'nested'
async def run():
async with makeTestClient(API.endpoint, Impl()) as client:
await self.assertResponse(
client.get("/endpoint/nested"), 'nested')
run_coro(run())
def test_args(self):
@api
class API:
def GET(arg: str): ...
class Impl(TestControllerBase):
async def GET(self, arg: str):
return arg
async def run():
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get('/?arg="whut"'), 'whut')
run_coro(run())
def test_missing_argument(self):
@api
class API:
def GET(arg: str): ...
class Impl(TestControllerBase):
async def GET(self, arg: str):
return arg
async def run():
async with makeTestClient(API, Impl()) as client:
resp = await client.get('/')
self.assertEqual(resp.status, 500)
self.assertEqual(resp.headers['x-status'], 'error')
self.assertEqual(resp.headers['x-error-type'], 'TypeError')
self.assertEqual(
resp.headers['x-error-msg'],
'missing required argument "arg"')
run_coro(run())
def test_error(self):
@api
class API:
def GET(): ...
class Impl(TestControllerBase):
async def GET(self):
return 1/0
async def run():
async with makeTestClient(API, Impl()) as client:
resp = await client.get('/')
self.assertEqual(resp.status, 500)
self.assertEqual(resp.headers['x-status'], 'error')
self.assertEqual(
resp.headers['x-error-type'], 'ZeroDivisionError')
run_coro(run())
def test_post(self):
@api
class API:
def POST(data: Payload[str]) -> str: ...
class Impl(TestControllerBase):
async def POST(self, data: str) -> str:
return data
async def run():
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.post("/", json='value'),
'value')
run_coro(run())
def test_missing_method(self):
@api
class API:
def GET(arg: str): ...
class Impl(TestControllerBase):
pass
app = web.Application()
with self.assertRaises(MissingImplementationError) as cm:
bind(app.router, API, Impl())
self.assertEqual(cm.exception.methname, "GET")
def test_signature_checking(self):
@api
class API:
def GET(arg: str): ...
class Impl(TestControllerBase):
async def GET(self, arg: int):
return arg
app = web.Application()
with self.assertRaises(SignatureMisatchError) as cm:
bind(app.router, API, Impl())
self.assertEqual(cm.exception.methname, "API.GET")
def test_middleware(self):
@web.middleware
async def middleware(request, handler):
resp = await handler(request)
exc = resp.get('exception')
if resp.get('exception') is not None:
resp.headers['x-error-type'] = type(exc).__name__
return resp
@api
class API:
def GET() -> str: ...
class Impl(TestControllerBase):
async def GET(self) -> str:
return 1/0
async def run():
async with makeTestClient(
API, Impl(), middlewares=[middleware]) as client:
resp = await client.get("/")
self.assertEqual(
resp.headers['x-error-type'], 'ZeroDivisionError')
run_coro(run())
def test_controller_for_request(self):
seen_controller = None
@web.middleware
async def middleware(request, handler):
nonlocal seen_controller
seen_controller = await controller_for_request(request)
return await handler(request)
@api
class API:
class meth:
def GET() -> str: ...
class Impl(TestControllerBase):
async def GET(self) -> str:
return ''
impl = Impl()
async def run():
async with makeTestClient(
API.meth, impl, middlewares=[middleware]) as client:
resp = await client.get("/meth")
self.assertEqual(await resp.json(), '')
run_coro(run())
self.assertIs(impl, seen_controller)

View File

@ -0,0 +1,126 @@
# Copyright 2020 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import datetime
import enum
import inspect
import typing
import attr
# This is basically a half-assed version of # https://pypi.org/project/cattrs/
# but that's not packaged and this is enough for our needs.
class Serializer:
def __init__(self):
self.typing_walkers = {
typing.Union: self._walk_Union,
list: self._walk_List,
typing.List: self._walk_List,
}
self.type_serializers = {}
self.type_deserializers = {}
for typ in int, str, dict, bool, list, type(None):
self.type_serializers[typ] = self._scalar
self.type_deserializers[typ] = self._scalar
self.type_serializers[datetime.datetime] = self._serialize_datetime
self.type_deserializers[datetime.datetime] = self._deserialize_datetime
def _scalar(self, annotation, value, metadata):
assert type(value) is annotation, "{} is not a {}".format(
value, annotation)
return value
def _walk_Union(self, meth, args, value, metadata):
NoneType = type(None)
assert NoneType in args, "can only serialize Optional"
args = [a for a in args if a is not NoneType]
assert len(args) == 1, "can only serialize Optional"
if value is None:
return value
return meth(args[0], value, metadata)
def _walk_List(self, meth, args, value, metadata):
return [meth(args[0], v, metadata) for v in value]
def _serialize_datetime(self, annotation, value, metadata):
assert type(value) is annotation
if metadata is not None and 'time_fmt' in metadata:
return value.strftime(metadata['time_fmt'])
else:
return str(value)
def _serialize_field(self, field, value):
return {field.name: self.serialize(field.type, value, field.metadata)}
def _serialize_attr(self, annotation, value, metadata):
r = {}
for field in attr.fields(annotation):
r.update(self._serialize_field(field, getattr(value, field.name)))
return r
def serialize(self, annotation, value, metadata=None):
if annotation is None:
assert value is None
return None
if annotation is inspect.Signature.empty:
return value
if attr.has(annotation):
return self._serialize_attr(annotation, value, metadata)
origin = getattr(annotation, '__origin__', None)
if origin is not None:
args = annotation.__args__
return self.typing_walkers[origin](
self.serialize, args, value, metadata)
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
return value.name
return self.type_serializers[annotation](annotation, value, metadata)
def _deserialize_datetime(self, annotation, value, metadata):
assert type(value) is str
if metadata is not None and 'time_fmt' in metadata:
return datetime.datetime.strptime(value, metadata['time_fmt'])
else:
1/0
def _deserialize_field(self, field, value):
return {
field.name: self.deserialize(field.type, value, field.metadata)
}
def _deserialize_attr(self, annotation, value, metadata):
args = {}
for field in attr.fields(annotation):
args.update(self._deserialize_field(field, value[field.name]))
return annotation(**args)
def deserialize(self, annotation, value, metadata=None):
if annotation is None:
assert value is None
return None
if annotation is inspect.Signature.empty:
return value
if attr.has(annotation):
return self._deserialize_attr(annotation, value, metadata)
origin = getattr(annotation, '__origin__', None)
if origin is not None:
args = annotation.__args__
return self.typing_walkers[origin](
self.deserialize, args, value, metadata)
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
return getattr(annotation, value)
return self.type_deserializers[annotation](annotation, value, metadata)

View File

@ -0,0 +1,704 @@
"""Utilities for with-statement contexts. See PEP 343."""
import abc
import sys
import _collections_abc
from collections import deque
from functools import wraps
from types import MethodType
__all__ = ["asynccontextmanager", "contextmanager", "closing", "nullcontext",
"AbstractContextManager", "AbstractAsyncContextManager",
"AsyncExitStack", "ContextDecorator", "ExitStack",
"redirect_stdout", "redirect_stderr", "suppress"]
class AbstractContextManager(abc.ABC):
"""An abstract base class for context managers."""
def __enter__(self):
"""Return `self` upon entering the runtime context."""
return self
@abc.abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None
@classmethod
def __subclasshook__(cls, C):
if cls is AbstractContextManager:
return _collections_abc._check_methods(C, "__enter__", "__exit__")
return NotImplemented
class AbstractAsyncContextManager(abc.ABC):
"""An abstract base class for asynchronous context managers."""
async def __aenter__(self):
"""Return `self` upon entering the runtime context."""
return self
@abc.abstractmethod
async def __aexit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None
@classmethod
def __subclasshook__(cls, C):
if cls is AbstractAsyncContextManager:
return _collections_abc._check_methods(C, "__aenter__",
"__aexit__")
return NotImplemented
class ContextDecorator(object):
"A base class or mixin that enables context managers to work as decorators."
def _recreate_cm(self):
"""Return a recreated instance of self.
Allows an otherwise one-shot context manager like
_GeneratorContextManager to support use as
a decorator via implicit recreation.
This is a private interface just for _GeneratorContextManager.
See issue #11647 for details.
"""
return self
def __call__(self, func):
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
return func(*args, **kwds)
return inner
class _GeneratorContextManagerBase:
"""Shared functionality for @contextmanager and @asynccontextmanager."""
def __init__(self, func, args, kwds):
self.gen = func(*args, **kwds)
self.func, self.args, self.kwds = func, args, kwds
# Issue 19330: ensure context manager instances have good docstrings
doc = getattr(func, "__doc__", None)
if doc is None:
doc = type(self).__doc__
self.__doc__ = doc
# Unfortunately, this still doesn't provide good help output when
# inspecting the created context manager instances, since pydoc
# currently bypasses the instance docstring and shows the docstring
# for the class instead.
# See http://bugs.python.org/issue19404 for more details.
class _GeneratorContextManager(_GeneratorContextManagerBase,
AbstractContextManager,
ContextDecorator):
"""Helper for @contextmanager decorator."""
def _recreate_cm(self):
# _GCM instances are one-shot context managers, so the
# CM must be recreated each time a decorated function is
# called
return self.__class__(self.func, self.args, self.kwds)
def __enter__(self):
# do not keep args and kwds alive unnecessarily
# they are only needed for recreation, which is not possible anymore
del self.args, self.kwds, self.func
try:
return next(self.gen)
except StopIteration:
raise RuntimeError("generator didn't yield") from None
def __exit__(self, type, value, traceback):
if type is None:
try:
next(self.gen)
except StopIteration:
return False
else:
raise RuntimeError("generator didn't stop")
else:
if value is None:
# Need to force instantiation so we can reliably
# tell if we get the same exception back
value = type()
try:
self.gen.throw(type, value, traceback)
except StopIteration as exc:
# Suppress StopIteration *unless* it's the same exception that
# was passed to throw(). This prevents a StopIteration
# raised inside the "with" statement from being suppressed.
return exc is not value
except RuntimeError as exc:
# Don't re-raise the passed in exception. (issue27122)
if exc is value:
return False
# Likewise, avoid suppressing if a StopIteration exception
# was passed to throw() and later wrapped into a RuntimeError
# (see PEP 479).
if type is StopIteration and exc.__cause__ is value:
return False
raise
except:
# only re-raise if it's *not* the exception that was
# passed to throw(), because __exit__() must not raise
# an exception unless __exit__() itself failed. But throw()
# has to raise the exception to signal propagation, so this
# fixes the impedance mismatch between the throw() protocol
# and the __exit__() protocol.
#
# This cannot use 'except BaseException as exc' (as in the
# async implementation) to maintain compatibility with
# Python 2, where old-style class exceptions are not caught
# by 'except BaseException'.
if sys.exc_info()[1] is value:
return False
raise
raise RuntimeError("generator didn't stop after throw()")
class _AsyncGeneratorContextManager(_GeneratorContextManagerBase,
AbstractAsyncContextManager):
"""Helper for @asynccontextmanager."""
async def __aenter__(self):
try:
return await self.gen.__anext__()
except StopAsyncIteration:
raise RuntimeError("generator didn't yield") from None
async def __aexit__(self, typ, value, traceback):
if typ is None:
try:
await self.gen.__anext__()
except StopAsyncIteration:
return
else:
raise RuntimeError("generator didn't stop")
else:
if value is None:
value = typ()
# See _GeneratorContextManager.__exit__ for comments on subtleties
# in this implementation
try:
await self.gen.athrow(typ, value, traceback)
raise RuntimeError("generator didn't stop after athrow()")
except StopAsyncIteration as exc:
return exc is not value
except RuntimeError as exc:
if exc is value:
return False
# Avoid suppressing if a StopIteration exception
# was passed to throw() and later wrapped into a RuntimeError
# (see PEP 479 for sync generators; async generators also
# have this behavior). But do this only if the exception wrapped
# by the RuntimeError is actully Stop(Async)Iteration (see
# issue29692).
if isinstance(value, (StopIteration, StopAsyncIteration)):
if exc.__cause__ is value:
return False
raise
except BaseException as exc:
if exc is not value:
raise
def contextmanager(func):
"""@contextmanager decorator.
Typical usage:
@contextmanager
def some_generator(<arguments>):
<setup>
try:
yield <value>
finally:
<cleanup>
This makes this:
with some_generator(<arguments>) as <variable>:
<body>
equivalent to this:
<setup>
try:
<variable> = <value>
<body>
finally:
<cleanup>
"""
@wraps(func)
def helper(*args, **kwds):
return _GeneratorContextManager(func, args, kwds)
return helper
def asynccontextmanager(func):
"""@asynccontextmanager decorator.
Typical usage:
@asynccontextmanager
async def some_async_generator(<arguments>):
<setup>
try:
yield <value>
finally:
<cleanup>
This makes this:
async with some_async_generator(<arguments>) as <variable>:
<body>
equivalent to this:
<setup>
try:
<variable> = <value>
<body>
finally:
<cleanup>
"""
@wraps(func)
def helper(*args, **kwds):
return _AsyncGeneratorContextManager(func, args, kwds)
return helper
class closing(AbstractContextManager):
"""Context to automatically close something at the end of a block.
Code like this:
with closing(<module>.open(<arguments>)) as f:
<block>
is equivalent to this:
f = <module>.open(<arguments>)
try:
<block>
finally:
f.close()
"""
def __init__(self, thing):
self.thing = thing
def __enter__(self):
return self.thing
def __exit__(self, *exc_info):
self.thing.close()
class _RedirectStream(AbstractContextManager):
_stream = None
def __init__(self, new_target):
self._new_target = new_target
# We use a list of old targets to make this CM re-entrant
self._old_targets = []
def __enter__(self):
self._old_targets.append(getattr(sys, self._stream))
setattr(sys, self._stream, self._new_target)
return self._new_target
def __exit__(self, exctype, excinst, exctb):
setattr(sys, self._stream, self._old_targets.pop())
class redirect_stdout(_RedirectStream):
"""Context manager for temporarily redirecting stdout to another file.
# How to send help() to stderr
with redirect_stdout(sys.stderr):
help(dir)
# How to write help() to a file
with open('help.txt', 'w') as f:
with redirect_stdout(f):
help(pow)
"""
_stream = "stdout"
class redirect_stderr(_RedirectStream):
"""Context manager for temporarily redirecting stderr to another file."""
_stream = "stderr"
class suppress(AbstractContextManager):
"""Context manager to suppress specified exceptions
After the exception is suppressed, execution proceeds with the next
statement following the with statement.
with suppress(FileNotFoundError):
os.remove(somefile)
# Execution still resumes here if the file was already removed
"""
def __init__(self, *exceptions):
self._exceptions = exceptions
def __enter__(self):
pass
def __exit__(self, exctype, excinst, exctb):
# Unlike isinstance and issubclass, CPython exception handling
# currently only looks at the concrete type hierarchy (ignoring
# the instance and subclass checking hooks). While Guido considers
# that a bug rather than a feature, it's a fairly hard one to fix
# due to various internal implementation details. suppress provides
# the simpler issubclass based semantics, rather than trying to
# exactly reproduce the limitations of the CPython interpreter.
#
# See http://bugs.python.org/issue12029 for more details
return exctype is not None and issubclass(exctype, self._exceptions)
class _BaseExitStack:
"""A base class for ExitStack and AsyncExitStack."""
@staticmethod
def _create_exit_wrapper(cm, cm_exit):
return MethodType(cm_exit, cm)
@staticmethod
def _create_cb_wrapper(callback, *args, **kwds):
def _exit_wrapper(exc_type, exc, tb):
callback(*args, **kwds)
return _exit_wrapper
def __init__(self):
self._exit_callbacks = deque()
def pop_all(self):
"""Preserve the context stack by transferring it to a new instance."""
new_stack = type(self)()
new_stack._exit_callbacks = self._exit_callbacks
self._exit_callbacks = deque()
return new_stack
def push(self, exit):
"""Registers a callback with the standard __exit__ method signature.
Can suppress exceptions the same way __exit__ method can.
Also accepts any object with an __exit__ method (registering a call
to the method instead of the object itself).
"""
# We use an unbound method rather than a bound method to follow
# the standard lookup behaviour for special methods.
_cb_type = type(exit)
try:
exit_method = _cb_type.__exit__
except AttributeError:
# Not a context manager, so assume it's a callable.
self._push_exit_callback(exit)
else:
self._push_cm_exit(exit, exit_method)
return exit # Allow use as a decorator.
def enter_context(self, cm):
"""Enters the supplied context manager.
If successful, also pushes its __exit__ method as a callback and
returns the result of the __enter__ method.
"""
# We look up the special methods on the type to match the with
# statement.
_cm_type = type(cm)
_exit = _cm_type.__exit__
result = _cm_type.__enter__(cm)
self._push_cm_exit(cm, _exit)
return result
def callback(*args, **kwds):
"""Registers an arbitrary callback and arguments.
Cannot suppress exceptions.
"""
if len(args) >= 2:
self, callback, *args = args
elif not args:
raise TypeError("descriptor 'callback' of '_BaseExitStack' object "
"needs an argument")
elif 'callback' in kwds:
callback = kwds.pop('callback')
self, *args = args
import warnings
warnings.warn("Passing 'callback' as keyword argument is deprecated",
DeprecationWarning, stacklevel=2)
else:
raise TypeError('callback expected at least 1 positional argument, '
'got %d' % (len(args)-1))
_exit_wrapper = self._create_cb_wrapper(callback, *args, **kwds)
# We changed the signature, so using @wraps is not appropriate, but
# setting __wrapped__ may still help with introspection.
_exit_wrapper.__wrapped__ = callback
self._push_exit_callback(_exit_wrapper)
return callback # Allow use as a decorator
callback.__text_signature__ = '($self, callback, /, *args, **kwds)'
def _push_cm_exit(self, cm, cm_exit):
"""Helper to correctly register callbacks to __exit__ methods."""
_exit_wrapper = self._create_exit_wrapper(cm, cm_exit)
self._push_exit_callback(_exit_wrapper, True)
def _push_exit_callback(self, callback, is_sync=True):
self._exit_callbacks.append((is_sync, callback))
# Inspired by discussions on http://bugs.python.org/issue13585
class ExitStack(_BaseExitStack, AbstractContextManager):
"""Context manager for dynamic management of a stack of exit callbacks.
For example:
with ExitStack() as stack:
files = [stack.enter_context(open(fname)) for fname in filenames]
# All opened files will automatically be closed at the end of
# the with statement, even if attempts to open files later
# in the list raise an exception.
"""
def __enter__(self):
return self
def __exit__(self, *exc_details):
received_exc = exc_details[0] is not None
# We manipulate the exception state so it behaves as though
# we were actually nesting multiple with statements
frame_exc = sys.exc_info()[1]
def _fix_exception_context(new_exc, old_exc):
# Context may not be correct, so find the end of the chain
while 1:
exc_context = new_exc.__context__
if exc_context is old_exc:
# Context is already set correctly (see issue 20317)
return
if exc_context is None or exc_context is frame_exc:
break
new_exc = exc_context
# Change the end of the chain to point to the exception
# we expect it to reference
new_exc.__context__ = old_exc
# Callbacks are invoked in LIFO order to match the behaviour of
# nested context managers
suppressed_exc = False
pending_raise = False
while self._exit_callbacks:
is_sync, cb = self._exit_callbacks.pop()
assert is_sync
try:
if cb(*exc_details):
suppressed_exc = True
pending_raise = False
exc_details = (None, None, None)
except:
new_exc_details = sys.exc_info()
# simulate the stack of exceptions by setting the context
_fix_exception_context(new_exc_details[1], exc_details[1])
pending_raise = True
exc_details = new_exc_details
if pending_raise:
try:
# bare "raise exc_details[1]" replaces our carefully
# set-up context
fixed_ctx = exc_details[1].__context__
raise exc_details[1]
except BaseException:
exc_details[1].__context__ = fixed_ctx
raise
return received_exc and suppressed_exc
def close(self):
"""Immediately unwind the context stack."""
self.__exit__(None, None, None)
# Inspired by discussions on https://bugs.python.org/issue29302
class AsyncExitStack(_BaseExitStack, AbstractAsyncContextManager):
"""Async context manager for dynamic management of a stack of exit
callbacks.
For example:
async with AsyncExitStack() as stack:
connections = [await stack.enter_async_context(get_connection())
for i in range(5)]
# All opened connections will automatically be released at the
# end of the async with statement, even if attempts to open a
# connection later in the list raise an exception.
"""
@staticmethod
def _create_async_exit_wrapper(cm, cm_exit):
return MethodType(cm_exit, cm)
@staticmethod
def _create_async_cb_wrapper(callback, *args, **kwds):
async def _exit_wrapper(exc_type, exc, tb):
await callback(*args, **kwds)
return _exit_wrapper
async def enter_async_context(self, cm):
"""Enters the supplied async context manager.
If successful, also pushes its __aexit__ method as a callback and
returns the result of the __aenter__ method.
"""
_cm_type = type(cm)
_exit = _cm_type.__aexit__
result = await _cm_type.__aenter__(cm)
self._push_async_cm_exit(cm, _exit)
return result
def push_async_exit(self, exit):
"""Registers a coroutine function with the standard __aexit__ method
signature.
Can suppress exceptions the same way __aexit__ method can.
Also accepts any object with an __aexit__ method (registering a call
to the method instead of the object itself).
"""
_cb_type = type(exit)
try:
exit_method = _cb_type.__aexit__
except AttributeError:
# Not an async context manager, so assume it's a coroutine function
self._push_exit_callback(exit, False)
else:
self._push_async_cm_exit(exit, exit_method)
return exit # Allow use as a decorator
def push_async_callback(*args, **kwds):
"""Registers an arbitrary coroutine function and arguments.
Cannot suppress exceptions.
"""
if len(args) >= 2:
self, callback, *args = args
elif not args:
raise TypeError("descriptor 'push_async_callback' of "
"'AsyncExitStack' object needs an argument")
elif 'callback' in kwds:
callback = kwds.pop('callback')
self, *args = args
import warnings
warnings.warn("Passing 'callback' as keyword argument is deprecated",
DeprecationWarning, stacklevel=2)
else:
raise TypeError('push_async_callback expected at least 1 '
'positional argument, got %d' % (len(args)-1))
_exit_wrapper = self._create_async_cb_wrapper(callback, *args, **kwds)
# We changed the signature, so using @wraps is not appropriate, but
# setting __wrapped__ may still help with introspection.
_exit_wrapper.__wrapped__ = callback
self._push_exit_callback(_exit_wrapper, False)
return callback # Allow use as a decorator
push_async_callback.__text_signature__ = '($self, callback, /, *args, **kwds)'
async def aclose(self):
"""Immediately unwind the context stack."""
await self.__aexit__(None, None, None)
def _push_async_cm_exit(self, cm, cm_exit):
"""Helper to correctly register coroutine function to __aexit__
method."""
_exit_wrapper = self._create_async_exit_wrapper(cm, cm_exit)
self._push_exit_callback(_exit_wrapper, False)
async def __aenter__(self):
return self
async def __aexit__(self, *exc_details):
received_exc = exc_details[0] is not None
# We manipulate the exception state so it behaves as though
# we were actually nesting multiple with statements
frame_exc = sys.exc_info()[1]
def _fix_exception_context(new_exc, old_exc):
# Context may not be correct, so find the end of the chain
while 1:
exc_context = new_exc.__context__
if exc_context is old_exc:
# Context is already set correctly (see issue 20317)
return
if exc_context is None or exc_context is frame_exc:
break
new_exc = exc_context
# Change the end of the chain to point to the exception
# we expect it to reference
new_exc.__context__ = old_exc
# Callbacks are invoked in LIFO order to match the behaviour of
# nested context managers
suppressed_exc = False
pending_raise = False
while self._exit_callbacks:
is_sync, cb = self._exit_callbacks.pop()
try:
if is_sync:
cb_suppress = cb(*exc_details)
else:
cb_suppress = await cb(*exc_details)
if cb_suppress:
suppressed_exc = True
pending_raise = False
exc_details = (None, None, None)
except:
new_exc_details = sys.exc_info()
# simulate the stack of exceptions by setting the context
_fix_exception_context(new_exc_details[1], exc_details[1])
pending_raise = True
exc_details = new_exc_details
if pending_raise:
try:
# bare "raise exc_details[1]" replaces our carefully
# set-up context
fixed_ctx = exc_details[1].__context__
raise exc_details[1]
except BaseException:
exc_details[1].__context__ = fixed_ctx
raise
return received_exc and suppressed_exc
class nullcontext(AbstractContextManager):
"""Context manager that does no additional processing.
Used as a stand-in for a normal context manager, when a particular
block of code is only sometimes used with a normal context manager:
cm = optional_cm if condition else nullcontext()
with cm:
# Perform operation, using optional_cm if condition is True
"""
def __init__(self, enter_result=None):
self.enter_result = enter_result
def __enter__(self):
return self.enter_result
def __exit__(self, *excinfo):
pass