diff --git a/subiquity/common/serialize.py b/subiquity/common/serialize.py index 483b0989..9bf5319d 100644 --- a/subiquity/common/serialize.py +++ b/subiquity/common/serialize.py @@ -43,6 +43,13 @@ class SerializationError(Exception): return f"processing {self.obj}: at {p}, {self.message}" +E = typing.TypeVar("E") + + +class NonExhaustive(typing.Generic[E]): + pass + + @attr.s(auto_attribs=True) class SerializationContext: obj: typing.Any @@ -71,6 +78,8 @@ class SerializationContext: # This is basically a half-assed version of # https://pypi.org/project/cattrs/ # but that's not packaged and this is enough for our needs. +_enum_has_str_values = {} + class Serializer: def __init__( @@ -86,6 +95,7 @@ class Serializer: typing.List: self._walk_List, dict: self._walk_Dict, typing.Dict: self._walk_Dict, + NonExhaustive: self._walk_NonExhaustive, } self.type_serializers = {} self.type_deserializers = {} @@ -97,6 +107,24 @@ class Serializer: self.type_serializers[datetime.datetime] = self._serialize_datetime self.type_deserializers[datetime.datetime] = self._deserialize_datetime + def _ann_ok_as_dict_key(self, annotation): + if annotation is str: + return True + origin = getattr(annotation, "__origin__", None) + if origin is NonExhaustive: + annotation = annotation.__args__[0] + if isinstance(annotation, type) and issubclass(annotation, enum.Enum): + if self.serialize_enums_by == "name": + return True + else: + if annotation in _enum_has_str_values: + return _enum_has_str_values[annotation] + ok = set(type(v.value) for v in annotation) == {str} + _enum_has_str_values[annotation] = ok + return ok + else: + return False + def _scalar(self, annotation, context): context.assert_type(annotation) return context.cur @@ -139,10 +167,12 @@ class Serializer: def _walk_Dict(self, meth, args, context): k_ann, v_ann = args - if not context.serializing and k_ann is not str: - input_items = context.cur - else: + if self._ann_ok_as_dict_key(k_ann): input_items = context.cur.items() + elif context.serializing: + input_items = context.cur.items() + else: + input_items = context.cur output_items = [ [ meth(k_ann, context.child(f"/{k}", k)), @@ -150,9 +180,25 @@ class Serializer: ] for k, v in input_items ] - if context.serializing and k_ann is not str: + if self._ann_ok_as_dict_key(k_ann): + return dict(output_items) + elif context.serializing: return output_items - return dict(output_items) + else: + return dict(output_items) + + def _walk_NonExhaustive(self, meth, args, context): + [enum_cls] = args + if context.serializing: + if isinstance(context.cur, enum_cls): + return meth(enum_cls, context) + else: + return context.cur + else: + if context.cur in (getattr(m, self.serialize_enums_by) for m in enum_cls): + return meth(enum_cls, context) + else: + return context.cur def _serialize_dict(self, annotation, context): context.assert_type(annotation) diff --git a/subiquity/common/tests/test_serialization.py b/subiquity/common/tests/test_serialization.py index 3ada5e65..8cb6513a 100644 --- a/subiquity/common/tests/test_serialization.py +++ b/subiquity/common/tests/test_serialization.py @@ -23,7 +23,12 @@ import unittest import attr -from subiquity.common.serialize import SerializationError, Serializer, named_field +from subiquity.common.serialize import ( + NonExhaustive, + SerializationError, + Serializer, + named_field, +) @attr.s(auto_attribs=True) @@ -61,6 +66,10 @@ class MyEnum(enum.Enum): name = "value" +class MyIntEnum(enum.Enum): + name = 1 + + class CommonSerializerTests: simple_examples = [ (int, 1), @@ -129,12 +138,24 @@ class CommonSerializerTests: def test_enums(self): self.assertSerialization(MyEnum, MyEnum.name, "name") + def test_non_exhaustive_enums(self): + self.serializer = type(self.serializer)(compact=self.serializer.compact) + self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "name") + self.assertSerialization(NonExhaustive[MyEnum], "name2", "name2") + def test_enums_by_value(self): self.serializer = type(self.serializer)( compact=self.serializer.compact, serialize_enums_by="value" ) self.assertSerialization(MyEnum, MyEnum.name, "value") + def test_non_exhaustive_enums_by_value(self): + self.serializer = type(self.serializer)( + compact=self.serializer.compact, serialize_enums_by="value" + ) + self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "value") + self.assertSerialization(NonExhaustive[MyEnum], "value2", "value2") + def test_serialize_any(self): o = object() self.assertSerialization(typing.Any, o, o) @@ -259,6 +280,29 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase): self.serializer.deserialize(Type, {"field-1": 1, "field2": 2}) self.assertEqual(catcher.exception.path, "['field-1']") + def test_serialize_dict_enumkeys_name(self): + self.assertSerialization( + typing.Dict[MyEnum, str], {MyEnum.name: "b"}, {"name": "b"} + ) + + def test_serialize_dict_enumkeys_str_value(self): + self.serializer = type(self.serializer)( + compact=self.serializer.compact, serialize_enums_by="value" + ) + self.assertSerialization( + typing.Dict[MyEnum, str], {MyEnum.name: "b"}, {"value": "b"} + ) + + def test_serialize_dict_enumkeys_notstr_value(self): + self.serializer = type(self.serializer)( + compact=self.serializer.compact, serialize_enums_by="value" + ) + self.assertSerialization( + typing.Dict[MyIntEnum, str], + {MyIntEnum.name: "b"}, + [[1, "b"]], + ) + class TestCompactSerializer(CommonSerializerTests, unittest.TestCase): serializer = Serializer(compact=True) diff --git a/subiquity/server/controllers/filesystem.py b/subiquity/server/controllers/filesystem.py index f210267c..fbdaeebe 100644 --- a/subiquity/server/controllers/filesystem.py +++ b/subiquity/server/controllers/filesystem.py @@ -275,7 +275,7 @@ class FilesystemController(SubiquityController, FilesystemManipulator): self._on_volume: Optional[snapdapi.OnVolume] = None self._source_handler: Optional[AbstractSourceHandler] = None self._system_mounter: Optional[Mounter] = None - self._role_to_device: Dict[str, _Device] = {} + self._role_to_device: Dict[Union[str, snapdapi.Role], _Device] = {} self._device_to_structure: Dict[_Device, snapdapi.OnVolume] = {} self._pyudev_context: Optional[pyudev.Context] = None self.use_tpm: bool = False @@ -949,9 +949,9 @@ class FilesystemController(SubiquityController, FilesystemManipulator): step=snapdapi.SystemActionStep.SETUP_STORAGE_ENCRYPTION, on_volumes=self._on_volumes(), ), + ann=snapdapi.SystemActionResponse, ) - role_to_encrypted_device = result["encrypted-devices"] - for role, enc_path in role_to_encrypted_device.items(): + for role, enc_path in result.encrypted_devices.items(): arb_device = ArbitraryDevice(m=self.model, path=enc_path) self.model._actions.append(arb_device) part = self._role_to_device[role] diff --git a/subiquity/server/controllers/tests/test_filesystem.py b/subiquity/server/controllers/tests/test_filesystem.py index ee2f4d85..3663f608 100644 --- a/subiquity/server/controllers/tests/test_filesystem.py +++ b/subiquity/server/controllers/tests/test_filesystem.py @@ -1537,11 +1537,11 @@ class TestCoreBootInstallMethods(IsolatedAsyncioTestCase): with mock.patch.object( snapdapi, "post_and_wait", new_callable=mock.AsyncMock ) as mocked: - mocked.return_value = { - "encrypted-devices": { + mocked.return_value = snapdapi.SystemActionResponse( + encrypted_devices={ snapdapi.Role.SYSTEM_DATA: "enc-system-data", }, - } + ) await self.fsc.setup_encryption(context=self.fsc.context) # setup_encryption mutates the filesystem model objects to diff --git a/subiquity/server/snapdapi.py b/subiquity/server/snapdapi.py index 4173c5aa..5ae38382 100644 --- a/subiquity/server/snapdapi.py +++ b/subiquity/server/snapdapi.py @@ -24,7 +24,7 @@ import attr from subiquity.common.api.client import make_client from subiquity.common.api.defs import Payload, api, path_parameter -from subiquity.common.serialize import Serializer, named_field +from subiquity.common.serialize import NonExhaustive, Serializer, named_field from subiquity.common.types import Change, TaskStatus log = logging.getLogger("subiquity.server.snapdapi") @@ -90,17 +90,11 @@ class Response: status: str -class Role: +class Role(enum.Enum): NONE = "" MBR = "mbr" SYSTEM_BOOT = "system-boot" - SYSTEM_BOOT_IMAGE = "system-boot-image" - SYSTEM_BOOT_SELECT = "system-boot-select" SYSTEM_DATA = "system-data" - SYSTEM_RECOVERY_SELECT = "system-recovery-select" - SYSTEM_SAVE = "system-save" - SYSTEM_SEED = "system-seed" - SYSTEM_SEED_NULL = "system-seed-null" @attr.s(auto_attribs=True) @@ -134,7 +128,7 @@ class VolumeStructure: offset_write: Optional[RelativeOffset] = named_field("offset-write", None) size: int = 0 type: str = "" - role: str = Role.NONE + role: NonExhaustive[Role] = Role.NONE id: Optional[str] = None filesystem: str = "" content: Optional[List[VolumeContent]] = None @@ -232,6 +226,13 @@ class SystemActionRequest: on_volumes: Dict[str, OnVolume] = named_field("on-volumes") +@attr.s(auto_attribs=True) +class SystemActionResponse: + encrypted_devices: Dict[NonExhaustive[Role], str] = named_field( + "encrypted-devices", default=attr.Factory(dict) + ) + + @api class SnapdAPI: serialize_query_args = False @@ -313,14 +314,17 @@ def make_api_client(async_snapd): snapd_serializer = Serializer(ignore_unknown_fields=True, serialize_enums_by="value") -async def post_and_wait(client, meth, *args, **kw): +async def post_and_wait(client, meth, *args, ann=None, **kw): change_id = await meth(*args, **kw) log.debug("post_and_wait %s", change_id) while True: result = await client.v2.changes[change_id].GET() if result.status == TaskStatus.DONE: - return result.data + data = result.data + if ann is not None: + data = snapd_serializer.deserialize(ann, data) + return data elif result.status == TaskStatus.ERROR: raise aiohttp.ClientError(result.err) await asyncio.sleep(0.1)