Skip to content

Commit 8047313

Browse files
Fix RGB/BGR color converter and legacy dataset detection (#1972)
Fix the RGB color converter. It did not properly handle the images returned from the polars data frame which are stored as a flattened list. Also added support for BGR to RGB conversion. Make the legacy converter more robust by adding explicit arguments to determine which dataset is being dealt with. It is not always possible to determine a hierarchical dataset just by looking at the dataset. <!-- Please add a summary of changes. You may use Copilot to auto-generate the PR description but please consider including any other relevant facts which Copilot may be unaware of (such as design choices and testing procedure). Add references to the relevant issues and pull requests if any like so: Resolves #111 and #222. Depends on #1000 (for series of dependent commits). --> ### Checklist <!-- Put an 'x' in all the boxes that apply --> - [ ] I have added tests to cover my changes or documented any manual tests. - [ ] I have updated the [documentation](https://github.com/open-edge-platform/datumaro/tree/develop/docs) accordingly --------- Signed-off-by: Albert van Houten <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 3f3287b commit 8047313

File tree

5 files changed

+49
-43
lines changed

5 files changed

+49
-43
lines changed

src/datumaro/experimental/converters/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
ImageCallableToImageConverter,
1717
ImagePathToImageConverter,
1818
ImageToImageInfo,
19-
RGBToBGRConverter,
19+
RedBlueColorConverter,
2020
UInt8ToFloat32Converter,
2121
)
2222
from datumaro.experimental.converters.mask_converters import (
@@ -67,7 +67,7 @@
6767
# Mask converters
6868
"PolygonToMaskConverter",
6969
# Image converters
70-
"RGBToBGRConverter",
70+
"RedBlueColorConverter",
7171
"RotatedBBoxToPolygonConverter",
7272
"UInt8ToFloat32Converter",
7373
"_can_lazy_converter_handle_conversion",

src/datumaro/experimental/converters/image_converters.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,68 +16,58 @@
1616

1717

1818
@converter
19-
class RGBToBGRConverter(Converter):
19+
class RedBlueColorConverter(Converter):
2020
"""
21-
Converter that transforms RGB image format to BGR format.
22-
23-
This converter swaps the red and blue channels of RGB images to produce
24-
BGR format images, commonly used for OpenCV compatibility.
21+
Converter that transforms RGB image format to BGR format and vice-versa.
2522
"""
2623

2724
input_image: AttributeSpec[ImageField]
2825
output_image: AttributeSpec[ImageField]
2926

3027
def filter_output_spec(self) -> bool:
3128
"""
32-
Check if input is RGB and configure output for BGR conversion.
29+
Check if input is RGB/BGR and configure output for BGR/RGB conversion.
3330
3431
Returns:
35-
True if the converter should be applied (RGB to BGR), False otherwise
32+
True if the converter should be applied (RGB to BGR/BGR to RGB), False otherwise
3633
"""
3734
input_format = self.input_image.field.format
3835
output_format = self.output_image.field.format
3936

40-
# Configure output specification for BGR format
4137
self.output_image = AttributeSpec(
4238
name=self.output_image.name,
4339
field=ImageField(
4440
semantic=self.input_image.field.semantic,
4541
dtype=self.input_image.field.dtype,
4642
channels_first=self.output_image.field.channels_first,
47-
format="BGR", # Set output format to BGR
43+
format=self.output_image.field.format,
4844
),
4945
)
5046

51-
# Only apply if input is RGB and output should be BGR
52-
return input_format == "RGB" and output_format == "BGR"
47+
return (input_format == "RGB" and output_format == "BGR") or (input_format == "BGR" and output_format == "RGB")
5348

5449
def convert(self, df: pl.DataFrame) -> pl.DataFrame:
55-
"""
56-
Convert RGB image format to BGR using numpy channel swapping.
57-
58-
Args:
59-
df: Input DataFrame containing RGB image data
60-
61-
Returns:
62-
DataFrame with BGR image data in the output column
63-
"""
6450
input_column_name = self.input_image.name
6551
output_column_name = self.output_image.name
6652

6753
input_shape_column_name = self.input_image.name + "_shape"
6854
output_shape_column_name = self.output_image.name + "_shape"
6955

70-
def rgb_to_bgr(tensor_data: pl.Series) -> Any:
71-
"""Convert RGB tensor data to BGR by reversing the channel order."""
72-
data = tensor_data.to_numpy().copy()
73-
data = data.reshape(-1, 3)
74-
data = np.flip(data, 1) # Flip along channel axis
75-
return data.reshape(-1)
56+
expected_dtype = polars_to_numpy_dtype(self.input_image.field.dtype)
57+
58+
def red_blue_swap(row: dict[str, Any]) -> Any:
59+
"""Swaps the first and third channels in the image. Images are stored as a flattened list in polars"""
60+
flat_data = np.array(row[input_column_name], dtype=expected_dtype, copy=True)
61+
shape = tuple(row[input_shape_column_name])
62+
reshaped = flat_data.reshape(shape)
63+
swapped = reshaped[..., ::-1]
64+
return np.asarray(swapped.reshape(-1), dtype=expected_dtype)
7665

77-
dtype = df.schema[input_column_name]
7866
# Apply the conversion using map_elements for efficient processing
7967
return df.with_columns(
80-
pl.col(input_column_name).map_elements(rgb_to_bgr, return_dtype=dtype).alias(output_column_name),
68+
pl.struct([pl.col(input_column_name), pl.col(input_shape_column_name)])
69+
.map_elements(red_blue_swap, return_dtype=pl.List(self.input_image.field.dtype))
70+
.alias(output_column_name),
8171
pl.col(input_shape_column_name).alias(output_shape_column_name),
8272
)
8373

src/datumaro/experimental/legacy/dataset_converters.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,21 @@ class AnalysisResult:
4444
is_hierarchical: bool
4545

4646

47-
def analyze_legacy_dataset(legacy_dataset: LegacyDataset, semantic: Semantic = Semantic.Default) -> AnalysisResult:
47+
def analyze_legacy_dataset(
48+
legacy_dataset: LegacyDataset,
49+
semantic: Semantic = Semantic.Default,
50+
hierarchical: bool = False,
51+
multi_label: bool = False,
52+
anomaly: bool = False,
53+
) -> AnalysisResult:
4854
"""Analyze legacy dataset and generate schema using registered converters.
4955
5056
Args:
5157
legacy_dataset: The legacy Datumaro dataset to analyze
5258
semantic: The semantic type for the converted fields
59+
hierarchical: Boolean indicating if the dataset should be treated as hierarchical
60+
multi_label: Boolean indicating if the dataset should be treated as multi-label
61+
anomaly: Boolean indicating if the dataset should be treated as anomaly
5362
5463
Returns:
5564
AnalysisResult containing the inferred schema and converters
@@ -70,11 +79,11 @@ def analyze_legacy_dataset(legacy_dataset: LegacyDataset, semantic: Semantic = S
7079

7180
# Check if project has a hierarchical structure
7281
label_names = [item.name for item in legacy_dataset.categories()[AnnotationType.label].items]
73-
is_hierarchical = _has_derived_labels(label_group_names) or _has_derived_labels(label_names)
82+
is_hierarchical = (_has_derived_labels(label_group_names) or _has_derived_labels(label_names)) or hierarchical
7483

7584
# Look for multi label classification groups
7685
multi_label_group_names = [name for name in label_group_names if name.startswith("Classification labels__")]
77-
is_multi_label = len(multi_label_group_names) > 1 and not is_hierarchical
86+
is_multi_label = (len(multi_label_group_names) > 1 and not is_hierarchical) or multi_label
7887
else:
7988
is_hierarchical = False
8089
is_multi_label = False
@@ -89,7 +98,7 @@ def analyze_legacy_dataset(legacy_dataset: LegacyDataset, semantic: Semantic = S
8998

9099
# If we have a label converter plus other converters, assume that this is an anomaly task.
91100
# To avoid conflicts between the label attribute and the other ones, use semantic to distinguish them.
92-
is_anomaly = AnnotationType.label in ann_types and len(ann_types) > 1
101+
is_anomaly = (AnnotationType.label in ann_types and len(ann_types) > 1) or anomaly
93102

94103
for ann_type in ann_types:
95104
ann_semantic = Semantic.Anomaly if is_anomaly and ann_type != AnnotationType.label else semantic
@@ -177,11 +186,16 @@ def _convert_legacy_item(item: DatasetItem, analysis_result: AnalysisResult) ->
177186
return attributes
178187

179188

180-
def convert_from_legacy(legacy_dataset: LegacyDataset) -> Dataset[Sample]:
189+
def convert_from_legacy(
190+
legacy_dataset: LegacyDataset, hierarchical: bool = False, multi_label: bool = False, anomaly: bool = False
191+
) -> Dataset[Sample]:
181192
"""Convert legacy dataset to new dataset format with automatic schema inference.
182193
183194
Args:
184195
legacy_dataset: The legacy Datumaro dataset to convert
196+
hierarchical: If True, forces hierarchical classification; otherwise, uses automatic detection.
197+
multi_label: If True, forces multi-label classification; otherwise, uses automatic detection.
198+
anomaly: If True, forces anomaly detection; otherwise, uses automatic detection.
185199
Returns:
186200
A new Dataset with inferred schema and converted data
187201
@@ -194,7 +208,9 @@ def convert_from_legacy(legacy_dataset: LegacyDataset) -> Dataset[Sample]:
194208
"""
195209

196210
# Step 1: Analyze dataset to infer schema
197-
analysis_result = analyze_legacy_dataset(legacy_dataset)
211+
analysis_result = analyze_legacy_dataset(
212+
legacy_dataset, hierarchical=hierarchical, multi_label=multi_label, anomaly=anomaly
213+
)
198214

199215
# Step 2: Create new dataset with inferred schema
200216
experimental_dataset = Dataset(analysis_result.schema)

src/datumaro/experimental/legacy/media_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from datumaro import Dataset as LegacyDataset
1212
from datumaro import DatasetItem, Image, MediaElement
13-
from datumaro.components.media import FromDataMixin, FromFileMixin
13+
from datumaro.components.media import FromDataMixin, FromFileMixin, ImageFromBytes
1414
from datumaro.experimental import AttributeInfo, Sample, Schema, Semantic
1515
from datumaro.experimental.fields import (
1616
ImageInfo,
@@ -172,7 +172,7 @@ def get_schema_attributes(self) -> dict[str, AttributeInfo]:
172172
def convert_item_media(self, item: DatasetItem) -> dict[str, Any]:
173173
result: dict[str, Any] = {}
174174

175-
if isinstance(item.media, Image): # pyright: ignore[reportUnknownMemberType]
175+
if isinstance(item.media, (Image, ImageFromBytes)):
176176
if self.media_mixin == FromDataMixin:
177177
if self.has_callable_data:
178178
# Use a top-level callable to ensure picklability across workers

tests/unit/experimental/test_converters.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
PolygonToBBoxConverter,
2727
PolygonToInstanceMaskConverter,
2828
PolygonToMaskConverter,
29-
RGBToBGRConverter,
29+
RedBlueColorConverter,
3030
RotatedBBoxToPolygonConverter,
3131
UInt8ToFloat32Converter,
3232
converter,
@@ -82,12 +82,12 @@ def convert(self, df: pl.DataFrame) -> pl.DataFrame:
8282

8383
def test_rgb_to_bgr_converter():
8484
"""Test RGB to BGR format conversion."""
85-
converter_instance = RGBToBGRConverter() # type: ignore[call-arg]
85+
converter_instance = RedBlueColorConverter() # type: ignore[call-arg]
8686

8787
# Create test data
88-
rgb_data = np.array([[[255, 0, 0], [0, 255, 0]], [[0, 0, 255], [128, 128, 128]]])
88+
rgb_data = np.array([255, 0, 0, 0, 255, 0, 0, 0, 255, 128, 128, 128])
8989
df = pl.DataFrame(
90-
{"image": [rgb_data.reshape(-1)], "image_shape": [[2, 2, 3]]},
90+
{"image": [rgb_data], "image_shape": [[2, 2, 3]]},
9191
schema=pl.Schema({"image": pl.List(pl.UInt8()), "image_shape": pl.List(pl.Int64)}),
9292
)
9393

@@ -443,7 +443,7 @@ def test_multiple_converter_chaining():
443443
# If successful, should have multiple steps
444444
assert len(path.converters["image"]) == 2
445445
assert type(path.converters["image"][0]) is UInt8ToFloat32Converter
446-
assert type(path.converters["image"][1]) is RGBToBGRConverter
446+
assert type(path.converters["image"][1]) is RedBlueColorConverter
447447

448448
# FIXME(gdlg): the BBoxCoordinateConverter needs an image
449449
# and it does not matter if the image is 8 bits or 32 bits,

0 commit comments

Comments
 (0)