From b53c1fffb781af25fa6d702c95a3fc526121da49 Mon Sep 17 00:00:00 2001 From: Michael Hudson-Doyle Date: Fri, 26 Apr 2024 15:48:14 +1200 Subject: [PATCH] subiquity.common.serialize: add a way to have a "non-exhaustive" enum We use the subiquity.common.api stuff to talk to snapd, and use enumerations to describe several field names. This is a bit of a landmine in some situations because snapd can add a value to the set of possible return values and thus cause serialization to fail. Rather than just giving up on all of the typo-resistance of declaring enums for API types, this adds a way to mark an enumeration as "non-exhaustive": values that are part of the declared enum will be deserialized as values of the enum but values that are not will be passed straight through (as strings, in all the cases used in snapdapi.py), which conveniently will never compare equal to an enumeration. This should let us just declare the values of the enumerations we actually care about and not break if we don't declare every value snapd actually uses. --- subiquity/common/serialize.py | 24 ++++++++++++++++++++ subiquity/common/tests/test_serialization.py | 19 +++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/subiquity/common/serialize.py b/subiquity/common/serialize.py index 47cac2ba..9bf5319d 100644 --- a/subiquity/common/serialize.py +++ b/subiquity/common/serialize.py @@ -43,6 +43,13 @@ class SerializationError(Exception): return f"processing {self.obj}: at {p}, {self.message}" +E = typing.TypeVar("E") + + +class NonExhaustive(typing.Generic[E]): + pass + + @attr.s(auto_attribs=True) class SerializationContext: obj: typing.Any @@ -88,6 +95,7 @@ class Serializer: typing.List: self._walk_List, dict: self._walk_Dict, typing.Dict: self._walk_Dict, + NonExhaustive: self._walk_NonExhaustive, } self.type_serializers = {} self.type_deserializers = {} @@ -102,6 +110,9 @@ class Serializer: def _ann_ok_as_dict_key(self, annotation): if annotation is str: return True + origin = getattr(annotation, "__origin__", None) + if origin is NonExhaustive: + annotation = annotation.__args__[0] if isinstance(annotation, type) and issubclass(annotation, enum.Enum): if self.serialize_enums_by == "name": return True @@ -176,6 +187,19 @@ class Serializer: else: return dict(output_items) + def _walk_NonExhaustive(self, meth, args, context): + [enum_cls] = args + if context.serializing: + if isinstance(context.cur, enum_cls): + return meth(enum_cls, context) + else: + return context.cur + else: + if context.cur in (getattr(m, self.serialize_enums_by) for m in enum_cls): + return meth(enum_cls, context) + else: + return context.cur + def _serialize_dict(self, annotation, context): context.assert_type(annotation) for k in context.cur: diff --git a/subiquity/common/tests/test_serialization.py b/subiquity/common/tests/test_serialization.py index b08f067e..8cb6513a 100644 --- a/subiquity/common/tests/test_serialization.py +++ b/subiquity/common/tests/test_serialization.py @@ -23,7 +23,12 @@ import unittest import attr -from subiquity.common.serialize import SerializationError, Serializer, named_field +from subiquity.common.serialize import ( + NonExhaustive, + SerializationError, + Serializer, + named_field, +) @attr.s(auto_attribs=True) @@ -133,12 +138,24 @@ class CommonSerializerTests: def test_enums(self): self.assertSerialization(MyEnum, MyEnum.name, "name") + def test_non_exhaustive_enums(self): + self.serializer = type(self.serializer)(compact=self.serializer.compact) + self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "name") + self.assertSerialization(NonExhaustive[MyEnum], "name2", "name2") + def test_enums_by_value(self): self.serializer = type(self.serializer)( compact=self.serializer.compact, serialize_enums_by="value" ) self.assertSerialization(MyEnum, MyEnum.name, "value") + def test_non_exhaustive_enums_by_value(self): + self.serializer = type(self.serializer)( + compact=self.serializer.compact, serialize_enums_by="value" + ) + self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "value") + self.assertSerialization(NonExhaustive[MyEnum], "value2", "value2") + def test_serialize_any(self): o = object() self.assertSerialization(typing.Any, o, o)