From a59ef9fba0bcef6c367ab6a7b398a66cccc25fa0 Mon Sep 17 00:00:00 2001 From: Michael Hudson-Doyle Date: Fri, 3 May 2024 17:19:36 +1200 Subject: [PATCH 1/4] serialize a dict with enum keys that serialize as strings as an object Rather than as a list of [key, value] pairs. --- subiquity/common/serialize.py | 32 +++++++++++++++++--- subiquity/common/tests/test_serialization.py | 27 +++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/subiquity/common/serialize.py b/subiquity/common/serialize.py index 483b0989..47cac2ba 100644 --- a/subiquity/common/serialize.py +++ b/subiquity/common/serialize.py @@ -71,6 +71,8 @@ class SerializationContext: # This is basically a half-assed version of # https://pypi.org/project/cattrs/ # but that's not packaged and this is enough for our needs. +_enum_has_str_values = {} + class Serializer: def __init__( @@ -97,6 +99,21 @@ class Serializer: self.type_serializers[datetime.datetime] = self._serialize_datetime self.type_deserializers[datetime.datetime] = self._deserialize_datetime + def _ann_ok_as_dict_key(self, annotation): + if annotation is str: + return True + if isinstance(annotation, type) and issubclass(annotation, enum.Enum): + if self.serialize_enums_by == "name": + return True + else: + if annotation in _enum_has_str_values: + return _enum_has_str_values[annotation] + ok = set(type(v.value) for v in annotation) == {str} + _enum_has_str_values[annotation] = ok + return ok + else: + return False + def _scalar(self, annotation, context): context.assert_type(annotation) return context.cur @@ -139,10 +156,12 @@ class Serializer: def _walk_Dict(self, meth, args, context): k_ann, v_ann = args - if not context.serializing and k_ann is not str: - input_items = context.cur - else: + if self._ann_ok_as_dict_key(k_ann): input_items = context.cur.items() + elif context.serializing: + input_items = context.cur.items() + else: + input_items = context.cur output_items = [ [ meth(k_ann, context.child(f"/{k}", k)), @@ -150,9 +169,12 @@ class Serializer: ] for k, v in input_items ] - if context.serializing and k_ann is not str: + if self._ann_ok_as_dict_key(k_ann): + return dict(output_items) + elif context.serializing: return output_items - return dict(output_items) + else: + return dict(output_items) def _serialize_dict(self, annotation, context): context.assert_type(annotation) diff --git a/subiquity/common/tests/test_serialization.py b/subiquity/common/tests/test_serialization.py index 3ada5e65..b08f067e 100644 --- a/subiquity/common/tests/test_serialization.py +++ b/subiquity/common/tests/test_serialization.py @@ -61,6 +61,10 @@ class MyEnum(enum.Enum): name = "value" +class MyIntEnum(enum.Enum): + name = 1 + + class CommonSerializerTests: simple_examples = [ (int, 1), @@ -259,6 +263,29 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase): self.serializer.deserialize(Type, {"field-1": 1, "field2": 2}) self.assertEqual(catcher.exception.path, "['field-1']") + def test_serialize_dict_enumkeys_name(self): + self.assertSerialization( + typing.Dict[MyEnum, str], {MyEnum.name: "b"}, {"name": "b"} + ) + + def test_serialize_dict_enumkeys_str_value(self): + self.serializer = type(self.serializer)( + compact=self.serializer.compact, serialize_enums_by="value" + ) + self.assertSerialization( + typing.Dict[MyEnum, str], {MyEnum.name: "b"}, {"value": "b"} + ) + + def test_serialize_dict_enumkeys_notstr_value(self): + self.serializer = type(self.serializer)( + compact=self.serializer.compact, serialize_enums_by="value" + ) + self.assertSerialization( + typing.Dict[MyIntEnum, str], + {MyIntEnum.name: "b"}, + [[1, "b"]], + ) + class TestCompactSerializer(CommonSerializerTests, unittest.TestCase): serializer = Serializer(compact=True) From b53c1fffb781af25fa6d702c95a3fc526121da49 Mon Sep 17 00:00:00 2001 From: Michael Hudson-Doyle Date: Fri, 26 Apr 2024 15:48:14 +1200 Subject: [PATCH 2/4] subiquity.common.serialize: add a way to have a "non-exhaustive" enum We use the subiquity.common.api stuff to talk to snapd, and use enumerations to describe several field names. This is a bit of a landmine in some situations because snapd can add a value to the set of possible return values and thus cause serialization to fail. Rather than just giving up on all of the typo-resistance of declaring enums for API types, this adds a way to mark an enumeration as "non-exhaustive": values that are part of the declared enum will be deserialized as values of the enum but values that are not will be passed straight through (as strings, in all the cases used in snapdapi.py), which conveniently will never compare equal to an enumeration. This should let us just declare the values of the enumerations we actually care about and not break if we don't declare every value snapd actually uses. --- subiquity/common/serialize.py | 24 ++++++++++++++++++++ subiquity/common/tests/test_serialization.py | 19 +++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/subiquity/common/serialize.py b/subiquity/common/serialize.py index 47cac2ba..9bf5319d 100644 --- a/subiquity/common/serialize.py +++ b/subiquity/common/serialize.py @@ -43,6 +43,13 @@ class SerializationError(Exception): return f"processing {self.obj}: at {p}, {self.message}" +E = typing.TypeVar("E") + + +class NonExhaustive(typing.Generic[E]): + pass + + @attr.s(auto_attribs=True) class SerializationContext: obj: typing.Any @@ -88,6 +95,7 @@ class Serializer: typing.List: self._walk_List, dict: self._walk_Dict, typing.Dict: self._walk_Dict, + NonExhaustive: self._walk_NonExhaustive, } self.type_serializers = {} self.type_deserializers = {} @@ -102,6 +110,9 @@ class Serializer: def _ann_ok_as_dict_key(self, annotation): if annotation is str: return True + origin = getattr(annotation, "__origin__", None) + if origin is NonExhaustive: + annotation = annotation.__args__[0] if isinstance(annotation, type) and issubclass(annotation, enum.Enum): if self.serialize_enums_by == "name": return True @@ -176,6 +187,19 @@ class Serializer: else: return dict(output_items) + def _walk_NonExhaustive(self, meth, args, context): + [enum_cls] = args + if context.serializing: + if isinstance(context.cur, enum_cls): + return meth(enum_cls, context) + else: + return context.cur + else: + if context.cur in (getattr(m, self.serialize_enums_by) for m in enum_cls): + return meth(enum_cls, context) + else: + return context.cur + def _serialize_dict(self, annotation, context): context.assert_type(annotation) for k in context.cur: diff --git a/subiquity/common/tests/test_serialization.py b/subiquity/common/tests/test_serialization.py index b08f067e..8cb6513a 100644 --- a/subiquity/common/tests/test_serialization.py +++ b/subiquity/common/tests/test_serialization.py @@ -23,7 +23,12 @@ import unittest import attr -from subiquity.common.serialize import SerializationError, Serializer, named_field +from subiquity.common.serialize import ( + NonExhaustive, + SerializationError, + Serializer, + named_field, +) @attr.s(auto_attribs=True) @@ -133,12 +138,24 @@ class CommonSerializerTests: def test_enums(self): self.assertSerialization(MyEnum, MyEnum.name, "name") + def test_non_exhaustive_enums(self): + self.serializer = type(self.serializer)(compact=self.serializer.compact) + self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "name") + self.assertSerialization(NonExhaustive[MyEnum], "name2", "name2") + def test_enums_by_value(self): self.serializer = type(self.serializer)( compact=self.serializer.compact, serialize_enums_by="value" ) self.assertSerialization(MyEnum, MyEnum.name, "value") + def test_non_exhaustive_enums_by_value(self): + self.serializer = type(self.serializer)( + compact=self.serializer.compact, serialize_enums_by="value" + ) + self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "value") + self.assertSerialization(NonExhaustive[MyEnum], "value2", "value2") + def test_serialize_any(self): o = object() self.assertSerialization(typing.Any, o, o) From 5605942e709a9e6d4773c3f0fe6ef2e55d78753d Mon Sep 17 00:00:00 2001 From: Michael Hudson-Doyle Date: Fri, 3 May 2024 17:23:53 +1200 Subject: [PATCH 3/4] Add a way to annotate the result of a asynchronous snapd operation --- subiquity/server/snapdapi.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/subiquity/server/snapdapi.py b/subiquity/server/snapdapi.py index 4173c5aa..8af125cc 100644 --- a/subiquity/server/snapdapi.py +++ b/subiquity/server/snapdapi.py @@ -313,14 +313,17 @@ def make_api_client(async_snapd): snapd_serializer = Serializer(ignore_unknown_fields=True, serialize_enums_by="value") -async def post_and_wait(client, meth, *args, **kw): +async def post_and_wait(client, meth, *args, ann=None, **kw): change_id = await meth(*args, **kw) log.debug("post_and_wait %s", change_id) while True: result = await client.v2.changes[change_id].GET() if result.status == TaskStatus.DONE: - return result.data + data = result.data + if ann is not None: + data = snapd_serializer.deserialize(ann, data) + return data elif result.status == TaskStatus.ERROR: raise aiohttp.ClientError(result.err) await asyncio.sleep(0.1) From e3f16aaded63c37cb92f2791625e6f35c3093f38 Mon Sep 17 00:00:00 2001 From: Michael Hudson-Doyle Date: Fri, 26 Apr 2024 16:05:00 +1200 Subject: [PATCH 4/4] use NonExhaustive to make snapdapi.Role an enum again --- subiquity/server/controllers/filesystem.py | 6 +++--- .../controllers/tests/test_filesystem.py | 6 +++--- subiquity/server/snapdapi.py | 19 ++++++++++--------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/subiquity/server/controllers/filesystem.py b/subiquity/server/controllers/filesystem.py index f210267c..fbdaeebe 100644 --- a/subiquity/server/controllers/filesystem.py +++ b/subiquity/server/controllers/filesystem.py @@ -275,7 +275,7 @@ class FilesystemController(SubiquityController, FilesystemManipulator): self._on_volume: Optional[snapdapi.OnVolume] = None self._source_handler: Optional[AbstractSourceHandler] = None self._system_mounter: Optional[Mounter] = None - self._role_to_device: Dict[str, _Device] = {} + self._role_to_device: Dict[Union[str, snapdapi.Role], _Device] = {} self._device_to_structure: Dict[_Device, snapdapi.OnVolume] = {} self._pyudev_context: Optional[pyudev.Context] = None self.use_tpm: bool = False @@ -949,9 +949,9 @@ class FilesystemController(SubiquityController, FilesystemManipulator): step=snapdapi.SystemActionStep.SETUP_STORAGE_ENCRYPTION, on_volumes=self._on_volumes(), ), + ann=snapdapi.SystemActionResponse, ) - role_to_encrypted_device = result["encrypted-devices"] - for role, enc_path in role_to_encrypted_device.items(): + for role, enc_path in result.encrypted_devices.items(): arb_device = ArbitraryDevice(m=self.model, path=enc_path) self.model._actions.append(arb_device) part = self._role_to_device[role] diff --git a/subiquity/server/controllers/tests/test_filesystem.py b/subiquity/server/controllers/tests/test_filesystem.py index ee2f4d85..3663f608 100644 --- a/subiquity/server/controllers/tests/test_filesystem.py +++ b/subiquity/server/controllers/tests/test_filesystem.py @@ -1537,11 +1537,11 @@ class TestCoreBootInstallMethods(IsolatedAsyncioTestCase): with mock.patch.object( snapdapi, "post_and_wait", new_callable=mock.AsyncMock ) as mocked: - mocked.return_value = { - "encrypted-devices": { + mocked.return_value = snapdapi.SystemActionResponse( + encrypted_devices={ snapdapi.Role.SYSTEM_DATA: "enc-system-data", }, - } + ) await self.fsc.setup_encryption(context=self.fsc.context) # setup_encryption mutates the filesystem model objects to diff --git a/subiquity/server/snapdapi.py b/subiquity/server/snapdapi.py index 8af125cc..5ae38382 100644 --- a/subiquity/server/snapdapi.py +++ b/subiquity/server/snapdapi.py @@ -24,7 +24,7 @@ import attr from subiquity.common.api.client import make_client from subiquity.common.api.defs import Payload, api, path_parameter -from subiquity.common.serialize import Serializer, named_field +from subiquity.common.serialize import NonExhaustive, Serializer, named_field from subiquity.common.types import Change, TaskStatus log = logging.getLogger("subiquity.server.snapdapi") @@ -90,17 +90,11 @@ class Response: status: str -class Role: +class Role(enum.Enum): NONE = "" MBR = "mbr" SYSTEM_BOOT = "system-boot" - SYSTEM_BOOT_IMAGE = "system-boot-image" - SYSTEM_BOOT_SELECT = "system-boot-select" SYSTEM_DATA = "system-data" - SYSTEM_RECOVERY_SELECT = "system-recovery-select" - SYSTEM_SAVE = "system-save" - SYSTEM_SEED = "system-seed" - SYSTEM_SEED_NULL = "system-seed-null" @attr.s(auto_attribs=True) @@ -134,7 +128,7 @@ class VolumeStructure: offset_write: Optional[RelativeOffset] = named_field("offset-write", None) size: int = 0 type: str = "" - role: str = Role.NONE + role: NonExhaustive[Role] = Role.NONE id: Optional[str] = None filesystem: str = "" content: Optional[List[VolumeContent]] = None @@ -232,6 +226,13 @@ class SystemActionRequest: on_volumes: Dict[str, OnVolume] = named_field("on-volumes") +@attr.s(auto_attribs=True) +class SystemActionResponse: + encrypted_devices: Dict[NonExhaustive[Role], str] = named_field( + "encrypted-devices", default=attr.Factory(dict) + ) + + @api class SnapdAPI: serialize_query_args = False