subiquity.common.serialize: add a way to have a "non-exhaustive" enum

We use the subiquity.common.api stuff to talk to snapd, and use
enumerations to describe several field names. This is a bit of a
landmine in some situations because snapd can add a value to the set of
possible return values and thus cause serialization to fail. Rather than
just giving up on all of the typo-resistance of declaring enums for API
types, this adds a way to mark an enumeration as "non-exhaustive":
values that are part of the declared enum will be deserialized as values
of the enum but values that are not will be passed straight through (as
strings, in all the cases used in snapdapi.py), which conveniently will
never compare equal to an enumeration. This should let us just declare
the values of the enumerations we actually care about and not break if
we don't declare every value snapd actually uses.
This commit is contained in:
Michael Hudson-Doyle 2024-04-26 15:48:14 +12:00
parent a59ef9fba0
commit b53c1fffb7
2 changed files with 42 additions and 1 deletions

View File

@ -43,6 +43,13 @@ class SerializationError(Exception):
return f"processing {self.obj}: at {p}, {self.message}"
E = typing.TypeVar("E")
class NonExhaustive(typing.Generic[E]):
pass
@attr.s(auto_attribs=True)
class SerializationContext:
obj: typing.Any
@ -88,6 +95,7 @@ class Serializer:
typing.List: self._walk_List,
dict: self._walk_Dict,
typing.Dict: self._walk_Dict,
NonExhaustive: self._walk_NonExhaustive,
}
self.type_serializers = {}
self.type_deserializers = {}
@ -102,6 +110,9 @@ class Serializer:
def _ann_ok_as_dict_key(self, annotation):
if annotation is str:
return True
origin = getattr(annotation, "__origin__", None)
if origin is NonExhaustive:
annotation = annotation.__args__[0]
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
if self.serialize_enums_by == "name":
return True
@ -176,6 +187,19 @@ class Serializer:
else:
return dict(output_items)
def _walk_NonExhaustive(self, meth, args, context):
[enum_cls] = args
if context.serializing:
if isinstance(context.cur, enum_cls):
return meth(enum_cls, context)
else:
return context.cur
else:
if context.cur in (getattr(m, self.serialize_enums_by) for m in enum_cls):
return meth(enum_cls, context)
else:
return context.cur
def _serialize_dict(self, annotation, context):
context.assert_type(annotation)
for k in context.cur:

View File

@ -23,7 +23,12 @@ import unittest
import attr
from subiquity.common.serialize import SerializationError, Serializer, named_field
from subiquity.common.serialize import (
NonExhaustive,
SerializationError,
Serializer,
named_field,
)
@attr.s(auto_attribs=True)
@ -133,12 +138,24 @@ class CommonSerializerTests:
def test_enums(self):
self.assertSerialization(MyEnum, MyEnum.name, "name")
def test_non_exhaustive_enums(self):
self.serializer = type(self.serializer)(compact=self.serializer.compact)
self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "name")
self.assertSerialization(NonExhaustive[MyEnum], "name2", "name2")
def test_enums_by_value(self):
self.serializer = type(self.serializer)(
compact=self.serializer.compact, serialize_enums_by="value"
)
self.assertSerialization(MyEnum, MyEnum.name, "value")
def test_non_exhaustive_enums_by_value(self):
self.serializer = type(self.serializer)(
compact=self.serializer.compact, serialize_enums_by="value"
)
self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "value")
self.assertSerialization(NonExhaustive[MyEnum], "value2", "value2")
def test_serialize_any(self):
o = object()
self.assertSerialization(typing.Any, o, o)