subiquity.common.serialize: add a way to have a "non-exhaustive" enum
We use the subiquity.common.api stuff to talk to snapd, and use enumerations to describe several field names. This is a bit of a landmine in some situations because snapd can add a value to the set of possible return values and thus cause serialization to fail. Rather than just giving up on all of the typo-resistance of declaring enums for API types, this adds a way to mark an enumeration as "non-exhaustive": values that are part of the declared enum will be deserialized as values of the enum but values that are not will be passed straight through (as strings, in all the cases used in snapdapi.py), which conveniently will never compare equal to an enumeration. This should let us just declare the values of the enumerations we actually care about and not break if we don't declare every value snapd actually uses.
This commit is contained in:
parent
a59ef9fba0
commit
b53c1fffb7
|
@ -43,6 +43,13 @@ class SerializationError(Exception):
|
||||||
return f"processing {self.obj}: at {p}, {self.message}"
|
return f"processing {self.obj}: at {p}, {self.message}"
|
||||||
|
|
||||||
|
|
||||||
|
E = typing.TypeVar("E")
|
||||||
|
|
||||||
|
|
||||||
|
class NonExhaustive(typing.Generic[E]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True)
|
@attr.s(auto_attribs=True)
|
||||||
class SerializationContext:
|
class SerializationContext:
|
||||||
obj: typing.Any
|
obj: typing.Any
|
||||||
|
@ -88,6 +95,7 @@ class Serializer:
|
||||||
typing.List: self._walk_List,
|
typing.List: self._walk_List,
|
||||||
dict: self._walk_Dict,
|
dict: self._walk_Dict,
|
||||||
typing.Dict: self._walk_Dict,
|
typing.Dict: self._walk_Dict,
|
||||||
|
NonExhaustive: self._walk_NonExhaustive,
|
||||||
}
|
}
|
||||||
self.type_serializers = {}
|
self.type_serializers = {}
|
||||||
self.type_deserializers = {}
|
self.type_deserializers = {}
|
||||||
|
@ -102,6 +110,9 @@ class Serializer:
|
||||||
def _ann_ok_as_dict_key(self, annotation):
|
def _ann_ok_as_dict_key(self, annotation):
|
||||||
if annotation is str:
|
if annotation is str:
|
||||||
return True
|
return True
|
||||||
|
origin = getattr(annotation, "__origin__", None)
|
||||||
|
if origin is NonExhaustive:
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
|
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
|
||||||
if self.serialize_enums_by == "name":
|
if self.serialize_enums_by == "name":
|
||||||
return True
|
return True
|
||||||
|
@ -176,6 +187,19 @@ class Serializer:
|
||||||
else:
|
else:
|
||||||
return dict(output_items)
|
return dict(output_items)
|
||||||
|
|
||||||
|
def _walk_NonExhaustive(self, meth, args, context):
|
||||||
|
[enum_cls] = args
|
||||||
|
if context.serializing:
|
||||||
|
if isinstance(context.cur, enum_cls):
|
||||||
|
return meth(enum_cls, context)
|
||||||
|
else:
|
||||||
|
return context.cur
|
||||||
|
else:
|
||||||
|
if context.cur in (getattr(m, self.serialize_enums_by) for m in enum_cls):
|
||||||
|
return meth(enum_cls, context)
|
||||||
|
else:
|
||||||
|
return context.cur
|
||||||
|
|
||||||
def _serialize_dict(self, annotation, context):
|
def _serialize_dict(self, annotation, context):
|
||||||
context.assert_type(annotation)
|
context.assert_type(annotation)
|
||||||
for k in context.cur:
|
for k in context.cur:
|
||||||
|
|
|
@ -23,7 +23,12 @@ import unittest
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from subiquity.common.serialize import SerializationError, Serializer, named_field
|
from subiquity.common.serialize import (
|
||||||
|
NonExhaustive,
|
||||||
|
SerializationError,
|
||||||
|
Serializer,
|
||||||
|
named_field,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True)
|
@attr.s(auto_attribs=True)
|
||||||
|
@ -133,12 +138,24 @@ class CommonSerializerTests:
|
||||||
def test_enums(self):
|
def test_enums(self):
|
||||||
self.assertSerialization(MyEnum, MyEnum.name, "name")
|
self.assertSerialization(MyEnum, MyEnum.name, "name")
|
||||||
|
|
||||||
|
def test_non_exhaustive_enums(self):
|
||||||
|
self.serializer = type(self.serializer)(compact=self.serializer.compact)
|
||||||
|
self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "name")
|
||||||
|
self.assertSerialization(NonExhaustive[MyEnum], "name2", "name2")
|
||||||
|
|
||||||
def test_enums_by_value(self):
|
def test_enums_by_value(self):
|
||||||
self.serializer = type(self.serializer)(
|
self.serializer = type(self.serializer)(
|
||||||
compact=self.serializer.compact, serialize_enums_by="value"
|
compact=self.serializer.compact, serialize_enums_by="value"
|
||||||
)
|
)
|
||||||
self.assertSerialization(MyEnum, MyEnum.name, "value")
|
self.assertSerialization(MyEnum, MyEnum.name, "value")
|
||||||
|
|
||||||
|
def test_non_exhaustive_enums_by_value(self):
|
||||||
|
self.serializer = type(self.serializer)(
|
||||||
|
compact=self.serializer.compact, serialize_enums_by="value"
|
||||||
|
)
|
||||||
|
self.assertSerialization(NonExhaustive[MyEnum], MyEnum.name, "value")
|
||||||
|
self.assertSerialization(NonExhaustive[MyEnum], "value2", "value2")
|
||||||
|
|
||||||
def test_serialize_any(self):
|
def test_serialize_any(self):
|
||||||
o = object()
|
o = object()
|
||||||
self.assertSerialization(typing.Any, o, o)
|
self.assertSerialization(typing.Any, o, o)
|
||||||
|
|
Loading…
Reference in New Issue