add serialization support for unions of attr classes
unions are serialized as a tagged union by adding a '$type' field indicating which member of the set of types is present.
This commit is contained in:
parent
85ed9826e4
commit
da14af1c65
|
@ -53,12 +53,35 @@ class Serializer:
|
|||
|
||||
def _walk_Union(self, meth, args, value, metadata, path, serializing):
|
||||
NoneType = type(None)
|
||||
assert NoneType in args, "at {}, can only serialize Optional"
|
||||
args = [a for a in args if a is not NoneType]
|
||||
assert len(args) == 1, "at {}, can only serialize Optional"
|
||||
if value is None:
|
||||
return value
|
||||
return meth(args[0], value, metadata, path)
|
||||
if NoneType in args:
|
||||
args = [a for a in args if a is not NoneType]
|
||||
if len(args) == 1:
|
||||
# I.e. Optional[thing]
|
||||
if value is None:
|
||||
return value
|
||||
return meth(args[0], value, metadata, path)
|
||||
if all(attr.has(a) for a in args):
|
||||
if serializing:
|
||||
for a in args:
|
||||
if isinstance(value, a):
|
||||
r = meth(a, value, metadata, path)
|
||||
if self.compact:
|
||||
r.insert(0, a.__name__)
|
||||
else:
|
||||
r['$type'] = a.__name__
|
||||
return r
|
||||
raise Exception(
|
||||
f"at {path}, type of {value} not found in {args}")
|
||||
else:
|
||||
if self.compact:
|
||||
n = value.pop(0)
|
||||
else:
|
||||
n = value.pop('$type')
|
||||
for a in args:
|
||||
if a.__name__ == n:
|
||||
return meth(a, value, metadata, path)
|
||||
raise Exception(f"at {path}, type {n} not found in {args}")
|
||||
raise Exception(f"at {path}, cannot serialize Union[{args}]")
|
||||
|
||||
def _walk_List(self, meth, args, value, metadata, path, serializing):
|
||||
return [
|
||||
|
|
|
@ -112,6 +112,11 @@ class CommonSerializerTests:
|
|||
def test_serialize_dict_strkeys(self):
|
||||
self.assertSerialization(typing.Dict[str, str], {"a": "b"}, {"a": "b"})
|
||||
|
||||
def test_rountrip_union(self):
|
||||
ann = typing.Union[Data, Container]
|
||||
self.assertRoundtrips(ann, Data.make_random())
|
||||
self.assertRoundtrips(ann, Container.make_random())
|
||||
|
||||
|
||||
class TestSerializer(CommonSerializerTests, unittest.TestCase):
|
||||
|
||||
|
@ -134,6 +139,15 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
|
|||
}
|
||||
self.assertSerialization(Container, container, expected)
|
||||
|
||||
def test_serialize_union(self):
|
||||
data = Data.make_random()
|
||||
expected = {
|
||||
'$type': 'Data',
|
||||
'field1': data.field1,
|
||||
'field2': data.field2,
|
||||
}
|
||||
self.assertSerialization(typing.Union[Data, Container], data, expected)
|
||||
|
||||
|
||||
class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
|
||||
|
||||
|
@ -153,3 +167,8 @@ class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
|
|||
[[data2.field1, data2.field2]],
|
||||
]
|
||||
self.assertSerialization(Container, container, expected)
|
||||
|
||||
def test_serialize_union(self):
|
||||
data = Data.make_random()
|
||||
expected = ['Data', data.field1, data.field2]
|
||||
self.assertSerialization(typing.Union[Data, Container], data, expected)
|
||||
|
|
Loading…
Reference in New Issue