diff --git a/subiquity/common/serialize.py b/subiquity/common/serialize.py index 47cac2ba..9bf5319d 100644 --- a/subiquity/common/serialize.py +++ b/subiquity/common/serialize.py @@ -43,6 +43,13 @@ class SerializationError(Exception): return f"processing {self.obj}: at {p}, {self.message}" +E = typing.TypeVar("E") + + +class NonExhaustive(typing.Generic[E]): + pass + + @attr.s(auto_attribs=True) class SerializationContext: obj: typing.Any @@ -88,6 +95,7 @@ class Serializer: typing.List: self._walk_List, dict: self._walk_Dict, typing.Dict: self._walk_Dict, + NonExhaustive: self._walk_NonExhaustive, } self.type_serializers = {} self.type_deserializers = {} @@ -102,6 +110,9 @@ class Serializer: def _ann_ok_as_dict_key(self, annotation): if annotation is str: return True + origin = getattr(annotation, "__origin__", None) + if origin is NonExhaustive: + annotation = annotation.__args__[0] if isinstance(annotation, type) and issubclass(annotation, enum.Enum): if self.serialize_enums_by == "name": return True @@ -176,6 +187,19 @@ class Serializer: else: return dict(output_items) + def _walk_NonExhaustive(self, meth, args, context): + [enum_cls] = args + if context.serializing: + if isinstance(context.cur, enum_cls): + return meth(enum_cls, context) + else: + return context.cur + else: + if context.cur in (getattr(m, self.serialize_enums_by) for m in enum_cls): + return meth(enum_cls, context) + else: + return context.cur + def _serialize_dict(self, annotation, context): context.assert_type(annotation) for k in context.cur: diff --git a/subiquity/common/tests/test_serialization.py b/subiquity/common/tests/test_serialization.py index b08f067e..8cb6513a 100644 --- a/subiquity/common/tests/test_serialization.py +++ b/subiquity/common/tests/test_serialization.py @@ -23,7 +23,12 @@ import unittest import attr -from subiquity.common.serialize import SerializationError, Serializer, named_field +from subiquity.common.serialize import ( + NonExhaustive, + SerializationError, + Serializer, + named_field, +) @attr.s(auto_attribs=True) @@ -133,12 +138,24 @@ class CommonSerializerTests: def test_enums(self): self.assertSerialization(MyEnum, MyEnum.name, "name") + def test_non_exhaustive_enums(self): + self.serializer = type(self.serializer)(compact=self.serializer.compact) + self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "name") + self.assertSerialization(NonExhaustive[MyEnum], "name2", "name2") + def test_enums_by_value(self): self.serializer = type(self.serializer)( compact=self.serializer.compact, serialize_enums_by="value" ) self.assertSerialization(MyEnum, MyEnum.name, "value") + def test_non_exhaustive_enums_by_value(self): + self.serializer = type(self.serializer)( + compact=self.serializer.compact, serialize_enums_by="value" + ) + self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "value") + self.assertSerialization(NonExhaustive[MyEnum], "value2", "value2") + def test_serialize_any(self): o = object() self.assertSerialization(typing.Any, o, o)