Skip to content

Commit 60c2d36

Browse files
[Data] ArrowInvalid error when you backfill missing fields from map tasks (ray-project#60643)
## Description Try type casting if struct field types mismatch when backfilling missing fields ## Related issues Closes ray-project#60628 ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: machichima <nary12321@gmail.com> Co-authored-by: Balaji Veeramani <balaji@anyscale.com>
1 parent e09b889 commit 60c2d36

File tree

3 files changed

+111
-4
lines changed

3 files changed

+111
-4
lines changed

python/ray/data/_internal/arrow_ops/transform_pyarrow.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,19 @@ def _backfill_missing_fields(
481481
ndim=field_type.ndim
482482
)
483483

484-
# The schema should already be unified by unify_schemas, so types
485-
# should be compatible. If not, let the error propagate up.
486-
# No explicit casting needed - PyArrow will handle type compatibility
487-
# during struct creation or raise appropriate errors.
484+
# Handle type mismatches for primitive types
485+
# The schema should already be unified by unify_schemas, but
486+
# struct field types may still diverge (e.g., int64 vs float64).
487+
# Cast the existing array to match the unified struct field type.
488+
elif current_array.type != field_type:
489+
try:
490+
current_array = current_array.cast(field_type)
491+
except pa.ArrowInvalid as e:
492+
raise ValueError(
493+
f"Cannot cast struct field '{field_name}' from "
494+
f"{current_array.type} to {field_type}: {e}"
495+
) from e
496+
488497
aligned_fields.append(current_array)
489498
else:
490499
# If the field is missing, fill with nulls

python/ray/data/tests/test_map_batches.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
import pytest
1313

1414
import ray
15+
from ray.data._internal.arrow_ops.transform_pyarrow import (
16+
MIN_PYARROW_VERSION_TYPE_PROMOTION,
17+
)
18+
from ray.data._internal.utils.arrow_utils import get_pyarrow_version
1519
from ray.data.context import DataContext
1620
from ray.data.dataset import Dataset
1721
from ray.data.exceptions import UserCodeException
@@ -798,6 +802,42 @@ async def __call__(self, batch):
798802
assert len(output) == len(expected_output), (len(output), len(expected_output))
799803

800804

805+
@pytest.mark.skipif(
806+
get_pyarrow_version() < MIN_PYARROW_VERSION_TYPE_PROMOTION,
807+
reason="Requires PyArrow >= 14.0.0 for type promotion in nested struct fields",
808+
)
809+
def test_map_batches_struct_field_type_divergence(shutdown_only):
810+
"""Test map_batches with struct fields that have diverging primitive types."""
811+
812+
def generator_fn(batch):
813+
for i, row_id in enumerate(batch["id"]):
814+
if i % 2 == 0:
815+
# Yield struct with fields (a: int64, b: string)
816+
yield {"data": [{"a": 1, "b": "hello"}]}
817+
else:
818+
# Yield struct with fields (a: float64, c: int32)
819+
# Field 'a' has different type, field 'b' missing, field 'c' new
820+
yield {"data": [{"a": 1.5, "c": 100}]}
821+
822+
ds = ray.data.range(4, override_num_blocks=1)
823+
ds = ds.map_batches(generator_fn, batch_size=4)
824+
result = ds.materialize()
825+
826+
rows = result.take_all()
827+
assert len(rows) == 4
828+
829+
# Sort to make the order deterministic.
830+
rows.sort(key=lambda r: (r["data"]["a"], str(r["data"]["b"])))
831+
832+
# Rows with a=1.0 (originally int) should have int cast to float, with c=None
833+
assert rows[0]["data"] == {"a": 1.0, "b": "hello", "c": None}
834+
assert rows[1]["data"] == {"a": 1.0, "b": "hello", "c": None}
835+
836+
# Rows with a=1.5 should have float a, with b=None
837+
assert rows[2]["data"] == {"a": 1.5, "b": None, "c": 100}
838+
assert rows[3]["data"] == {"a": 1.5, "b": None, "c": 100}
839+
840+
801841
if __name__ == "__main__":
802842
import sys
803843

python/ray/data/tests/test_transform_pyarrow.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,64 @@ def test_struct_with_arrow_variable_shaped_tensor_type(
473473
)
474474

475475

476+
@pytest.mark.skipif(
477+
get_pyarrow_version() < MIN_PYARROW_VERSION_TYPE_PROMOTION,
478+
reason="Requires PyArrow >= 14.0.0 for type promotion in nested struct fields",
479+
)
480+
def test_struct_with_diverging_primitive_types():
481+
"""Test concatenating tables with struct fields that have diverging primitive types.
482+
483+
This tests the scenario where struct fields have the same name but different
484+
primitive types (e.g., int64 vs float64), which requires type promotion.
485+
"""
486+
import pyarrow as pa
487+
488+
# Table 1: struct with (a: int64, b: string)
489+
t1 = pa.table(
490+
{
491+
"data": pa.array(
492+
[{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}],
493+
type=pa.struct([pa.field("a", pa.int64()), pa.field("b", pa.string())]),
494+
)
495+
}
496+
)
497+
498+
# Table 2: struct with (a: float64, c: int32)
499+
# Field 'a' has different type, field 'b' missing, field 'c' new
500+
t2 = pa.table(
501+
{
502+
"data": pa.array(
503+
[{"a": 1.5, "c": 100}, {"a": 2.5, "c": 200}],
504+
type=pa.struct(
505+
[pa.field("a", pa.float64()), pa.field("c", pa.int32())]
506+
),
507+
)
508+
}
509+
)
510+
511+
# Concatenate with type promotion
512+
result = concat([t1, t2], promote_types=True)
513+
514+
# Verify schema: field 'a' should be promoted to float64
515+
expected_struct_type = pa.struct(
516+
[
517+
pa.field("a", pa.float64()),
518+
pa.field("b", pa.string()),
519+
pa.field("c", pa.int32()),
520+
]
521+
)
522+
assert result.schema == pa.schema([pa.field("data", expected_struct_type)])
523+
524+
# Verify data: int64 values should be cast to float64, missing fields filled with None
525+
expected_data = [
526+
{"a": 1.0, "b": "hello", "c": None},
527+
{"a": 2.0, "b": "world", "c": None},
528+
{"a": 1.5, "b": None, "c": 100},
529+
{"a": 2.5, "b": None, "c": 200},
530+
]
531+
assert result.column("data").to_pylist() == expected_data
532+
533+
476534
def test_arrow_concat_object_with_tensor_fails(object_with_tensor_fails_blocks):
477535
with pytest.raises(ArrowConversionError) as exc_info:
478536
concat(object_with_tensor_fails_blocks)

0 commit comments

Comments
 (0)