Skip to content

Commit cd0346d

Browse files
More union serialization tidying (#1536)
1 parent cd270e4 commit cd0346d

File tree

3 files changed

+85
-10
lines changed

3 files changed

+85
-10
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ install:
1515
pip install -U pip wheel pre-commit
1616
pip install -r tests/requirements.txt
1717
pip install -r tests/requirements-linting.txt
18-
pip install -e .
18+
pip install -v -e .
1919
pre-commit install
2020

2121
.PHONY: install-rust-coverage

src/serializers/type_serializers/union.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -422,19 +422,32 @@ impl TaggedUnionSerializer {
422422
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
423423
let py = value.py();
424424
let discriminator_value = match &self.discriminator {
425-
Discriminator::LookupKey(lookup_key) => lookup_key
426-
.simple_py_get_attr(value)
427-
.ok()
428-
.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py))),
425+
Discriminator::LookupKey(lookup_key) => {
426+
// we're pretty lax here, we allow either dict[key] or object.key, as we very well could
427+
// be doing a discriminator lookup on a typed dict, and there's no good way to check that
428+
// at this point. we could be more strict and only do this in lax mode...
429+
let getattr_result = match value.is_instance_of::<PyDict>() {
430+
true => {
431+
let value_dict = value.downcast::<PyDict>().unwrap();
432+
lookup_key.py_get_dict_item(value_dict).ok()
433+
}
434+
false => lookup_key.simple_py_get_attr(value).ok(),
435+
};
436+
getattr_result.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py)))
437+
}
429438
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
430439
};
431440
if discriminator_value.is_none() {
432441
let value_str = truncate_safe_repr(value, None);
433-
extra.warnings.custom_warning(
434-
format!(
435-
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
436-
)
437-
);
442+
443+
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning
444+
if extra.check == SerCheck::None {
445+
extra.warnings.custom_warning(
446+
format!(
447+
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
448+
)
449+
);
450+
}
438451
}
439452
discriminator_value
440453
}

tests/serializers/test_union.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,43 @@ def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
948948
assert m in str(w[0].message)
949949

950950

951+
def test_mixed_union_models_and_other_types() -> None:
952+
s = SchemaSerializer(
953+
core_schema.union_schema(
954+
[
955+
core_schema.tagged_union_schema(
956+
discriminator='type_',
957+
choices={
958+
'cat': core_schema.model_schema(
959+
cls=ModelCat,
960+
schema=core_schema.model_fields_schema(
961+
fields={
962+
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
963+
},
964+
),
965+
),
966+
'dog': core_schema.model_schema(
967+
cls=ModelDog,
968+
schema=core_schema.model_fields_schema(
969+
fields={
970+
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
971+
},
972+
),
973+
),
974+
},
975+
),
976+
core_schema.str_schema(),
977+
]
978+
)
979+
)
980+
981+
assert s.to_python(ModelCat(type_='cat'), warnings='error') == {'type_': 'cat'}
982+
assert s.to_python(ModelDog(type_='dog'), warnings='error') == {'type_': 'dog'}
983+
# note, this fails as ModelCat and ModelDog (discriminator warnings, etc), but the warnings
984+
# don't bubble up to this level :)
985+
assert s.to_python('a string', warnings='error') == 'a string'
986+
987+
951988
@pytest.mark.parametrize(
952989
'input,expected',
953990
[
@@ -1000,3 +1037,28 @@ def test_union_of_unions_of_models_with_tagged_union_json_serialization(
10001037
)
10011038

10021039
assert s.to_json(input, warnings='error') == expected
1040+
1041+
1042+
def test_discriminated_union_ser_with_typed_dict() -> None:
1043+
v = SchemaSerializer(
1044+
core_schema.tagged_union_schema(
1045+
{
1046+
'a': core_schema.typed_dict_schema(
1047+
{
1048+
'type': core_schema.typed_dict_field(core_schema.literal_schema(['a'])),
1049+
'a': core_schema.typed_dict_field(core_schema.int_schema()),
1050+
}
1051+
),
1052+
'b': core_schema.typed_dict_schema(
1053+
{
1054+
'type': core_schema.typed_dict_field(core_schema.literal_schema(['b'])),
1055+
'b': core_schema.typed_dict_field(core_schema.str_schema()),
1056+
}
1057+
),
1058+
},
1059+
discriminator='type',
1060+
)
1061+
)
1062+
1063+
assert v.to_python({'type': 'a', 'a': 1}, warnings='error') == {'type': 'a', 'a': 1}
1064+
assert v.to_python({'type': 'b', 'b': 'foo'}, warnings='error') == {'type': 'b', 'b': 'foo'}

0 commit comments

Comments
 (0)