Skip to content

Commit 6417998

Browse files
Additional fields and type inferring for union (#1834)
<!-- Contributing guide: https://github.com/open-edge-platform/datumaro/blob/develop/CONTRIBUTING.md --> ### Summary This pull request introduces robust support for Python `Union` types in the experimental Datumaro type registry and dataset schema inference. It enables seamless conversion between multiple candidate types (including both `typing.Union` and modern `A | B` syntax), with fallback logic and comprehensive test coverage. The changes also improve image type conversion and schema inference for datasets, making the system more flexible and reliable. ### Type registry and conversion improvements * Added full support for `Union` types in the type registry: both `typing.Union` and Python 3.10+ `A | B` syntax are now handled, with fallback to subsequent types if the first conversion fails. This includes updated logic in `from_polars_data` and new tests for ordering, error handling, and fallback behavior. [[1]](diffhunk://#diff-e324261812079d99ca2989612441e5df1dd15dabde37fb2e5e8c0c1b639dac0dR122-R154) [[2]](diffhunk://#diff-e324261812079d99ca2989612441e5df1dd15dabde37fb2e5e8c0c1b639dac0dR170-R269) [[3]](diffhunk://#diff-30f23b2869128577a39c918ed25c78229a30cb96578c33728d45e5ebce740ac2R1-R162) * Added comprehensive tests for type registry conversions, including basic types, union types, error cases, ordering, and converter functionality for numpy and torch tensors. ### Dataset and schema inference enhancements * Improved schema inference in `Dataset` to resolve string annotations to actual type objects, supporting cases where `from __future__ import annotations` is used, and added correct handling for `Union` types to preserve the original annotation. [[1]](diffhunk://#diff-4ac196ddc4dc8e6d33daf684ded18886ff8774fadb8b6cbd4bfa88ca424bb34fR65-R80) [[2]](diffhunk://#diff-4ac196ddc4dc8e6d33daf684ded18886ff8774fadb8b6cbd4bfa88ca424bb34fR94-R110) * Updated type variable definitions and method signatures in `dataset.py` for clarity and correctness, and removed unnecessary imports. [[1]](diffhunk://#diff-4ac196ddc4dc8e6d33daf684ded18886ff8774fadb8b6cbd4bfa88ca424bb34fR19-R25) [[2]](diffhunk://#diff-4ac196ddc4dc8e6d33daf684ded18886ff8774fadb8b6cbd4bfa88ca424bb34fL105-R128) [[3]](diffhunk://#diff-4ac196ddc4dc8e6d33daf684ded18886ff8774fadb8b6cbd4bfa88ca424bb34fL134-R157) ### API and import improvements * Updated the experimental module’s public API to expose new converters, dataset classes, fields, schema types, and registry functions. ### Test coverage * Added targeted tests for union type handling in dataset samples, ensuring both modern and legacy union syntax are supported. These changes significantly improve the flexibility and reliability of type conversion and schema inference in Datumaro’s experimental pipeline. <!-- Resolves #111 and #222. Depends on #1000 (for series of dependent commits). This PR introduces this capability to make the project better in this and that. - Added this feature - Removed that feature - Fixed the problem #1234 --> ### How to test <!-- Describe the testing procedure for reviewers, if changes are not fully covered by unit tests or manual testing can be complicated. --> ### Checklist <!-- Put an 'x' in all the boxes that apply --> - [ ] I have added unit tests to cover my changes.​ - [ ] I have added integration tests to cover my changes.​ - [ ] I have added the description of my changes into [CHANGELOG](https://github.com/open-edge-platform/datumaro/blob/develop/CHANGELOG.md).​ - [ ] I have updated the [documentation](https://github.com/open-edge-platform/datumaro/tree/develop/docs) accordingly ### License - [ ] I submit _my code changes_ under the same [MIT License](https://github.com/open-edge-platform/datumaro/blob/develop/LICENSE) that covers the project. Feel free to contact the maintainers if that's a concern. - [ ] I have updated the license header for each file (see an example below). ```python # Copyright (C) 2025 Intel Corporation # # SPDX-License-Identifier: MIT ```
2 parents 3756f46 + ec296a3 commit 6417998

File tree

7 files changed

+420
-11
lines changed

7 files changed

+420
-11
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
11
# Copyright (C) 2025 Intel Corporation
22
#
33
# SPDX-License-Identifier: MIT
4+
5+
from .converter_registry import ConverterRegistry, converter, find_conversion_path
6+
from .dataset import Dataset, Sample
7+
from .fields import (
8+
BBoxField,
9+
ImageField,
10+
ImageInfoField,
11+
ImagePathField,
12+
LabelField,
13+
TensorField,
14+
bbox_field,
15+
image_field,
16+
image_info_field,
17+
image_path_field,
18+
label_field,
19+
tensor_field,
20+
)
21+
from .schema import AttributeInfo, Field, Schema, Semantic
22+
from .type_registry import register_from_polars_converter, register_numpy_converter

src/datumaro/experimental/dataset.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
from __future__ import annotations
66

7-
import copy
7+
import sys
8+
import types
89
from functools import cache
910
from typing import (
1011
TYPE_CHECKING,
12+
Annotated,
1113
Any,
1214
Dict,
1315
Generic,
@@ -16,12 +18,13 @@
1618
Type,
1719
Union,
1820
cast,
21+
dataclass_transform,
1922
get_args,
2023
get_origin,
2124
)
2225

2326
import polars as pl
24-
from typing_extensions import Annotated, TypeGuard, TypeVar, dataclass_transform
27+
from typing_extensions import TypeGuard, TypeVar
2528

2629
from .converter_registry import Converter, find_conversion_path
2730
from .schema import AttributeInfo, Field, Schema
@@ -61,8 +64,21 @@ def infer_schema(cls) -> Schema:
6164
Raises:
6265
TypeError: If attributes don't have proper Field annotations
6366
"""
67+
6468
attributes: dict[str, AttributeInfo] = {}
6569
for name, annotation in cls.__annotations__.items():
70+
# Resolve string annotations to actual type objects
71+
# This handles cases where `from __future__ import annotations` is used
72+
if isinstance(annotation, str):
73+
try:
74+
# Get the module where the class is defined to resolve annotations
75+
module = sys.modules[cls.__module__]
76+
annotation = eval(annotation, module.__dict__)
77+
except Exception as e:
78+
raise TypeError(
79+
f"Failed to resolve type annotation '{annotation}' for attribute '{name}': {e}"
80+
)
81+
6682
origin = get_origin(annotation)
6783
if origin is Annotated:
6884
# Handle Annotated[Type, Field] approach
@@ -78,13 +94,18 @@ def infer_schema(cls) -> Schema:
7894
# Extract base class from generic types like MyClass[A, B, C] -> MyClass
7995
type_origin = get_origin(annotation)
8096

81-
final_type = type_origin if type_origin is not None else annotation
97+
# For Union types, keep the original annotation (the Union instance)
98+
# instead of the origin (which is just the UnionType class)
99+
if isinstance(annotation, types.UnionType) or type_origin is Union:
100+
final_type = annotation
101+
else:
102+
final_type = type_origin if type_origin is not None else annotation
82103
attributes[name] = AttributeInfo(type=final_type, annotation=field_annotation)
83104
return Schema(attributes=attributes)
84105

85106

86-
DType = TypeVar("DType", bound=Sample, default=Sample)
87-
DTargetType = TypeVar("DTargetType", bound=Sample, default=Sample)
107+
DType = TypeVar("DType", bound=Sample)
108+
DTargetType = TypeVar("DTargetType", bound=Sample)
88109

89110

90111
class Dataset(Generic[DType]):
@@ -102,7 +123,7 @@ class Dataset(Generic[DType]):
102123
def __init__(
103124
self,
104125
dtype_or_schema: Union[Schema, Type[DType]],
105-
categories: Dict[str, "Categories"] = None,
126+
categories: Categories = None,
106127
):
107128
"""
108129
Initialize dataset with either a schema or sample type.
@@ -131,7 +152,7 @@ def from_dataframe(
131152
df: pl.DataFrame,
132153
dtype_or_schema: Union[Schema, Type[DTargetType]],
133154
lazy_converters: List[Converter] | None = None,
134-
categories: Dict[str, "Categories"] = None,
155+
categories: Dict[str, Categories] = None,
135156
) -> "Dataset[DTargetType]":
136157
"""
137158
Create a Dataset from an existing DataFrame and lazy converters.
@@ -282,8 +303,7 @@ def convert_to_schema(
282303
A new Dataset instance with the converted schema
283304
"""
284305
# Import the converter implementations to register them
285-
# ruff: noqa: F401
286-
import datumaro.experimental.converters # pyright: ignore [reportUnusedImport, reportMissingImports]
306+
import datumaro.experimental.converters # type: ignore[import] # noqa: F401
287307

288308
# Determine target schema
289309
if isinstance(target_dtype_or_schema, Schema):

src/datumaro/experimental/fields.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,57 @@ def image_path_field(semantic: Semantic = Semantic.Default) -> Any:
246246
ImagePathField instance configured with the given semantic tags
247247
"""
248248
return ImagePathField(semantic=semantic)
249+
250+
251+
@dataclass(frozen=True)
252+
class LabelField(Field):
253+
"""
254+
Represents a unified label annotation field that supports both single and multi-label scenarios.
255+
256+
This field automatically detects whether the input is a single label or multiple labels
257+
and handles the conversion accordingly:
258+
- Single labels: stored as Int32
259+
- Multi-labels: stored as List(Int32)
260+
"""
261+
262+
semantic: Semantic
263+
dtype: Any
264+
multi_label: bool = False # Flag to indicate if this field should handle multi-labels
265+
266+
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
267+
"""Generate schema based on whether this is single or multi-label."""
268+
if self.multi_label:
269+
return {name: pl.List(self.dtype)}
270+
return {name: self.dtype}
271+
272+
def to_polars(self, name: str, value: Any) -> dict[str, pl.Series]:
273+
"""Convert label(s) to Polars format for single or multi-label cases."""
274+
if value is None:
275+
return {name: pl.Series(name, [None], dtype=self.dtype)}
276+
277+
if self.multi_label:
278+
return {name: pl.Series(name, [to_numpy(value)], dtype=pl.List(self.dtype))}
279+
280+
return {name: pl.Series(name, [value], dtype=self.dtype)}
281+
282+
def from_polars(self, name: str, row_index: int, df: pl.DataFrame, target_type: type[T]) -> T:
283+
"""Reconstruct label(s) from Polars data."""
284+
data = df[name][row_index]
285+
return from_polars_data(data, target_type)
286+
287+
288+
def label_field(
289+
dtype: Any = pl.Int32(), semantic: Semantic = Semantic.Default, multi_label: bool = False
290+
) -> Any:
291+
"""
292+
Create a LabelField instance with the specified parameters.
293+
294+
Args:
295+
dtype: Polars data type for label values (defaults to pl.Int32())
296+
semantic: Semantic tags describing the label purpose (optional)
297+
multi_label: Whether this field should handle multiple labels (defaults to False)
298+
299+
Returns:
300+
LabelField instance configured with the given parameters
301+
"""
302+
return LabelField(semantic=semantic, dtype=dtype, multi_label=multi_label)

src/datumaro/experimental/legacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def analyze_legacy_dataset(legacy_dataset: LegacyDataset) -> AnalysisResult:
194194
attributes.update(media_converter.get_schema_attributes())
195195
except ValueError:
196196
# No converter for this media type - skip
197-
media_converter = None
197+
pass
198198

199199
# Get annotation attributes from converters
200200
for ann_type in ann_types:

src/datumaro/experimental/type_registry.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
DataFrames. New types can be registered at runtime without modifying core code.
1010
"""
1111

12-
from typing import Any, Callable
12+
import types
13+
from typing import Any, Callable, Union
1314

1415
import numpy as np
1516
import polars as pl
@@ -118,9 +119,39 @@ def from_polars_data(polars_data: Any, target_type: type) -> Any:
118119
>>> isinstance(tensor, torch.Tensor)
119120
True
120121
"""
122+
# Handle direct type matches first
121123
if target_type in _from_polars_converters:
122124
return _from_polars_converters[target_type](polars_data)
123125

126+
# Handle Union types (e.g., torch.Tensor | np.ndarray)
127+
# Check if target_type is a Union type (Python 3.10+ style or typing.Union)
128+
is_union = False
129+
union_args = None
130+
131+
# Check for types.UnionType (Python 3.10+ syntax: A | B)
132+
if isinstance(target_type, types.UnionType):
133+
is_union = True
134+
union_args = target_type.__args__
135+
136+
# Check for typing.Union (older syntax: Union[A, B])
137+
try:
138+
from typing import get_args, get_origin
139+
140+
if get_origin(target_type) is Union:
141+
is_union = True
142+
union_args = get_args(target_type)
143+
except Exception:
144+
pass
145+
146+
if is_union and union_args:
147+
# Try each type in the union until one succeeds
148+
for union_type in union_args:
149+
if union_type in _from_polars_converters:
150+
try:
151+
return _from_polars_converters[union_type](polars_data)
152+
except KeyError:
153+
# If conversion fails, try the next type in the union
154+
continue
124155
raise TypeError(f"No converter registered for type {target_type}")
125156

126157

@@ -136,3 +167,101 @@ def from_polars_data(polars_data: Any, target_type: type) -> Any:
136167
) # pyright: ignore[reportUnknownMemberType, reportUnknownLambdaType, reportUnknownArgumentType]
137168
except ImportError:
138169
pass
170+
171+
172+
# Register PIL Image converters if available
173+
try:
174+
from PIL import Image
175+
176+
register_numpy_converter(Image.Image, lambda x: np.array(x))
177+
register_from_polars_converter(Image.Image, lambda x: Image.fromarray(np.array(x)))
178+
except ImportError:
179+
pass
180+
181+
182+
def convert_image_type(image: Any, target_type: type) -> Any:
183+
"""
184+
Convert an image between different types (numpy, PIL, torch).
185+
This function provides direct conversion between image types using
186+
the registered converters in the type registry.
187+
Args:
188+
image: Source image (numpy.ndarray, PIL.Image.Image, or torch.Tensor)
189+
target_type: Target type to convert to
190+
Returns:
191+
Image converted to the target type
192+
Raises:
193+
TypeError: If source or target type is not supported
194+
Example:
195+
>>> import numpy as np
196+
>>> from PIL import Image
197+
>>> import torch
198+
>>>
199+
>>> # Convert numpy array to PIL Image
200+
>>> np_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
201+
>>> pil_image = convert_image_type(np_image, Image.Image)
202+
>>>
203+
>>> # Convert PIL Image to torch tensor
204+
>>> torch_image = convert_image_type(pil_image, torch.Tensor)
205+
"""
206+
current_type = type(image)
207+
208+
# Define supported image types - only numpy, PIL Image, and torch Tensor
209+
supported_image_types = get_supported_image_types()
210+
211+
# Validate that target_type is a supported image type
212+
if target_type not in supported_image_types:
213+
supported_names = [t.__name__ for t in supported_image_types]
214+
raise TypeError(
215+
f"Target type {target_type.__name__} not supported. Supported image types: {supported_names}"
216+
)
217+
218+
# If already the target type, return as-is
219+
if current_type == target_type:
220+
return image
221+
222+
# Convert via numpy as intermediate format
223+
try:
224+
# First convert to numpy if not already
225+
if current_type == np.ndarray:
226+
numpy_image = image
227+
else:
228+
numpy_image = to_numpy(image)
229+
230+
# Then convert from numpy to target type
231+
if target_type == np.ndarray:
232+
return numpy_image
233+
else:
234+
# Convert numpy to target via polars-style conversion
235+
return _from_polars_converters[target_type](numpy_image)
236+
237+
except Exception as e:
238+
raise TypeError(f"Cannot convert from {current_type} to {target_type}: {e}")
239+
240+
241+
def get_supported_image_types() -> list[type]:
242+
"""
243+
Get a list of all supported image types for conversion.
244+
Returns:
245+
List of supported image types
246+
"""
247+
supported_types = [np.ndarray] # numpy is always supported
248+
249+
# Add conditionally available types
250+
try:
251+
from PIL import Image
252+
253+
if Image.Image in _from_polars_converters:
254+
supported_types.append(Image.Image)
255+
except ImportError:
256+
pass
257+
258+
# Check for torch
259+
try:
260+
import torch
261+
262+
if torch.Tensor in _from_polars_converters:
263+
supported_types.append(torch.Tensor)
264+
except ImportError:
265+
pass
266+
267+
return supported_types

tests/unit/experimental/test_dataset.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,3 +599,28 @@ class TestSample(Sample):
599599

600600
assert len(schema3.attributes["image"].categories) == 2 # car, truck
601601
assert len(schema1.attributes["bbox"].categories) == 1 # person
602+
603+
604+
def test_union_type_handling():
605+
"""Test Union type handling with both modern (A | B) and typing.Union syntax."""
606+
try:
607+
import torch
608+
except ImportError:
609+
pytest.skip("PyTorch not available")
610+
611+
from typing import Union
612+
613+
from datumaro.experimental.type_registry import from_polars_data
614+
615+
# Modern syntax
616+
union_type_modern = torch.Tensor | np.ndarray
617+
polars_data = [1.0, 2.0, 3.0]
618+
result = from_polars_data(polars_data, union_type_modern)
619+
assert isinstance(result, torch.Tensor)
620+
assert result.tolist() == [1.0, 2.0, 3.0]
621+
622+
# typing.Union syntax
623+
union_type_typing = Union[torch.Tensor, np.ndarray]
624+
result2 = from_polars_data(polars_data, union_type_typing)
625+
assert isinstance(result2, torch.Tensor)
626+
assert result2.tolist() == [1.0, 2.0, 3.0]

0 commit comments

Comments
 (0)