Skip to content

Commit 589b0e3

Browse files
fix(parsing): correctly handle nested discriminated unions
1 parent fe82bb4 commit 589b0e3

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

src/openai/_models.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
66
from datetime import date, datetime
77
from typing_extensions import (
8+
List,
89
Unpack,
910
Literal,
1011
ClassVar,
@@ -391,7 +392,7 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
391392
if type_ is None:
392393
raise RuntimeError(f"Unexpected field type is None for {key}")
393394

394-
return construct_type(value=value, type_=type_)
395+
return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
395396

396397

397398
def is_basemodel(type_: type) -> bool:
@@ -445,7 +446,7 @@ def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
445446
return cast(_T, construct_type(value=value, type_=type_))
446447

447448

448-
def construct_type(*, value: object, type_: object) -> object:
449+
def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object:
449450
"""Loose coercion to the expected type with construction of nested values.
450451
451452
If the given value does not match the expected type then it is returned as-is.
@@ -463,8 +464,10 @@ def construct_type(*, value: object, type_: object) -> object:
463464
type_ = type_.__value__ # type: ignore[unreachable]
464465

465466
# unwrap `Annotated[T, ...]` -> `T`
466-
if is_annotated_type(type_):
467-
meta: tuple[Any, ...] = get_args(type_)[1:]
467+
if metadata is not None:
468+
meta: tuple[Any, ...] = tuple(metadata)
469+
elif is_annotated_type(type_):
470+
meta = get_args(type_)[1:]
468471
type_ = extract_type_arg(type_, 0)
469472
else:
470473
meta = tuple()

tests/test_models.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,3 +889,48 @@ class ModelB(BaseModel):
889889
)
890890

891891
assert isinstance(m, ModelB)
892+
893+
894+
def test_nested_discriminated_union() -> None:
895+
class InnerType1(BaseModel):
896+
type: Literal["type_1"]
897+
898+
class InnerModel(BaseModel):
899+
inner_value: str
900+
901+
class InnerType2(BaseModel):
902+
type: Literal["type_2"]
903+
some_inner_model: InnerModel
904+
905+
class Type1(BaseModel):
906+
base_type: Literal["base_type_1"]
907+
value: Annotated[
908+
Union[
909+
InnerType1,
910+
InnerType2,
911+
],
912+
PropertyInfo(discriminator="type"),
913+
]
914+
915+
class Type2(BaseModel):
916+
base_type: Literal["base_type_2"]
917+
918+
T = Annotated[
919+
Union[
920+
Type1,
921+
Type2,
922+
],
923+
PropertyInfo(discriminator="base_type"),
924+
]
925+
926+
model = construct_type(
927+
type_=T,
928+
value={
929+
"base_type": "base_type_1",
930+
"value": {
931+
"type": "type_2",
932+
},
933+
},
934+
)
935+
assert isinstance(model, Type1)
936+
assert isinstance(model.value, InnerType2)

0 commit comments

Comments
 (0)