Skip to content

Commit 6497f58

Browse files
authored
Fix sample validation for complex types (#1973)
This pull request fixes a problem with type validation in the experimental dataset and sample modules. When we use the is_list boolean of score_field, a complex ndarray type is generated that includes an Any type, which is not accepted by isinstance. To circumvent problems with complex types, whenever isintance fails with a type error, we only validate against the origin type (ndarray instead of ndarray[float32], for example) ### Type validation improvements * Updated the `_validate_attribute_type` method in `src/datumaro/experimental/dataset.py` to correctly handle type validation for generic types by using `origin` when available, improving support for complex type annotations. ### Test enhancements * Added a new test, `test_sample_with_is_list`, in `tests/unit/experimental/test_sample.py` to verify that samples with list-type fields (using `is_list=True` in `score_field`) are created without validation errors. Resolves #1971 --------- Signed-off-by: Jort Bergfeld <[email protected]>
1 parent 42138a8 commit 6497f58

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

src/datumaro/experimental/dataset.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,22 @@ def _validate_attribute_type(self, expected_type: Any, value: Any) -> bool:
7878
# Union and Callable types have to be handled separately,
7979
# because isinstance() does not work with Callable types.
8080
origin = get_origin(expected_type)
81-
if origin is Union:
81+
if origin in {Union, types.UnionType}:
8282
# Check each type in the Union
83-
return any(self._validate_attribute_type(typ, value) for typ in get_args(expected_type))
84-
if origin in {typing.Callable, collections.abc.Callable} or expected_type in {
83+
result = any(self._validate_attribute_type(typ, value) for typ in get_args(expected_type))
84+
elif origin in {typing.Callable, collections.abc.Callable} or expected_type in {
8585
typing.Callable,
8686
collections.abc.Callable,
8787
}:
88-
return callable(value)
89-
return isinstance(value, expected_type)
88+
result = callable(value)
89+
else:
90+
try:
91+
result = isinstance(value, expected_type)
92+
except TypeError:
93+
# Some complex types cannot be validated, for example, sometimes when a numpy dtype is turned
94+
# into a list using Polars List, the resulting complex dtype will contain a generic Any.
95+
result = isinstance(value, origin)
96+
return result
9097

9198
@classmethod
9299
@cache

tests/unit/experimental/test_sample.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any
66

77
import numpy as np
8+
import numpy.typing as npt
89
import polars as pl
910
import pytest
1011

@@ -17,6 +18,7 @@
1718
bbox_field,
1819
image_field,
1920
image_info_field,
21+
numeric_field,
2022
)
2123
from datumaro.experimental.fields.images import image_path_field
2224
from datumaro.experimental.schema import Schema
@@ -185,3 +187,11 @@ class ExtendedSample(BaseSample):
185187
assert len(extended_schema.attributes) == 3
186188
assert "image_info" in extended_schema.attributes
187189
assert "image_info" not in base_schema.attributes
190+
191+
192+
def test_sample_with_is_list():
193+
class MySample(Sample):
194+
confidence: npt.NDArray[np.float32] | None = numeric_field(dtype=pl.Float32(), is_list=True)
195+
196+
# Assert that sample can be created without validation errors
197+
MySample(confidence=np.array([0.8]))

0 commit comments

Comments
 (0)