Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions dlt/extract/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,29 @@ def __call__(self, item: TDataItems, meta: Any = None) -> TDataItems:
from dlt.common.libs.pydantic import validate_and_filter_item, validate_and_filter_items

if isinstance(item, list):
return [
model.dict(by_alias=True)
for model in validate_and_filter_items(
self.table_name, self.list_model, item, self.column_mode, self.data_mode
)
]
item = validate_and_filter_item(
input_is_model = bool(item) and isinstance(item[0], PydanticBaseModel)
validated_list = validate_and_filter_items(
self.table_name, self.list_model, item, self.column_mode, self.data_mode
)
if input_is_model:
input_fields = set(item[0].__class__.model_fields.keys())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good catch. the model that is used to validate may be different from model instances passed to pipeline. even if the model is the same there are some weird cases or revalidation was requested. so we cannot just skip validation here.

still - doing this check and going back to dict is IMO not intuitive and expensive. we also probe just a first element and assume that the list is of uniform item types.

overall I think we need something simpler here

validated_fields = set(validated_list[0].__class__.model_fields.keys())
if input_fields.issubset(validated_fields):
return validated_list
return [m.dict(by_alias=True) for m in validated_list]

input_is_model = isinstance(item, PydanticBaseModel)
validated = validate_and_filter_item(
self.table_name, self.model, item, self.column_mode, self.data_mode
)
if item is not None:
item = item.dict(by_alias=True)
return item
if validated is None:
return None
if input_is_model:
input_fields = set(item.__class__.model_fields.keys())
validated_fields = set(validated.__class__.model_fields.keys())
if input_fields.issubset(validated_fields):
return validated
return validated.dict(by_alias=True)

def __str__(self, *args: Any, **kwargs: Any) -> str:
return f"PydanticValidator(model={self.model.__qualname__})"
Expand Down Expand Up @@ -93,4 +104,4 @@ def create_item_validator(
),
schema_contract or expanded_schema_contract,
)
return None, schema_contract
return None, schema_contract
50 changes: 50 additions & 0 deletions tests/pipeline/test_pipeline_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,3 +855,53 @@ def inconsistent_data(dtype: str):
# generates variant column on non-nullable column. original "foo" will receive null
pipeline.run(inconsistent_data("text"))
assert pip_ex.value.step == "normalize"


@pytest.mark.parametrize(
"as_model, as_list",
[
(False, False),
(True, False),
(False, True),
(True, True),
],
)
def test_pydantic_validator_preserves_model_instances(as_model, as_list):
class Result(BaseModel):
number: int

@dlt.resource(columns=Result)
def data():
if as_model:
item = Result(number=1)
else:
item = {"number": 1}# type: ignore[assignment]
if as_list:
yield [item, item, item]
else:
yield item

seen = []

@dlt.transformer(data_from=data)
def check(x):
seen.append(x)
yield x

pipeline = dlt.pipeline(destination="duckdb", dev_mode=True)
pipeline.run(check)

assert len(seen) == 1
v = seen[0]

if as_list:
assert isinstance(v, list)
if as_model:
assert all(isinstance(el, BaseModel) for el in v)
else:
assert all(isinstance(el, dict) for el in v)
else:
if as_model:
assert isinstance(v, BaseModel)
else:
assert isinstance(v, dict)
Loading