diff --git a/bigframes/core/local_data.py b/bigframes/core/local_data.py index d387e0b818..d23f3538dd 100644 --- a/bigframes/core/local_data.py +++ b/bigframes/core/local_data.py @@ -86,7 +86,7 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable: columns: list[pa.ChunkedArray] = [] fields: list[schemata.SchemaItem] = [] for name, arr in zip(table.column_names, table.columns): - new_arr, bf_type = _adapt_arrow_array(arr) + new_arr, bf_type = _adapt_chunked_array(arr) columns.append(new_arr) fields.append(schemata.SchemaItem(name, bf_type)) @@ -279,10 +279,26 @@ def _adapt_pandas_series( raise e -def _adapt_arrow_array( - array: Union[pa.ChunkedArray, pa.Array] -) -> tuple[Union[pa.ChunkedArray, pa.Array], bigframes.dtypes.Dtype]: +def _adapt_chunked_array( + chunked_array: pa.ChunkedArray, +) -> tuple[pa.ChunkedArray, bigframes.dtypes.Dtype]: + if len(chunked_array.chunks) == 0: + return _adapt_arrow_array(chunked_array.combine_chunks()) + dtype = None + arrays = [] + for chunk in chunked_array.chunks: + array, arr_dtype = _adapt_arrow_array(chunk) + arrays.append(array) + dtype = dtype or arr_dtype + assert dtype is not None + return pa.chunked_array(arrays), dtype + + +def _adapt_arrow_array(array: pa.Array) -> tuple[pa.Array, bigframes.dtypes.Dtype]: """Normalize the array to managed storage types. Preverse shapes, only transforms values.""" + if array.offset != 0: # Offset arrays don't have all operations implemented + return _adapt_arrow_array(pa.concat_arrays([array])) + if pa.types.is_struct(array.type): assert isinstance(array, pa.StructArray) assert isinstance(array.type, pa.StructType) diff --git a/tests/unit/test_local_data.py b/tests/unit/test_local_data.py index 9cd08787c9..bb7330aba4 100644 --- a/tests/unit/test_local_data.py +++ b/tests/unit/test_local_data.py @@ -44,3 +44,23 @@ def test_local_data_well_formed_round_trip(): local_entry = local_data.ManagedArrowTable.from_pandas(pd_data) result = pd.DataFrame(local_entry.itertuples(), columns=pd_data.columns) pandas.testing.assert_frame_equal(pd_data_normalized, result, check_dtype=False) + + +def test_local_data_well_formed_round_trip_chunked(): + pa_table = pa.Table.from_pandas(pd_data, preserve_index=False) + as_rechunked_pyarrow = pa.Table.from_batches(pa_table.to_batches(max_chunksize=2)) + local_entry = local_data.ManagedArrowTable.from_pyarrow(as_rechunked_pyarrow) + result = pd.DataFrame(local_entry.itertuples(), columns=pd_data.columns) + pandas.testing.assert_frame_equal(pd_data_normalized, result, check_dtype=False) + + +def test_local_data_well_formed_round_trip_sliced(): + pa_table = pa.Table.from_pandas(pd_data, preserve_index=False) + as_rechunked_pyarrow = pa.Table.from_batches(pa_table.slice(2, 4).to_batches()) + local_entry = local_data.ManagedArrowTable.from_pyarrow(as_rechunked_pyarrow) + result = pd.DataFrame(local_entry.itertuples(), columns=pd_data.columns) + pandas.testing.assert_frame_equal( + pd_data_normalized[2:4].reset_index(drop=True), + result.reset_index(drop=True), + check_dtype=False, + )