From 9d39b436a2be8621b8ea153ee40f41c5966d6d50 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 25 Apr 2025 21:16:04 +0000 Subject: [PATCH] perf: Defer some data uploads to execution time --- bigframes/core/array_value.py | 11 ++- bigframes/core/nodes.py | 5 + bigframes/session/__init__.py | 15 +-- bigframes/session/bq_caching_executor.py | 30 +++++- bigframes/session/loader.py | 120 ++++++++++++++--------- 5 files changed, 123 insertions(+), 58 deletions(-) diff --git a/bigframes/core/array_value.py b/bigframes/core/array_value.py index eba63ad72e..c58cd08ed8 100644 --- a/bigframes/core/array_value.py +++ b/bigframes/core/array_value.py @@ -133,8 +133,17 @@ def from_table( ordering=ordering, n_rows=n_rows, ) + return cls.from_bq_data_source(source_def, scan_list, session) + + @classmethod + def from_bq_data_source( + cls, + source: nodes.BigqueryDataSource, + scan_list: nodes.ScanList, + session: Session, + ): node = nodes.ReadTableNode( - source=source_def, + source=source, scan_list=scan_list, table_session=session, ) diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index f7327f2a7a..63ed000983 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -614,6 +614,11 @@ def project( result = ScanList((self.items[:1])) return result + def append( + self, source_id: str, dtype: bigframes.dtypes.Dtype, id: identifiers.ColumnId + ) -> ScanList: + return ScanList((*self.items, ScanItem(id, dtype, source_id))) + @dataclasses.dataclass(frozen=True, eq=False) class ReadLocalNode(LeafNode): diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 6379a6f2e8..bd4b406c8a 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -244,13 +244,6 @@ def __init__( self._temp_storage_manager = ( self._session_resource_manager or self._anon_dataset_manager ) - self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor( - bqclient=self._clients_provider.bqclient, - bqstoragereadclient=self._clients_provider.bqstoragereadclient, - storage_manager=self._temp_storage_manager, - strictly_ordered=self._strictly_ordered, - metrics=self._metrics, - ) self._loader = bigframes.session.loader.GbqDataLoader( session=self, bqclient=self._clients_provider.bqclient, @@ -261,6 +254,14 @@ def __init__( force_total_order=self._strictly_ordered, metrics=self._metrics, ) + self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor( + bqclient=self._clients_provider.bqclient, + bqstoragereadclient=self._clients_provider.bqstoragereadclient, + storage_manager=self._temp_storage_manager, + loader=self._loader, + strictly_ordered=self._strictly_ordered, + metrics=self._metrics, + ) def __del__(self): """Automatic cleanup of internal resources.""" diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index 584b41452a..65a9af1ccb 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -35,7 +35,7 @@ import bigframes.dtypes import bigframes.exceptions as bfe import bigframes.features -from bigframes.session import executor, local_scan_executor, read_api_execution +from bigframes.session import executor, loader, local_scan_executor, read_api_execution import bigframes.session._io.bigquery as bq_io import bigframes.session.metrics import bigframes.session.planner @@ -47,6 +47,7 @@ MAX_SUBTREE_FACTORINGS = 5 _MAX_CLUSTER_COLUMNS = 4 MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G +MAX_INLINE_BYTES = 5000 class BigQueryCachingExecutor(executor.Executor): @@ -63,6 +64,7 @@ def __init__( bqclient: bigquery.Client, storage_manager: bigframes.session.temporary_storage.TemporaryStorageManager, bqstoragereadclient: google.cloud.bigquery_storage_v1.BigQueryReadClient, + loader: loader.GbqDataLoader, *, strictly_ordered: bool = True, metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None, @@ -72,6 +74,7 @@ def __init__( self.compiler: bigframes.core.compile.SQLCompiler = ( bigframes.core.compile.SQLCompiler() ) + self.loader = loader self.strictly_ordered: bool = strictly_ordered self._cached_executions: weakref.WeakKeyDictionary[ nodes.BigFrameNode, nodes.BigFrameNode @@ -437,6 +440,31 @@ def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue): if not did_cache: return + def _upload_large_local_sources(self, root: nodes.BigFrameNode): + for leaf in root.unique_nodes(): + if isinstance(leaf, nodes.ReadLocalNode): + if leaf.local_data_source.metadata.total_bytes > MAX_INLINE_BYTES: + self._cache_local_table(leaf) + + def _cache_local_table(self, node: nodes.ReadLocalNode): + offsets_col = bigframes.core.guid.generate_guid() + # TODO: Best effort go through available upload paths + bq_data_source = self.loader.write_data( + node.local_data_source, offsets_col=offsets_col + ) + scan_list = node.scan_list + if node.offsets_col is not None: + scan_list = scan_list.append( + offsets_col, bigframes.dtypes.INT_DTYPE, node.offsets_col + ) + cache_node = nodes.CachedTableNode( + source=bq_data_source, + scan_list=scan_list, + table_session=self.loader._session, + original_node=node, + ) + self._cached_executions[node] = cache_node + def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool: # TODO: If query fails, retry with lower complexity limit selection = tree_properties.select_cache_target( diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 76f12ae438..ad23f7dfbb 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -43,7 +43,7 @@ import pandas import pyarrow as pa -from bigframes.core import guid, local_data, utils +from bigframes.core import guid, identifiers, local_data, nodes, ordering, utils import bigframes.core as core import bigframes.core.blocks as blocks import bigframes.core.schema as schemata @@ -183,35 +183,59 @@ def read_pandas( ) managed_data = local_data.ManagedArrowTable.from_pandas(prepared_df) + block = blocks.Block( + self.read_managed_data(managed_data, method=method, api_name=api_name), + index_columns=idx_cols, + column_labels=pandas_dataframe.columns, + index_labels=pandas_dataframe.index.names, + ) + return dataframe.DataFrame(block) + + def read_managed_data( + self, + data: local_data.ManagedArrowTable, + method: Literal["load", "stream", "write"], + api_name: str, + ) -> core.ArrayValue: + offsets_col = guid.generate_guid("upload_offsets_") if method == "load": - array_value = self.load_data(managed_data, api_name=api_name) + gbq_source = self.load_data( + data, offsets_col=offsets_col, api_name=api_name + ) elif method == "stream": - array_value = self.stream_data(managed_data) + gbq_source = self.stream_data(data, offsets_col=offsets_col) elif method == "write": - array_value = self.write_data(managed_data) + gbq_source = self.write_data(data, offsets_col=offsets_col) else: raise ValueError(f"Unsupported read method {method}") - block = blocks.Block( - array_value, - index_columns=idx_cols, - column_labels=pandas_dataframe.columns, - index_labels=pandas_dataframe.index.names, + return core.ArrayValue.from_bq_data_source( + source=gbq_source, + scan_list=nodes.ScanList( + tuple( + nodes.ScanItem( + identifiers.ColumnId(item.column), item.dtype, item.column + ) + for item in data.schema.items + ) + ), + session=self._session, ) - return dataframe.DataFrame(block) def load_data( - self, data: local_data.ManagedArrowTable, api_name: Optional[str] = None - ) -> core.ArrayValue: + self, + data: local_data.ManagedArrowTable, + offsets_col: str, + api_name: Optional[str] = None, + ) -> nodes.BigqueryDataSource: """Load managed data into bigquery""" - ordering_col = guid.generate_guid("load_offsets_") # JSON support incomplete for item in data.schema.items: _validate_dtype_can_load(item.column, item.dtype) schema_w_offsets = data.schema.append( - schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE) + schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE) ) bq_schema = schema_w_offsets.to_bigquery(_LOAD_JOB_TYPE_OVERRIDES) @@ -222,13 +246,13 @@ def load_data( job_config.labels = {"bigframes-api": api_name} load_table_destination = self._storage_manager.create_temp_table( - bq_schema, [ordering_col] + bq_schema, [offsets_col] ) buffer = io.BytesIO() data.to_parquet( buffer, - offsets_col=ordering_col, + offsets_col=offsets_col, geo_format="wkt", duration_type="duration", json_type="string", @@ -240,23 +264,24 @@ def load_data( self._start_generic_job(load_job) # must get table metadata after load job for accurate metadata destination_table = self._bqclient.get_table(load_table_destination) - return core.ArrayValue.from_table( - table=destination_table, - schema=schema_w_offsets, - session=self._session, - offsets_col=ordering_col, - n_rows=data.data.num_rows, - ).drop_columns([ordering_col]) + return nodes.BigqueryDataSource( + nodes.GbqTable.from_table(destination_table), + ordering=ordering.TotalOrdering.from_offset_col(offsets_col), + n_rows=destination_table.num_rows, + ) - def stream_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue: + def stream_data( + self, + data: local_data.ManagedArrowTable, + offsets_col: str, + ) -> nodes.BigqueryDataSource: """Load managed data into bigquery""" - ordering_col = guid.generate_guid("stream_offsets_") schema_w_offsets = data.schema.append( - schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE) + schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE) ) bq_schema = schema_w_offsets.to_bigquery(_STREAM_JOB_TYPE_OVERRIDES) load_table_destination = self._storage_manager.create_temp_table( - bq_schema, [ordering_col] + bq_schema, [offsets_col] ) rows = data.itertuples( @@ -275,24 +300,23 @@ def stream_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue: f"Problem loading at least one row from DataFrame: {errors}. {constants.FEEDBACK_LINK}" ) destination_table = self._bqclient.get_table(load_table_destination) - return core.ArrayValue.from_table( - table=destination_table, - schema=schema_w_offsets, - session=self._session, - offsets_col=ordering_col, - n_rows=data.data.num_rows, - ).drop_columns([ordering_col]) + return nodes.BigqueryDataSource( + nodes.GbqTable.from_table(destination_table), + ordering=ordering.TotalOrdering.from_offset_col(offsets_col), + n_rows=destination_table.num_rows, + ) - def write_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue: + def write_data( + self, + data: local_data.ManagedArrowTable, + offsets_col: str, + ) -> nodes.BigqueryDataSource: """Load managed data into bigquery""" - ordering_col = guid.generate_guid("stream_offsets_") schema_w_offsets = data.schema.append( - schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE) + schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE) ) bq_schema = schema_w_offsets.to_bigquery(_STREAM_JOB_TYPE_OVERRIDES) - bq_table_ref = self._storage_manager.create_temp_table( - bq_schema, [ordering_col] - ) + bq_table_ref = self._storage_manager.create_temp_table(bq_schema, [offsets_col]) requested_stream = bq_storage_types.stream.WriteStream() requested_stream.type_ = bq_storage_types.stream.WriteStream.Type.COMMITTED # type: ignore @@ -304,7 +328,7 @@ def write_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue: def request_gen() -> Generator[bq_storage_types.AppendRowsRequest, None, None]: schema, batches = data.to_arrow( - offsets_col=ordering_col, duration_type="int" + offsets_col=offsets_col, duration_type="int" ) offset = 0 for batch in batches: @@ -330,13 +354,11 @@ def request_gen() -> Generator[bq_storage_types.AppendRowsRequest, None, None]: assert response.row_count == data.data.num_rows destination_table = self._bqclient.get_table(bq_table_ref) - return core.ArrayValue.from_table( - table=destination_table, - schema=schema_w_offsets, - session=self._session, - offsets_col=ordering_col, - n_rows=data.data.num_rows, - ).drop_columns([ordering_col]) + return nodes.BigqueryDataSource( + nodes.GbqTable.from_table(destination_table), + ordering=ordering.TotalOrdering.from_offset_col(offsets_col), + n_rows=destination_table.num_rows, + ) def _start_generic_job(self, job: formatting_helpers.GenericJob): if bigframes.options.display.progress_bar is not None: @@ -533,7 +555,7 @@ def read_gbq_table( if not primary_key: array_value = array_value.order_by( [ - bigframes.core.ordering.OrderingExpression( + ordering.OrderingExpression( bigframes.operations.RowKey().as_expr( *(id for id in array_value.column_ids) ),