Skip to content

feat: support names parameter in read_csv for bigquery engine #1659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions bigframes/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ def get_axis_number(axis: typing.Union[str, int]) -> typing.Literal[0, 1]:
raise ValueError(f"Not a valid axis: {axis}")


def is_list_like(obj: typing.Any) -> typing_extensions.TypeGuard[typing.Sequence]:
return pd.api.types.is_list_like(obj)
def is_list_like(
obj: typing.Any, allow_sets: bool = True
) -> typing_extensions.TypeGuard[typing.Sequence]:
return pd.api.types.is_list_like(obj, allow_sets=allow_sets)


def is_dict_like(obj: typing.Any) -> typing_extensions.TypeGuard[typing.Mapping]:
Expand Down
36 changes: 20 additions & 16 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

from collections import abc
import datetime
import logging
import os
Expand Down Expand Up @@ -569,7 +570,7 @@ def read_gbq_table(
columns = col_order

return self._loader.read_gbq_table(
query=query,
table_id=query,
index_col=index_col,
columns=columns,
max_results=max_results,
Expand Down Expand Up @@ -953,14 +954,21 @@ def _read_csv_w_bigquery_engine(
native CSV loading capabilities, making it suitable for large datasets
that may not fit into local memory.
"""

if any(param is not None for param in (dtype, names)):
not_supported = ("dtype", "names")
if dtype is not None:
raise NotImplementedError(
f"BigQuery engine does not support these arguments: {not_supported}. "
f"BigQuery engine does not support the `dtype` argument."
f"{constants.FEEDBACK_LINK}"
)

if names is not None:
if len(names) != len(set(names)):
raise ValueError("Duplicated names are not allowed.")
if not (
bigframes.core.utils.is_list_like(names, allow_sets=False)
or isinstance(names, abc.KeysView)
):
raise ValueError("Names should be an ordered collection.")

if index_col is True:
raise ValueError("The value of index_col couldn't be 'True'")

Expand Down Expand Up @@ -1004,11 +1012,9 @@ def _read_csv_w_bigquery_engine(
elif header > 0:
job_config.skip_leading_rows = header + 1

return self._loader.read_bigquery_load_job(
filepath_or_buffer,
job_config=job_config,
index_col=index_col,
columns=columns,
table_id = self._loader.load_file(filepath_or_buffer, job_config=job_config)
return self._loader.read_gbq_table(
table_id, index_col=index_col, columns=columns, names=names
)

def read_pickle(
Expand Down Expand Up @@ -1049,8 +1055,8 @@ def read_parquet(
job_config = bigquery.LoadJobConfig()
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.labels = {"bigframes-api": "read_parquet"}

return self._loader.read_bigquery_load_job(path, job_config=job_config)
table_id = self._loader.load_file(path, job_config=job_config)
return self._loader.read_gbq_table(table_id)
else:
if "*" in path:
raise ValueError(
Expand Down Expand Up @@ -1121,10 +1127,8 @@ def read_json(
job_config.encoding = encoding
job_config.labels = {"bigframes-api": "read_json"}

return self._loader.read_bigquery_load_job(
path_or_buf,
job_config=job_config,
)
table_id = self._loader.load_file(path_or_buf, job_config=job_config)
return self._loader.read_gbq_table(table_id)
else:
if any(arg in kwargs for arg in ("chunksize", "iterator")):
raise NotImplementedError(
Expand Down
14 changes: 14 additions & 0 deletions bigframes/session/_io/bigquery/read_gbq_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def get_index_cols(
| Iterable[int]
| int
| bigframes.enums.DefaultIndexKind,
*,
names: Optional[Iterable[str]] = None,
) -> List[str]:
"""
If we can get a total ordering from the table, such as via primary key
Expand All @@ -245,6 +247,14 @@ def get_index_cols(
# Transform index_col -> index_cols so we have a variable that is
# always a list of column names (possibly empty).
schema_len = len(table.schema)

# If the `names` is provided, the index_col provided by the user is the new
# name, so we need to rename it to the original name in the table schema.
renamed_schema: Optional[Dict[str, str]] = None
if names is not None:
assert len(list(names)) == schema_len
renamed_schema = {name: field.name for name, field in zip(names, table.schema)}

index_cols: List[str] = []
if isinstance(index_col, bigframes.enums.DefaultIndexKind):
if index_col == bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64:
Expand All @@ -261,6 +271,8 @@ def get_index_cols(
f"Got unexpected index_col {repr(index_col)}. {constants.FEEDBACK_LINK}"
)
elif isinstance(index_col, str):
if renamed_schema is not None:
index_col = renamed_schema.get(index_col, index_col)
index_cols = [index_col]
elif isinstance(index_col, int):
if not 0 <= index_col < schema_len:
Expand All @@ -272,6 +284,8 @@ def get_index_cols(
elif isinstance(index_col, Iterable):
for item in index_col:
if isinstance(item, str):
if renamed_schema is not None:
item = renamed_schema.get(item, item)
index_cols.append(item)
elif isinstance(item, int):
if not 0 <= item < schema_len:
Expand Down
64 changes: 42 additions & 22 deletions bigframes/session/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,15 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob):

def read_gbq_table(
self,
query: str,
table_id: str,
*,
index_col: Iterable[str]
| str
| Iterable[int]
| int
| bigframes.enums.DefaultIndexKind = (),
columns: Iterable[str] = (),
names: Optional[Iterable[str]] = None,
max_results: Optional[int] = None,
api_name: str = "read_gbq_table",
use_cache: bool = True,
Expand All @@ -375,7 +376,7 @@ def read_gbq_table(
)

table_ref = google.cloud.bigquery.table.TableReference.from_string(
query, default_project=self._bqclient.project
table_id, default_project=self._bqclient.project
)

columns = list(columns)
Expand Down Expand Up @@ -411,12 +412,37 @@ def read_gbq_table(
f"Column '{key}' of `columns` not found in this table. Did you mean '{possibility}'?"
)

# TODO(b/408499371): check `names` work with `use_cols` for read_csv method.
if names is not None:
len_names = len(list(names))
len_columns = len(table.schema)
if len_names > len_columns:
raise ValueError(
f"Too many columns specified: expected {len_columns}"
f" and found {len_names}"
)
elif len_names < len_columns:
if (
isinstance(index_col, bigframes.enums.DefaultIndexKind)
or index_col != ()
):
raise KeyError(
"When providing both `index_col` and `names`, ensure the "
"number of `names` matches the number of columns in your "
"data."
)
index_col = range(len_columns - len_names)
names = [
field.name for field in table.schema[: len_columns - len_names]
] + list(names)

# Converting index_col into a list of column names requires
# the table metadata because we might use the primary keys
# when constructing the index.
index_cols = bf_read_gbq_table.get_index_cols(
table=table,
index_col=index_col,
names=names,
)
_check_column_duplicates(index_cols, columns)

Expand All @@ -443,15 +469,15 @@ def read_gbq_table(
# TODO(b/338419730): We don't need to fallback to a query for wildcard
# tables if we allow some non-determinism when time travel isn't supported.
if max_results is not None or bf_io_bigquery.is_table_with_wildcard_suffix(
query
table_id
):
# TODO(b/338111344): If we are running a query anyway, we might as
# well generate ROW_NUMBER() at the same time.
all_columns: Iterable[str] = (
itertools.chain(index_cols, columns) if columns else ()
)
query = bf_io_bigquery.to_query(
query,
table_id,
columns=all_columns,
sql_predicate=bf_io_bigquery.compile_filters(filters)
if filters
Expand Down Expand Up @@ -561,6 +587,15 @@ def read_gbq_table(
index_names = [None]

value_columns = [col for col in array_value.column_ids if col not in index_cols]
if names is not None:
renamed_cols: Dict[str, str] = {
col: new_name for col, new_name in zip(array_value.column_ids, names)
}
index_names = [
renamed_cols.get(index_col, index_col) for index_col in index_cols
]
value_columns = [renamed_cols.get(col, col) for col in value_columns]

block = blocks.Block(
array_value,
index_columns=index_cols,
Expand All @@ -576,18 +611,12 @@ def read_gbq_table(
df.sort_index()
return df

def read_bigquery_load_job(
def load_file(
self,
filepath_or_buffer: str | IO["bytes"],
*,
job_config: bigquery.LoadJobConfig,
index_col: Iterable[str]
| str
| Iterable[int]
| int
| bigframes.enums.DefaultIndexKind = (),
columns: Iterable[str] = (),
) -> dataframe.DataFrame:
) -> str:
# Need to create session table beforehand
table = self._storage_manager.create_temp_table(_PLACEHOLDER_SCHEMA)
# but, we just overwrite the placeholder schema immediately with the load job
Expand Down Expand Up @@ -615,16 +644,7 @@ def read_bigquery_load_job(

self._start_generic_job(load_job)
table_id = f"{table.project}.{table.dataset_id}.{table.table_id}"

# The BigQuery REST API for tables.get doesn't take a session ID, so we
# can't get the schema for a temp table that way.

return self.read_gbq_table(
query=table_id,
index_col=index_col,
columns=columns,
api_name="read_gbq_table",
)
return table_id

def read_gbq_query(
self,
Expand Down
Loading