Skip to content
17 changes: 13 additions & 4 deletions python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,19 @@ def _backfill_missing_fields(
ndim=field_type.ndim
)

# The schema should already be unified by unify_schemas, so types
# should be compatible. If not, let the error propagate up.
# No explicit casting needed - PyArrow will handle type compatibility
# during struct creation or raise appropriate errors.
# Handle type mismatches for primitive types
# The schema should already be unified by unify_schemas, but
# struct field types may still diverge (e.g., int64 vs float64).
# Cast the existing array to match the unified struct field type.
elif current_array.type != field_type:
try:
current_array = current_array.cast(field_type)
except pa.ArrowInvalid as e:
raise ValueError(
f"Cannot cast struct field '{field_name}' from "
f"{current_array.type} to {field_type}: {e}"
) from e

aligned_fields.append(current_array)
else:
# If the field is missing, fill with nulls
Expand Down
29 changes: 29 additions & 0 deletions python/ray/data/tests/test_map_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,35 @@ async def __call__(self, batch):
assert len(output) == len(expected_output), (len(output), len(expected_output))


def test_map_batches_struct_field_type_divergence(shutdown_only):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to this E2E test, I think we should add a unit test for this bug as well (maybe at the concat function layer of abstraction). Unit tests are not only much faster to run, but also serve as documentation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added unit test in: 6a35990

"""Test map_batches with struct fields that have diverging primitive types."""

def generator_fn(batch):
for i, row_id in enumerate(batch["id"]):
if i % 2 == 0:
# Yield struct with fields (a: int64, b: string)
yield {"data": [{"a": 1, "b": "hello"}]}
else:
# Yield struct with fields (a: float64, c: int32)
# Field 'a' has different type, field 'b' missing, field 'c' new
yield {"data": [{"a": 1.5, "c": 100}]}

ds = ray.data.range(4, override_num_blocks=1)
ds = ds.map_batches(generator_fn, batch_size=4)
result = ds.materialize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think this materialize is redundant. take_all() already materializes the dataset


rows = result.take_all()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant materialize() call before take_all()

Low Severity

The materialize() call on line 824 is redundant because take_all() on line 826 already materializes the dataset. This was noted in the PR review comments. The extra call doesn't cause incorrect behavior but adds unnecessary overhead and clutters the test code.

Fix in Cursor Fix in Web

assert len(rows) == 4

# Rows 0 and 2 should have int cast to float, with c=None
assert rows[0]["data"] == {"a": 1.0, "b": "hello", "c": None}
assert rows[2]["data"] == {"a": 1.0, "b": "hello", "c": None}

# Rows 1 and 3 should have float a, with b=None
assert rows[1]["data"] == {"a": 1.5, "b": None, "c": 100}
assert rows[3]["data"] == {"a": 1.5, "b": None, "c": 100}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I don't think this is possible given the current Ray Data implementation, this row ordering isn't guaranteed by the interface. This has historically been the cause of a lot of Ray Data's flaky tests.

Could you refactor this test so that it doesn't depend on a particular test ordering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 773cd0d



if __name__ == "__main__":
import sys

Expand Down