Skip to content

Commit 6a35990

Browse files
committed
test: add unit test using concat
Signed-off-by: machichima <nary12321@gmail.com>
1 parent dd8a662 commit 6a35990

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

python/ray/data/tests/test_transform_pyarrow.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,64 @@ def test_struct_with_arrow_variable_shaped_tensor_type(
473473
)
474474

475475

476+
@pytest.mark.skipif(
477+
get_pyarrow_version() < MIN_PYARROW_VERSION_TYPE_PROMOTION,
478+
reason="Requires PyArrow >= 14.0.0 for type promotion in nested struct fields",
479+
)
480+
def test_struct_with_diverging_primitive_types():
481+
"""Test concatenating tables with struct fields that have diverging primitive types.
482+
483+
This tests the scenario where struct fields have the same name but different
484+
primitive types (e.g., int64 vs float64), which requires type promotion.
485+
"""
486+
import pyarrow as pa
487+
488+
# Table 1: struct with (a: int64, b: string)
489+
t1 = pa.table(
490+
{
491+
"data": pa.array(
492+
[{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}],
493+
type=pa.struct([pa.field("a", pa.int64()), pa.field("b", pa.string())]),
494+
)
495+
}
496+
)
497+
498+
# Table 2: struct with (a: float64, c: int32)
499+
# Field 'a' has different type, field 'b' missing, field 'c' new
500+
t2 = pa.table(
501+
{
502+
"data": pa.array(
503+
[{"a": 1.5, "c": 100}, {"a": 2.5, "c": 200}],
504+
type=pa.struct(
505+
[pa.field("a", pa.float64()), pa.field("c", pa.int32())]
506+
),
507+
)
508+
}
509+
)
510+
511+
# Concatenate with type promotion
512+
result = concat([t1, t2], promote_types=True)
513+
514+
# Verify schema: field 'a' should be promoted to float64
515+
expected_struct_type = pa.struct(
516+
[
517+
pa.field("a", pa.float64()),
518+
pa.field("b", pa.string()),
519+
pa.field("c", pa.int32()),
520+
]
521+
)
522+
assert result.schema == pa.schema([pa.field("data", expected_struct_type)])
523+
524+
# Verify data: int64 values should be cast to float64, missing fields filled with None
525+
expected_data = [
526+
{"a": 1.0, "b": "hello", "c": None},
527+
{"a": 2.0, "b": "world", "c": None},
528+
{"a": 1.5, "b": None, "c": 100},
529+
{"a": 2.5, "b": None, "c": 200},
530+
]
531+
assert result.column("data").to_pylist() == expected_data
532+
533+
476534
def test_arrow_concat_object_with_tensor_fails(object_with_tensor_fails_blocks):
477535
with pytest.raises(ArrowConversionError) as exc_info:
478536
concat(object_with_tensor_fails_blocks)

0 commit comments

Comments
 (0)