Merge pull request #1984 from mwhudson/non-exhaustive-enums
subiquity.common.serialize: add a way to have a "non-exhaustive" enum
This commit is contained in:
commit
771fc1355b
|
@ -43,6 +43,13 @@ class SerializationError(Exception):
|
||||||
return f"processing {self.obj}: at {p}, {self.message}"
|
return f"processing {self.obj}: at {p}, {self.message}"
|
||||||
|
|
||||||
|
|
||||||
|
E = typing.TypeVar("E")
|
||||||
|
|
||||||
|
|
||||||
|
class NonExhaustive(typing.Generic[E]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True)
|
@attr.s(auto_attribs=True)
|
||||||
class SerializationContext:
|
class SerializationContext:
|
||||||
obj: typing.Any
|
obj: typing.Any
|
||||||
|
@ -71,6 +78,8 @@ class SerializationContext:
|
||||||
# This is basically a half-assed version of # https://pypi.org/project/cattrs/
|
# This is basically a half-assed version of # https://pypi.org/project/cattrs/
|
||||||
# but that's not packaged and this is enough for our needs.
|
# but that's not packaged and this is enough for our needs.
|
||||||
|
|
||||||
|
_enum_has_str_values = {}
|
||||||
|
|
||||||
|
|
||||||
class Serializer:
|
class Serializer:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -86,6 +95,7 @@ class Serializer:
|
||||||
typing.List: self._walk_List,
|
typing.List: self._walk_List,
|
||||||
dict: self._walk_Dict,
|
dict: self._walk_Dict,
|
||||||
typing.Dict: self._walk_Dict,
|
typing.Dict: self._walk_Dict,
|
||||||
|
NonExhaustive: self._walk_NonExhaustive,
|
||||||
}
|
}
|
||||||
self.type_serializers = {}
|
self.type_serializers = {}
|
||||||
self.type_deserializers = {}
|
self.type_deserializers = {}
|
||||||
|
@ -97,6 +107,24 @@ class Serializer:
|
||||||
self.type_serializers[datetime.datetime] = self._serialize_datetime
|
self.type_serializers[datetime.datetime] = self._serialize_datetime
|
||||||
self.type_deserializers[datetime.datetime] = self._deserialize_datetime
|
self.type_deserializers[datetime.datetime] = self._deserialize_datetime
|
||||||
|
|
||||||
|
def _ann_ok_as_dict_key(self, annotation):
|
||||||
|
if annotation is str:
|
||||||
|
return True
|
||||||
|
origin = getattr(annotation, "__origin__", None)
|
||||||
|
if origin is NonExhaustive:
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
|
||||||
|
if self.serialize_enums_by == "name":
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if annotation in _enum_has_str_values:
|
||||||
|
return _enum_has_str_values[annotation]
|
||||||
|
ok = set(type(v.value) for v in annotation) == {str}
|
||||||
|
_enum_has_str_values[annotation] = ok
|
||||||
|
return ok
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def _scalar(self, annotation, context):
|
def _scalar(self, annotation, context):
|
||||||
context.assert_type(annotation)
|
context.assert_type(annotation)
|
||||||
return context.cur
|
return context.cur
|
||||||
|
@ -139,10 +167,12 @@ class Serializer:
|
||||||
|
|
||||||
def _walk_Dict(self, meth, args, context):
|
def _walk_Dict(self, meth, args, context):
|
||||||
k_ann, v_ann = args
|
k_ann, v_ann = args
|
||||||
if not context.serializing and k_ann is not str:
|
if self._ann_ok_as_dict_key(k_ann):
|
||||||
input_items = context.cur
|
|
||||||
else:
|
|
||||||
input_items = context.cur.items()
|
input_items = context.cur.items()
|
||||||
|
elif context.serializing:
|
||||||
|
input_items = context.cur.items()
|
||||||
|
else:
|
||||||
|
input_items = context.cur
|
||||||
output_items = [
|
output_items = [
|
||||||
[
|
[
|
||||||
meth(k_ann, context.child(f"/{k}", k)),
|
meth(k_ann, context.child(f"/{k}", k)),
|
||||||
|
@ -150,9 +180,25 @@ class Serializer:
|
||||||
]
|
]
|
||||||
for k, v in input_items
|
for k, v in input_items
|
||||||
]
|
]
|
||||||
if context.serializing and k_ann is not str:
|
if self._ann_ok_as_dict_key(k_ann):
|
||||||
|
return dict(output_items)
|
||||||
|
elif context.serializing:
|
||||||
return output_items
|
return output_items
|
||||||
return dict(output_items)
|
else:
|
||||||
|
return dict(output_items)
|
||||||
|
|
||||||
|
def _walk_NonExhaustive(self, meth, args, context):
|
||||||
|
[enum_cls] = args
|
||||||
|
if context.serializing:
|
||||||
|
if isinstance(context.cur, enum_cls):
|
||||||
|
return meth(enum_cls, context)
|
||||||
|
else:
|
||||||
|
return context.cur
|
||||||
|
else:
|
||||||
|
if context.cur in (getattr(m, self.serialize_enums_by) for m in enum_cls):
|
||||||
|
return meth(enum_cls, context)
|
||||||
|
else:
|
||||||
|
return context.cur
|
||||||
|
|
||||||
def _serialize_dict(self, annotation, context):
|
def _serialize_dict(self, annotation, context):
|
||||||
context.assert_type(annotation)
|
context.assert_type(annotation)
|
||||||
|
|
|
@ -23,7 +23,12 @@ import unittest
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from subiquity.common.serialize import SerializationError, Serializer, named_field
|
from subiquity.common.serialize import (
|
||||||
|
NonExhaustive,
|
||||||
|
SerializationError,
|
||||||
|
Serializer,
|
||||||
|
named_field,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True)
|
@attr.s(auto_attribs=True)
|
||||||
|
@ -61,6 +66,10 @@ class MyEnum(enum.Enum):
|
||||||
name = "value"
|
name = "value"
|
||||||
|
|
||||||
|
|
||||||
|
class MyIntEnum(enum.Enum):
|
||||||
|
name = 1
|
||||||
|
|
||||||
|
|
||||||
class CommonSerializerTests:
|
class CommonSerializerTests:
|
||||||
simple_examples = [
|
simple_examples = [
|
||||||
(int, 1),
|
(int, 1),
|
||||||
|
@ -129,12 +138,24 @@ class CommonSerializerTests:
|
||||||
def test_enums(self):
|
def test_enums(self):
|
||||||
self.assertSerialization(MyEnum, MyEnum.name, "name")
|
self.assertSerialization(MyEnum, MyEnum.name, "name")
|
||||||
|
|
||||||
|
def test_non_exhaustive_enums(self):
|
||||||
|
self.serializer = type(self.serializer)(compact=self.serializer.compact)
|
||||||
|
self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "name")
|
||||||
|
self.assertSerialization(NonExhaustive[MyEnum], "name2", "name2")
|
||||||
|
|
||||||
def test_enums_by_value(self):
|
def test_enums_by_value(self):
|
||||||
self.serializer = type(self.serializer)(
|
self.serializer = type(self.serializer)(
|
||||||
compact=self.serializer.compact, serialize_enums_by="value"
|
compact=self.serializer.compact, serialize_enums_by="value"
|
||||||
)
|
)
|
||||||
self.assertSerialization(MyEnum, MyEnum.name, "value")
|
self.assertSerialization(MyEnum, MyEnum.name, "value")
|
||||||
|
|
||||||
|
def test_non_exhaustive_enums_by_value(self):
|
||||||
|
self.serializer = type(self.serializer)(
|
||||||
|
compact=self.serializer.compact, serialize_enums_by="value"
|
||||||
|
)
|
||||||
|
self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "value")
|
||||||
|
self.assertSerialization(NonExhaustive[MyEnum], "value2", "value2")
|
||||||
|
|
||||||
def test_serialize_any(self):
|
def test_serialize_any(self):
|
||||||
o = object()
|
o = object()
|
||||||
self.assertSerialization(typing.Any, o, o)
|
self.assertSerialization(typing.Any, o, o)
|
||||||
|
@ -259,6 +280,29 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
|
||||||
self.serializer.deserialize(Type, {"field-1": 1, "field2": 2})
|
self.serializer.deserialize(Type, {"field-1": 1, "field2": 2})
|
||||||
self.assertEqual(catcher.exception.path, "['field-1']")
|
self.assertEqual(catcher.exception.path, "['field-1']")
|
||||||
|
|
||||||
|
def test_serialize_dict_enumkeys_name(self):
|
||||||
|
self.assertSerialization(
|
||||||
|
typing.Dict[MyEnum, str], {MyEnum.name: "b"}, {"name": "b"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_serialize_dict_enumkeys_str_value(self):
|
||||||
|
self.serializer = type(self.serializer)(
|
||||||
|
compact=self.serializer.compact, serialize_enums_by="value"
|
||||||
|
)
|
||||||
|
self.assertSerialization(
|
||||||
|
typing.Dict[MyEnum, str], {MyEnum.name: "b"}, {"value": "b"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_serialize_dict_enumkeys_notstr_value(self):
|
||||||
|
self.serializer = type(self.serializer)(
|
||||||
|
compact=self.serializer.compact, serialize_enums_by="value"
|
||||||
|
)
|
||||||
|
self.assertSerialization(
|
||||||
|
typing.Dict[MyIntEnum, str],
|
||||||
|
{MyIntEnum.name: "b"},
|
||||||
|
[[1, "b"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
|
class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
|
||||||
serializer = Serializer(compact=True)
|
serializer = Serializer(compact=True)
|
||||||
|
|
|
@ -275,7 +275,7 @@ class FilesystemController(SubiquityController, FilesystemManipulator):
|
||||||
self._on_volume: Optional[snapdapi.OnVolume] = None
|
self._on_volume: Optional[snapdapi.OnVolume] = None
|
||||||
self._source_handler: Optional[AbstractSourceHandler] = None
|
self._source_handler: Optional[AbstractSourceHandler] = None
|
||||||
self._system_mounter: Optional[Mounter] = None
|
self._system_mounter: Optional[Mounter] = None
|
||||||
self._role_to_device: Dict[str, _Device] = {}
|
self._role_to_device: Dict[Union[str, snapdapi.Role], _Device] = {}
|
||||||
self._device_to_structure: Dict[_Device, snapdapi.OnVolume] = {}
|
self._device_to_structure: Dict[_Device, snapdapi.OnVolume] = {}
|
||||||
self._pyudev_context: Optional[pyudev.Context] = None
|
self._pyudev_context: Optional[pyudev.Context] = None
|
||||||
self.use_tpm: bool = False
|
self.use_tpm: bool = False
|
||||||
|
@ -949,9 +949,9 @@ class FilesystemController(SubiquityController, FilesystemManipulator):
|
||||||
step=snapdapi.SystemActionStep.SETUP_STORAGE_ENCRYPTION,
|
step=snapdapi.SystemActionStep.SETUP_STORAGE_ENCRYPTION,
|
||||||
on_volumes=self._on_volumes(),
|
on_volumes=self._on_volumes(),
|
||||||
),
|
),
|
||||||
|
ann=snapdapi.SystemActionResponse,
|
||||||
)
|
)
|
||||||
role_to_encrypted_device = result["encrypted-devices"]
|
for role, enc_path in result.encrypted_devices.items():
|
||||||
for role, enc_path in role_to_encrypted_device.items():
|
|
||||||
arb_device = ArbitraryDevice(m=self.model, path=enc_path)
|
arb_device = ArbitraryDevice(m=self.model, path=enc_path)
|
||||||
self.model._actions.append(arb_device)
|
self.model._actions.append(arb_device)
|
||||||
part = self._role_to_device[role]
|
part = self._role_to_device[role]
|
||||||
|
|
|
@ -1537,11 +1537,11 @@ class TestCoreBootInstallMethods(IsolatedAsyncioTestCase):
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
snapdapi, "post_and_wait", new_callable=mock.AsyncMock
|
snapdapi, "post_and_wait", new_callable=mock.AsyncMock
|
||||||
) as mocked:
|
) as mocked:
|
||||||
mocked.return_value = {
|
mocked.return_value = snapdapi.SystemActionResponse(
|
||||||
"encrypted-devices": {
|
encrypted_devices={
|
||||||
snapdapi.Role.SYSTEM_DATA: "enc-system-data",
|
snapdapi.Role.SYSTEM_DATA: "enc-system-data",
|
||||||
},
|
},
|
||||||
}
|
)
|
||||||
await self.fsc.setup_encryption(context=self.fsc.context)
|
await self.fsc.setup_encryption(context=self.fsc.context)
|
||||||
|
|
||||||
# setup_encryption mutates the filesystem model objects to
|
# setup_encryption mutates the filesystem model objects to
|
||||||
|
|
|
@ -24,7 +24,7 @@ import attr
|
||||||
|
|
||||||
from subiquity.common.api.client import make_client
|
from subiquity.common.api.client import make_client
|
||||||
from subiquity.common.api.defs import Payload, api, path_parameter
|
from subiquity.common.api.defs import Payload, api, path_parameter
|
||||||
from subiquity.common.serialize import Serializer, named_field
|
from subiquity.common.serialize import NonExhaustive, Serializer, named_field
|
||||||
from subiquity.common.types import Change, TaskStatus
|
from subiquity.common.types import Change, TaskStatus
|
||||||
|
|
||||||
log = logging.getLogger("subiquity.server.snapdapi")
|
log = logging.getLogger("subiquity.server.snapdapi")
|
||||||
|
@ -90,17 +90,11 @@ class Response:
|
||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
class Role:
|
class Role(enum.Enum):
|
||||||
NONE = ""
|
NONE = ""
|
||||||
MBR = "mbr"
|
MBR = "mbr"
|
||||||
SYSTEM_BOOT = "system-boot"
|
SYSTEM_BOOT = "system-boot"
|
||||||
SYSTEM_BOOT_IMAGE = "system-boot-image"
|
|
||||||
SYSTEM_BOOT_SELECT = "system-boot-select"
|
|
||||||
SYSTEM_DATA = "system-data"
|
SYSTEM_DATA = "system-data"
|
||||||
SYSTEM_RECOVERY_SELECT = "system-recovery-select"
|
|
||||||
SYSTEM_SAVE = "system-save"
|
|
||||||
SYSTEM_SEED = "system-seed"
|
|
||||||
SYSTEM_SEED_NULL = "system-seed-null"
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True)
|
@attr.s(auto_attribs=True)
|
||||||
|
@ -134,7 +128,7 @@ class VolumeStructure:
|
||||||
offset_write: Optional[RelativeOffset] = named_field("offset-write", None)
|
offset_write: Optional[RelativeOffset] = named_field("offset-write", None)
|
||||||
size: int = 0
|
size: int = 0
|
||||||
type: str = ""
|
type: str = ""
|
||||||
role: str = Role.NONE
|
role: NonExhaustive[Role] = Role.NONE
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
filesystem: str = ""
|
filesystem: str = ""
|
||||||
content: Optional[List[VolumeContent]] = None
|
content: Optional[List[VolumeContent]] = None
|
||||||
|
@ -232,6 +226,13 @@ class SystemActionRequest:
|
||||||
on_volumes: Dict[str, OnVolume] = named_field("on-volumes")
|
on_volumes: Dict[str, OnVolume] = named_field("on-volumes")
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
class SystemActionResponse:
|
||||||
|
encrypted_devices: Dict[NonExhaustive[Role], str] = named_field(
|
||||||
|
"encrypted-devices", default=attr.Factory(dict)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@api
|
@api
|
||||||
class SnapdAPI:
|
class SnapdAPI:
|
||||||
serialize_query_args = False
|
serialize_query_args = False
|
||||||
|
@ -313,14 +314,17 @@ def make_api_client(async_snapd):
|
||||||
snapd_serializer = Serializer(ignore_unknown_fields=True, serialize_enums_by="value")
|
snapd_serializer = Serializer(ignore_unknown_fields=True, serialize_enums_by="value")
|
||||||
|
|
||||||
|
|
||||||
async def post_and_wait(client, meth, *args, **kw):
|
async def post_and_wait(client, meth, *args, ann=None, **kw):
|
||||||
change_id = await meth(*args, **kw)
|
change_id = await meth(*args, **kw)
|
||||||
log.debug("post_and_wait %s", change_id)
|
log.debug("post_and_wait %s", change_id)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
result = await client.v2.changes[change_id].GET()
|
result = await client.v2.changes[change_id].GET()
|
||||||
if result.status == TaskStatus.DONE:
|
if result.status == TaskStatus.DONE:
|
||||||
return result.data
|
data = result.data
|
||||||
|
if ann is not None:
|
||||||
|
data = snapd_serializer.deserialize(ann, data)
|
||||||
|
return data
|
||||||
elif result.status == TaskStatus.ERROR:
|
elif result.status == TaskStatus.ERROR:
|
||||||
raise aiohttp.ClientError(result.err)
|
raise aiohttp.ClientError(result.err)
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
Loading…
Reference in New Issue