Skip to content

Commit b65d178

Browse files
committed
WIP
1 parent f4a0675 commit b65d178

File tree

7 files changed

+212
-1
lines changed

7 files changed

+212
-1
lines changed

python/pydantic_core/core_schema.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,24 @@ def uuid_schema(
14371437
)
14381438

14391439

1440+
class NestedModelSchema(TypedDict, total=False):
1441+
type: Required[Literal['nested-model']]
1442+
model: Required[Type[Any]]
1443+
metadata: Any
1444+
1445+
1446+
def nested_model_schema(
1447+
*,
1448+
model: Type[Any],
1449+
metadata: Any = None,
1450+
) -> NestedModelSchema:
1451+
return _dict_not_none(
1452+
type='nested-model',
1453+
model=model,
1454+
metadata=metadata,
1455+
)
1456+
1457+
14401458
class IncExSeqSerSchema(TypedDict, total=False):
14411459
type: Required[Literal['include-exclude-sequence']]
14421460
include: Set[int]
@@ -3866,6 +3884,7 @@ def definition_reference_schema(
38663884
DefinitionReferenceSchema,
38673885
UuidSchema,
38683886
ComplexSchema,
3887+
NestedModelSchema,
38693888
]
38703889
elif False:
38713890
CoreSchema: TypeAlias = Mapping[str, Any]
@@ -3922,6 +3941,7 @@ def definition_reference_schema(
39223941
'definition-ref',
39233942
'uuid',
39243943
'complex',
3944+
'nested-model',
39253945
]
39263946

39273947
CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']

src/serializers/shared.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ combined_serializer! {
143143
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
144144
Tuple: super::type_serializers::tuple::TupleSerializer;
145145
Complex: super::type_serializers::complex::ComplexSerializer;
146+
NestedModel: super::type_serializers::nested_model::NestedModelSerializer;
146147
}
147148
}
148149

@@ -254,6 +255,7 @@ impl PyGcTraverse for CombinedSerializer {
254255
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
255256
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
256257
CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit),
258+
CombinedSerializer::NestedModel(inner) => inner.py_gc_traverse(visit),
257259
}
258260
}
259261
}

src/serializers/type_serializers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub mod json_or_python;
1616
pub mod list;
1717
pub mod literal;
1818
pub mod model;
19+
pub mod nested_model;
1920
pub mod nullable;
2021
pub mod other;
2122
pub mod set_frozenset;
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
use std::borrow::Cow;
2+
3+
use pyo3::{
4+
intern,
5+
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
6+
Bound, Py, PyAny, PyObject, PyResult, Python,
7+
};
8+
9+
use crate::{
10+
definitions::DefinitionsBuilder,
11+
serializers::{
12+
shared::{BuildSerializer, TypeSerializer},
13+
CombinedSerializer, Extra,
14+
},
15+
SchemaSerializer,
16+
};
17+
18+
#[derive(Debug, Clone)]
19+
pub struct NestedModelSerializer {
20+
model: Py<PyType>,
21+
name: String,
22+
}
23+
24+
impl_py_gc_traverse!(NestedModelSerializer { model });
25+
26+
impl BuildSerializer for NestedModelSerializer {
27+
const EXPECTED_TYPE: &'static str = "nested-model";
28+
29+
fn build(
30+
schema: &Bound<'_, PyDict>,
31+
_config: Option<&Bound<'_, PyDict>>,
32+
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
33+
) -> PyResult<CombinedSerializer> {
34+
let py = schema.py();
35+
let model = schema
36+
.get_item(intern!(py, "model"))?
37+
.expect("Invalid core schema for `nested-model` type")
38+
.downcast::<PyType>()
39+
.expect("Invalid core schema for `nested-model` type")
40+
.clone();
41+
42+
let name = model.getattr(intern!(py, "__name__"))?.extract()?;
43+
44+
Ok(CombinedSerializer::NestedModel(NestedModelSerializer {
45+
model: model.clone().unbind(),
46+
name,
47+
}))
48+
}
49+
}
50+
51+
impl NestedModelSerializer {
52+
fn nested_serializer<'py>(&self, py: Python<'py>) -> Bound<'py, SchemaSerializer> {
53+
self.model
54+
.bind(py)
55+
.call_method(intern!(py, "model_rebuild"), (), None)
56+
.unwrap();
57+
58+
self.model
59+
.getattr(py, intern!(py, "__pydantic_serializer__"))
60+
.unwrap()
61+
.downcast_bound::<SchemaSerializer>(py)
62+
.unwrap()
63+
.clone()
64+
65+
// crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
66+
// .downcast_bound::<SchemaSerializer>(py)
67+
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
68+
// .expect("Cached validator was not a `SchemaSerializer`")
69+
// .clone()
70+
}
71+
}
72+
73+
impl TypeSerializer for NestedModelSerializer {
74+
fn to_python(
75+
&self,
76+
value: &Bound<'_, PyAny>,
77+
include: Option<&Bound<'_, PyAny>>,
78+
exclude: Option<&Bound<'_, PyAny>>,
79+
extra: &Extra,
80+
) -> PyResult<PyObject> {
81+
self.nested_serializer(value.py())
82+
.get()
83+
.serializer
84+
.to_python(value, include, exclude, extra)
85+
}
86+
87+
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
88+
self.nested_serializer(key.py()).get().serializer.json_key(key, extra)
89+
}
90+
91+
fn serde_serialize<S: serde::ser::Serializer>(
92+
&self,
93+
value: &Bound<'_, PyAny>,
94+
serializer: S,
95+
include: Option<&Bound<'_, PyAny>>,
96+
exclude: Option<&Bound<'_, PyAny>>,
97+
extra: &Extra,
98+
) -> Result<S::Ok, S::Error> {
99+
self.nested_serializer(value.py())
100+
.get()
101+
.serializer
102+
.serde_serialize(value, serializer, include, exclude, extra)
103+
}
104+
105+
fn get_name(&self) -> &str {
106+
&self.name
107+
}
108+
}

