add serialization support for unions of attr classes
unions are serialized as a tagged union by adding a '$type' field indicating which member of the set of types is present.
This commit is contained in:
parent
85ed9826e4
commit
da14af1c65
|
@ -53,12 +53,35 @@ class Serializer:
|
||||||
|
|
||||||
def _walk_Union(self, meth, args, value, metadata, path, serializing):
|
def _walk_Union(self, meth, args, value, metadata, path, serializing):
|
||||||
NoneType = type(None)
|
NoneType = type(None)
|
||||||
assert NoneType in args, "at {}, can only serialize Optional"
|
if NoneType in args:
|
||||||
args = [a for a in args if a is not NoneType]
|
args = [a for a in args if a is not NoneType]
|
||||||
assert len(args) == 1, "at {}, can only serialize Optional"
|
if len(args) == 1:
|
||||||
if value is None:
|
# I.e. Optional[thing]
|
||||||
return value
|
if value is None:
|
||||||
return meth(args[0], value, metadata, path)
|
return value
|
||||||
|
return meth(args[0], value, metadata, path)
|
||||||
|
if all(attr.has(a) for a in args):
|
||||||
|
if serializing:
|
||||||
|
for a in args:
|
||||||
|
if isinstance(value, a):
|
||||||
|
r = meth(a, value, metadata, path)
|
||||||
|
if self.compact:
|
||||||
|
r.insert(0, a.__name__)
|
||||||
|
else:
|
||||||
|
r['$type'] = a.__name__
|
||||||
|
return r
|
||||||
|
raise Exception(
|
||||||
|
f"at {path}, type of {value} not found in {args}")
|
||||||
|
else:
|
||||||
|
if self.compact:
|
||||||
|
n = value.pop(0)
|
||||||
|
else:
|
||||||
|
n = value.pop('$type')
|
||||||
|
for a in args:
|
||||||
|
if a.__name__ == n:
|
||||||
|
return meth(a, value, metadata, path)
|
||||||
|
raise Exception(f"at {path}, type {n} not found in {args}")
|
||||||
|
raise Exception(f"at {path}, cannot serialize Union[{args}]")
|
||||||
|
|
||||||
def _walk_List(self, meth, args, value, metadata, path, serializing):
|
def _walk_List(self, meth, args, value, metadata, path, serializing):
|
||||||
return [
|
return [
|
||||||
|
|
|
@ -112,6 +112,11 @@ class CommonSerializerTests:
|
||||||
def test_serialize_dict_strkeys(self):
|
def test_serialize_dict_strkeys(self):
|
||||||
self.assertSerialization(typing.Dict[str, str], {"a": "b"}, {"a": "b"})
|
self.assertSerialization(typing.Dict[str, str], {"a": "b"}, {"a": "b"})
|
||||||
|
|
||||||
|
def test_rountrip_union(self):
|
||||||
|
ann = typing.Union[Data, Container]
|
||||||
|
self.assertRoundtrips(ann, Data.make_random())
|
||||||
|
self.assertRoundtrips(ann, Container.make_random())
|
||||||
|
|
||||||
|
|
||||||
class TestSerializer(CommonSerializerTests, unittest.TestCase):
|
class TestSerializer(CommonSerializerTests, unittest.TestCase):
|
||||||
|
|
||||||
|
@ -134,6 +139,15 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
|
||||||
}
|
}
|
||||||
self.assertSerialization(Container, container, expected)
|
self.assertSerialization(Container, container, expected)
|
||||||
|
|
||||||
|
def test_serialize_union(self):
|
||||||
|
data = Data.make_random()
|
||||||
|
expected = {
|
||||||
|
'$type': 'Data',
|
||||||
|
'field1': data.field1,
|
||||||
|
'field2': data.field2,
|
||||||
|
}
|
||||||
|
self.assertSerialization(typing.Union[Data, Container], data, expected)
|
||||||
|
|
||||||
|
|
||||||
class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
|
class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
|
||||||
|
|
||||||
|
@ -153,3 +167,8 @@ class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
|
||||||
[[data2.field1, data2.field2]],
|
[[data2.field1, data2.field2]],
|
||||||
]
|
]
|
||||||
self.assertSerialization(Container, container, expected)
|
self.assertSerialization(Container, container, expected)
|
||||||
|
|
||||||
|
def test_serialize_union(self):
|
||||||
|
data = Data.make_random()
|
||||||
|
expected = ['Data', data.field1, data.field2]
|
||||||
|
self.assertSerialization(typing.Union[Data, Container], data, expected)
|
||||||
|
|
Loading…
Reference in New Issue