add serialization support for unions of attr classes

unions are serialized as a tagged union by adding a '$type' field
indicating which member of the set of types is present.
This commit is contained in:
Michael Hudson-Doyle 2021-03-16 09:46:44 +13:00
parent 85ed9826e4
commit da14af1c65
2 changed files with 48 additions and 6 deletions

View File

@ -53,12 +53,35 @@ class Serializer:
def _walk_Union(self, meth, args, value, metadata, path, serializing):
NoneType = type(None)
assert NoneType in args, "at {}, can only serialize Optional"
args = [a for a in args if a is not NoneType]
assert len(args) == 1, "at {}, can only serialize Optional"
if value is None:
return value
return meth(args[0], value, metadata, path)
if NoneType in args:
args = [a for a in args if a is not NoneType]
if len(args) == 1:
# I.e. Optional[thing]
if value is None:
return value
return meth(args[0], value, metadata, path)
if all(attr.has(a) for a in args):
if serializing:
for a in args:
if isinstance(value, a):
r = meth(a, value, metadata, path)
if self.compact:
r.insert(0, a.__name__)
else:
r['$type'] = a.__name__
return r
raise Exception(
f"at {path}, type of {value} not found in {args}")
else:
if self.compact:
n = value.pop(0)
else:
n = value.pop('$type')
for a in args:
if a.__name__ == n:
return meth(a, value, metadata, path)
raise Exception(f"at {path}, type {n} not found in {args}")
raise Exception(f"at {path}, cannot serialize Union[{args}]")
def _walk_List(self, meth, args, value, metadata, path, serializing):
return [

View File

@ -112,6 +112,11 @@ class CommonSerializerTests:
def test_serialize_dict_strkeys(self):
self.assertSerialization(typing.Dict[str, str], {"a": "b"}, {"a": "b"})
def test_rountrip_union(self):
ann = typing.Union[Data, Container]
self.assertRoundtrips(ann, Data.make_random())
self.assertRoundtrips(ann, Container.make_random())
class TestSerializer(CommonSerializerTests, unittest.TestCase):
@ -134,6 +139,15 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
}
self.assertSerialization(Container, container, expected)
def test_serialize_union(self):
data = Data.make_random()
expected = {
'$type': 'Data',
'field1': data.field1,
'field2': data.field2,
}
self.assertSerialization(typing.Union[Data, Container], data, expected)
class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
@ -153,3 +167,8 @@ class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
[[data2.field1, data2.field2]],
]
self.assertSerialization(Container, container, expected)
def test_serialize_union(self):
data = Data.make_random()
expected = ['Data', data.field1, data.field2]
self.assertSerialization(typing.Union[Data, Container], data, expected)