src/validators/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ mod list;
4949
mod literal;
5050
mod model;
5151
mod model_fields;
52+
mod nested_model;
5253
mod none;
5354
mod nullable;
5455
mod set;
@@ -584,6 +585,7 @@ pub fn build_validator(
584585
definitions::DefinitionRefValidator,
585586
definitions::DefinitionsValidatorBuilder,
586587
complex::ComplexValidator,
588+
nested_model::NestedModelValidator,
587589
)
588590
}
589591

@@ -738,6 +740,8 @@ pub enum CombinedValidator {
738740
// input dependent
739741
JsonOrPython(json_or_python::JsonOrPython),
740742
Complex(complex::ComplexValidator),
743+
// Schema for a model inside of another schema
744+
NestedModel(nested_model::NestedModelValidator),
741745
}
742746

743747
/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,

src/validators/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ impl BuildValidator for ModelValidator {
7777

7878
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
7979
let sub_schema = schema.get_as_req(intern!(py, "schema"))?;
80-
let validator = build_validator(&sub_schema, config.as_ref(), definitions)?;
80+
let validator: CombinedValidator = build_validator(&sub_schema, config.as_ref(), definitions)?;
8181
let name = class.getattr(intern!(py, "__name__"))?.extract()?;
8282

8383
Ok(Self {

src/validators/nested_model.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use pyo3::{
2+
intern,
3+
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
4+
Bound, Py, PyObject, PyResult, Python,
5+
};
6+
7+
use crate::{definitions::DefinitionsBuilder, errors::ValResult, input::Input};
8+
9+
use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator};
10+
11+
#[derive(Debug, Clone)]
12+
pub struct NestedModelValidator {
13+
model: Py<PyType>,
14+
name: String,
15+
}
16+
17+
impl_py_gc_traverse!(NestedModelValidator { model });
18+
19+
impl BuildValidator for NestedModelValidator {
20+
const EXPECTED_TYPE: &'static str = "nested-model";
21+
22+
fn build(
23+
schema: &Bound<'_, PyDict>,
24+
_config: Option<&Bound<'_, PyDict>>,
25+
_definitions: &mut DefinitionsBuilder<super::CombinedValidator>,
26+
) -> PyResult<super::CombinedValidator> {
27+
let py = schema.py();
28+
let model = schema
29+
.get_item(intern!(py, "model"))?
30+
.expect("Invalid core schema for `nested-model` type")
31+
.downcast::<PyType>()
32+
.expect("Invalid core schema for `nested-model` type")
33+
.clone();
34+
35+
let name = model.getattr(intern!(py, "__name__"))?.extract()?;
36+
37+
Ok(CombinedValidator::NestedModel(NestedModelValidator {
38+
model: model.clone().unbind(),
39+
name,
40+
}))
41+
}
42+
}
43+
44+
impl Validator for NestedModelValidator {
45+
fn validate<'py>(
46+
&self,
47+
py: Python<'py>,
48+
input: &(impl Input<'py> + ?Sized),
49+
state: &mut ValidationState<'_, 'py>,
50+
) -> ValResult<PyObject> {
51+
self.model
52+
.bind(py)
53+
.call_method(intern!(py, "model_rebuild"), (), None)
54+
.unwrap();
55+
56+
let validator = self
57+
.model
58+
.getattr(py, intern!(py, "__pydantic_validator__"))
59+
.unwrap()
60+
.downcast_bound::<SchemaValidator>(py)
61+
.unwrap()
62+
.clone();
63+
64+
// let validator = crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
65+
// .downcast_bound::<SchemaValidator>(py)
66+
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
67+
// .expect("Cached validator was not a `SchemaValidator`")
68+
// .clone();
69+
70+
validator.get().validator.validate(py, input, state)
71+
}
72+
73+
fn get_name(&self) -> &str {
74+
&self.name
75+
}
76+
}

0 commit comments

Comments
 (0)