From cd2bb1b06c90bf63a03f4a8480afc23a01ebe589 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Tue, 6 May 2025 00:32:22 +0000 Subject: [PATCH 1/2] perf: Cache results opportunistically --- bigframes/core/compile/__init__.py | 5 + bigframes/core/pyarrow_utils.py | 37 ++++++ bigframes/session/bq_caching_executor.py | 147 ++++++++++++++++------- bigframes/session/executor.py | 8 +- bigframes/session/local_scan_executor.py | 26 ++-- bigframes/session/read_api_execution.py | 45 ++++--- tests/unit/polars_session.py | 2 +- 7 files changed, 182 insertions(+), 88 deletions(-) create mode 100644 bigframes/core/pyarrow_utils.py diff --git a/bigframes/core/compile/__init__.py b/bigframes/core/compile/__init__.py index 0bfdf2222d..2e1651cd71 100644 --- a/bigframes/core/compile/__init__.py +++ b/bigframes/core/compile/__init__.py @@ -14,8 +14,13 @@ from __future__ import annotations from bigframes.core.compile.api import SQLCompiler, test_only_ibis_inferred_schema +from bigframes.core.compile.compiler import compile_sql +from bigframes.core.compile.configs import CompileRequest, CompileResult __all__ = [ "SQLCompiler", "test_only_ibis_inferred_schema", + "compile_sql", + "CompileRequest", + "CompileResult", ] diff --git a/bigframes/core/pyarrow_utils.py b/bigframes/core/pyarrow_utils.py new file mode 100644 index 0000000000..d5ffdb65df --- /dev/null +++ b/bigframes/core/pyarrow_utils.py @@ -0,0 +1,37 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from typing import Iterable, Iterator, Optional + +import pyarrow as pa + + +def peek_batches( + batch_iter: Iterable[pa.RecordBatch], max_bytes: int +) -> tuple[Iterator[pa.RecordBatch], Optional[tuple[pa.RecordBatch, ...]]]: + """ + Try to peek a pyarrow batch iterable. If greater than max_bytes, give up. + + Will consume max_bytes + one batch of memory at worst. + """ + batch_list = [] + current_bytes = 0 + for batch in batch_iter: + batch_list.append(batch) + current_bytes += batch.nbytes + + if current_bytes > max_bytes: + return itertools.chain(batch_list, batch_iter), None + + return iter(batch_list), tuple(batch_list) diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index 4c10d76253..f026190ca8 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -25,10 +25,10 @@ import google.cloud.bigquery.job as bq_job import google.cloud.bigquery.table as bq_table import google.cloud.bigquery_storage_v1 +import pyarrow as pa import bigframes.core -from bigframes.core import rewrite -import bigframes.core.compile +from bigframes.core import compile, local_data, pyarrow_utils, rewrite import bigframes.core.guid import bigframes.core.nodes as nodes import bigframes.core.ordering as order @@ -70,9 +70,6 @@ def __init__( ): self.bqclient = bqclient self.storage_manager = storage_manager - self.compiler: bigframes.core.compile.SQLCompiler = ( - bigframes.core.compile.SQLCompiler() - ) self.strictly_ordered: bool = strictly_ordered self._cached_executions: weakref.WeakKeyDictionary[ nodes.BigFrameNode, nodes.BigFrameNode @@ -97,8 +94,11 @@ def to_sql( ) -> str: if offset_column: array_value, _ = array_value.promote_offsets() - node = self.logical_plan(array_value.node) if enable_cache else array_value.node - return self.compiler.compile(node, ordered=ordered) + node = ( + self.simplify_plan(array_value.node) if enable_cache else array_value.node + ) + compiled = compile.compile_sql(compile.CompileRequest(node, sort_rows=ordered)) + return compiled.sql def execute( self, @@ -115,7 +115,6 @@ def execute( if bigframes.options.compute.enable_multi_query_execution: self._simplify_with_caching(array_value) - plan = self.logical_plan(array_value.node) # Use explicit destination to avoid 10GB limit of temporary table destination_table = ( self.storage_manager.create_temp_table( @@ -125,7 +124,7 @@ def execute( else None ) return self._execute_plan( - plan, + array_value.node, ordered=ordered, page_size=page_size, max_results=max_results, @@ -224,7 +223,7 @@ def peek( """ A 'peek' efficiently accesses a small number of rows in the dataframe. """ - plan = self.logical_plan(array_value.node) + plan = self.simplify_plan(array_value.node) if not tree_properties.can_fast_peek(plan): msg = bfe.format_message("Peeking this value cannot be done efficiently.") warnings.warn(msg) @@ -240,7 +239,7 @@ def peek( ) return self._execute_plan( - plan, ordered=False, destination=destination_table, peek=n_rows + array_value.node, ordered=False, destination=destination_table, peek=n_rows ) def cached( @@ -329,10 +328,10 @@ def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue): # Once rewriting is available, will want to rewrite before # evaluating execution cost. return tree_properties.is_trivially_executable( - self.logical_plan(array_value.node) + self.simplify_plan(array_value.node) ) - def logical_plan(self, root: nodes.BigFrameNode) -> nodes.BigFrameNode: + def simplify_plan(self, root: nodes.BigFrameNode) -> nodes.BigFrameNode: """ Apply universal logical simplifications that are helpful regardless of engine. """ @@ -345,18 +344,20 @@ def _cache_with_cluster_cols( self, array_value: bigframes.core.ArrayValue, cluster_cols: Sequence[str] ): """Executes the query and uses the resulting table to rewrite future executions.""" - - sql, schema, ordering_info = self.compiler.compile_raw( - self.logical_plan(array_value.node) + plan = self.simplify_plan(array_value.node) + compiled = compile.compile_sql( + compile.CompileRequest( + plan, sort_rows=False, materialize_all_order_keys=True + ) ) tmp_table = self._sql_as_cached_temp_table( - sql, - schema, - cluster_cols=bq_io.select_cluster_cols(schema, cluster_cols), + compiled.sql, + compiled.sql_schema, + cluster_cols=bq_io.select_cluster_cols(compiled.sql_schema, cluster_cols), ) cached_replacement = array_value.as_cached( cache_table=self.bqclient.get_table(tmp_table), - ordering=ordering_info, + ordering=compiled.row_order, ).node self._cached_executions[array_value.node] = cached_replacement @@ -364,10 +365,14 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue): """Executes the query and uses the resulting table to rewrite future executions.""" offset_column = bigframes.core.guid.generate_guid("bigframes_offsets") w_offsets, offset_column = array_value.promote_offsets() - sql = self.compiler.compile(self.logical_plan(w_offsets.node), ordered=False) + compiled = compile.compile_sql( + compile.CompileRequest( + array_value.node, sort_rows=False, materialize_all_order_keys=True + ) + ) tmp_table = self._sql_as_cached_temp_table( - sql, + compiled.sql, w_offsets.schema.to_bigquery(), cluster_cols=[offset_column], ) @@ -401,7 +406,7 @@ def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue): # Apply existing caching first for _ in range(MAX_SUBTREE_FACTORINGS): if ( - self.logical_plan(array_value.node).planning_complexity + self.simplify_plan(array_value.node).planning_complexity < QUERY_COMPLEXITY_LIMIT ): return @@ -458,8 +463,8 @@ def _validate_result_schema( bq_schema: list[bigquery.SchemaField], ): actual_schema = _sanitize(tuple(bq_schema)) - ibis_schema = bigframes.core.compile.test_only_ibis_inferred_schema( - self.logical_plan(array_value.node) + ibis_schema = compile.test_only_ibis_inferred_schema( + self.simplify_plan(array_value.node) ).to_bigquery() internal_schema = _sanitize(array_value.schema.to_bigquery()) if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable: @@ -477,7 +482,7 @@ def _validate_result_schema( def _execute_plan( self, - plan: nodes.BigFrameNode, + root: nodes.BigFrameNode, ordered: bool, page_size: Optional[int] = None, max_results: Optional[int] = None, @@ -490,7 +495,9 @@ def _execute_plan( # TODO: Allow page_size and max_results by rechunking/truncating results if (not page_size) and (not max_results) and (not destination) and (not peek): for semi_executor in self._semi_executors: - maybe_result = semi_executor.execute(plan, ordered=ordered) + maybe_result = semi_executor.execute( + self.simplify_plan(root), ordered=ordered + ) if maybe_result: return maybe_result @@ -500,9 +507,13 @@ def _execute_plan( # Use explicit destination to avoid 10GB limit of temporary table if destination is not None: job_config.destination = destination - sql = self.compiler.compile(plan, ordered=ordered, limit=peek) + compiled = compile.compile_sql( + compile.CompileRequest( + self.simplify_plan(root), sort_rows=ordered, peek_count=peek + ) + ) iterator, query_job = self._run_execute_query( - sql=sql, + sql=compiled.sql, job_config=job_config, page_size=page_size, max_results=max_results, @@ -510,21 +521,20 @@ def _execute_plan( ) # Though we provide the read client, iterator may or may not use it based on what is efficient for the result - def iterator_supplier(): - # Workaround issue fixed by: https://github.com/googleapis/python-bigquery/pull/2154 - if iterator._page_size is not None or iterator.max_results is not None: - return iterator.to_arrow_iterable(bqstorage_client=None) - else: - return iterator.to_arrow_iterable( - bqstorage_client=self.bqstoragereadclient - ) + # Workaround issue fixed by: https://github.com/googleapis/python-bigquery/pull/2154 + if iterator._page_size is not None or iterator.max_results is not None: + batch_iterator = iterator.to_arrow_iterable(bqstorage_client=None) + else: + batch_iterator = iterator.to_arrow_iterable( + bqstorage_client=self.bqstoragereadclient + ) if query_job: - size_bytes = self.bqclient.get_table(query_job.destination).num_bytes + table = self.bqclient.get_table(query_job.destination) else: - size_bytes = None + table = None - if size_bytes is not None and size_bytes >= MAX_SMALL_RESULT_BYTES: + if (table is not None) and (table.num_bytes or 0) >= MAX_SMALL_RESULT_BYTES: msg = bfe.format_message( "The query result size has exceeded 10 GB. In BigFrames 2.0 and " "later, you might need to manually set `allow_large_results=True` in " @@ -536,14 +546,63 @@ def iterator_supplier(): # Do not execute these validations outside of testing suite. if "PYTEST_CURRENT_TEST" in os.environ: self._validate_result_schema( - bigframes.core.ArrayValue(plan), iterator.schema + bigframes.core.ArrayValue(root), iterator.schema + ) + + # if destination is set, this is an externally managed table, which may mutated, cannot use as cache + if ( + (destination is not None) + and (table is not None) + and (compiled.row_order is not None) + and (peek is None) + ): + # Assumption: GBQ cached table uses field name as bq column name + scan_list = nodes.ScanList( + tuple( + nodes.ScanItem(field.id, field.dtype, field.id.name) + for field in root.fields + ) + ) + cached_replacement = nodes.CachedTableNode( + source=nodes.BigqueryDataSource( + nodes.GbqTable.from_table( + table, columns=tuple(f.id.name for f in root.fields) + ), + ordering=compiled.row_order, + n_rows=table.num_rows, + ), + scan_list=scan_list, + table_session=root.session, + original_node=root, + ) + self._cached_executions[root] = cached_replacement + else: # no explicit destination, can maybe peek iterator + # Assumption: GBQ cached table uses field name as bq column name + scan_list = nodes.ScanList( + tuple( + nodes.ScanItem(field.id, field.dtype, field.id.name) + for field in root.fields + ) + ) + # Will increase when have auto-upload, 5000 is most want to inline + batch_iterator, batches = pyarrow_utils.peek_batches( + batch_iterator, max_bytes=5000 ) + if batches: + local_cached = nodes.ReadLocalNode( + local_data_source=local_data.ManagedArrowTable.from_pyarrow( + pa.Table.from_batches(batches) + ), + scan_list=scan_list, + session=root.session, + ) + self._cached_executions[root] = local_cached return executor.ExecuteResult( - arrow_batches=iterator_supplier, - schema=plan.schema, + arrow_batches=batch_iterator, + schema=root.schema, query_job=query_job, - total_bytes=size_bytes, + total_bytes=table.num_bytes if table else None, total_rows=iterator.total_rows, ) diff --git a/bigframes/session/executor.py b/bigframes/session/executor.py index 9075f4eee6..6536317b28 100644 --- a/bigframes/session/executor.py +++ b/bigframes/session/executor.py @@ -18,7 +18,7 @@ import dataclasses import functools import itertools -from typing import Callable, Iterator, Literal, Mapping, Optional, Sequence, Union +from typing import Iterator, Literal, Mapping, Optional, Sequence, Union from google.cloud import bigquery import pandas as pd @@ -31,7 +31,7 @@ @dataclasses.dataclass(frozen=True) class ExecuteResult: - arrow_batches: Callable[[], Iterator[pyarrow.RecordBatch]] + arrow_batches: Iterator[pyarrow.RecordBatch] schema: bigframes.core.schema.ArraySchema query_job: Optional[bigquery.QueryJob] = None total_bytes: Optional[int] = None @@ -41,7 +41,7 @@ def to_arrow_table(self) -> pyarrow.Table: # Need to provide schema if no result rows, as arrow can't infer # If ther are rows, it is safest to infer schema from batches. # Any discrepencies between predicted schema and actual schema will produce errors. - batches = iter(self.arrow_batches()) + batches = iter(self.arrow_batches) peek_it = itertools.islice(batches, 0, 1) peek_value = list(peek_it) # TODO: Enforce our internal schema on the table for consistency @@ -58,7 +58,7 @@ def to_pandas(self) -> pd.DataFrame: def to_pandas_batches(self) -> Iterator[pd.DataFrame]: yield from map( functools.partial(io_pandas.arrow_to_pandas, schema=self.schema), - self.arrow_batches(), + self.arrow_batches, ) def to_py_scalar(self): diff --git a/bigframes/session/local_scan_executor.py b/bigframes/session/local_scan_executor.py index 67e381ab8a..88304fa181 100644 --- a/bigframes/session/local_scan_executor.py +++ b/bigframes/session/local_scan_executor.py @@ -35,30 +35,24 @@ def execute( return None # TODO: Can support some slicing, sorting - def iterator_supplier(): - offsets_col = ( - node.offsets_col.sql if (node.offsets_col is not None) else None - ) - arrow_table = node.local_data_source.to_pyarrow_table( - offsets_col=offsets_col - ) - if peek: - arrow_table = arrow_table.slice(0, peek) + offsets_col = node.offsets_col.sql if (node.offsets_col is not None) else None + arrow_table = node.local_data_source.to_pyarrow_table(offsets_col=offsets_col) + if peek: + arrow_table = arrow_table.slice(0, peek) - needed_cols = [item.source_id for item in node.scan_list.items] - if offsets_col is not None: - needed_cols.append(offsets_col) + needed_cols = [item.source_id for item in node.scan_list.items] + if offsets_col is not None: + needed_cols.append(offsets_col) - arrow_table = arrow_table.select(needed_cols) - arrow_table = arrow_table.rename_columns([id.sql for id in node.ids]) - yield from arrow_table.to_batches() + arrow_table = arrow_table.select(needed_cols) + arrow_table = arrow_table.rename_columns([id.sql for id in node.ids]) total_rows = node.row_count if (peek is not None) and (total_rows is not None): total_rows = min(peek, total_rows) return executor.ExecuteResult( - arrow_batches=iterator_supplier, + arrow_batches=arrow_table.to_batches(), schema=plan.schema, query_job=None, total_bytes=None, diff --git a/bigframes/session/read_api_execution.py b/bigframes/session/read_api_execution.py index ae1272e722..5d79efcbfe 100644 --- a/bigframes/session/read_api_execution.py +++ b/bigframes/session/read_api_execution.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional +from typing import Any, Iterator, Optional from google.cloud import bigquery_storage_v1 import pyarrow as pa @@ -66,41 +66,40 @@ def execute( table_mod_options["snapshot_time"] = snapshot_time = snapshot_time table_mods = bq_storage_types.ReadSession.TableModifiers(**table_mod_options) - def iterator_supplier(): - requested_session = bq_storage_types.stream.ReadSession( - table=bq_table.to_bqstorage(), - data_format=bq_storage_types.DataFormat.ARROW, - read_options=read_options, - table_modifiers=table_mods, - ) - # Single stream to maintain ordering - request = bq_storage_types.CreateReadSessionRequest( - parent=f"projects/{self.project}", - read_session=requested_session, - max_stream_count=1, - ) - session = self.bqstoragereadclient.create_read_session( - request=request, retry=None - ) - - if not session.streams: - return iter([]) + requested_session = bq_storage_types.stream.ReadSession( + table=bq_table.to_bqstorage(), + data_format=bq_storage_types.DataFormat.ARROW, + read_options=read_options, + table_modifiers=table_mods, + ) + # Single stream to maintain ordering + request = bq_storage_types.CreateReadSessionRequest( + parent=f"projects/{self.project}", + read_session=requested_session, + max_stream_count=1, + ) + session = self.bqstoragereadclient.create_read_session( + request=request, retry=None + ) + if not session.streams: + batches: Iterator[pa.RecordBatch] = iter([]) + else: reader = self.bqstoragereadclient.read_rows( session.streams[0].name, retry=None ) rowstream = reader.rows() - def process_page(page): + def process_page(page) -> pa.RecordBatch: pa_batch = page.to_arrow() return pa.RecordBatch.from_arrays( pa_batch.columns, names=[id.sql for id in node.ids] ) - return map(process_page, rowstream.pages) + batches = map(process_page, rowstream.pages) return executor.ExecuteResult( - arrow_batches=iterator_supplier, + arrow_batches=batches, schema=plan.schema, query_job=None, total_bytes=None, diff --git a/tests/unit/polars_session.py b/tests/unit/polars_session.py index a27db0e438..d592b49038 100644 --- a/tests/unit/polars_session.py +++ b/tests/unit/polars_session.py @@ -51,7 +51,7 @@ def execute( # Currently, pyarrow types might not quite be exactly the ones in the bigframes schema. # Nullability may be different, and might use large versions of list, string datatypes. return bigframes.session.executor.ExecuteResult( - arrow_batches=lambda: pa_table.to_batches(), + arrow_batches=pa_table.to_batches(), schema=array_value.schema, total_bytes=pa_table.nbytes, total_rows=pa_table.num_rows, From 4599ffe2bc8a0cf7a615ba88e090bf38983aa57f Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Tue, 6 May 2025 19:34:40 +0000 Subject: [PATCH 2/2] refactor executor caching --- bigframes/session/bq_caching_executor.py | 145 ++++++++++++----------- 1 file changed, 75 insertions(+), 70 deletions(-) diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index f026190ca8..b1030f8367 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -16,7 +16,7 @@ import math import os -from typing import cast, Literal, Mapping, Optional, Sequence, Tuple, Union +from typing import cast, Iterator, Literal, Mapping, Optional, Sequence, Tuple, Union import warnings import weakref @@ -350,16 +350,14 @@ def _cache_with_cluster_cols( plan, sort_rows=False, materialize_all_order_keys=True ) ) - tmp_table = self._sql_as_cached_temp_table( + tmp_table_ref = self._sql_as_cached_temp_table( compiled.sql, compiled.sql_schema, cluster_cols=bq_io.select_cluster_cols(compiled.sql_schema, cluster_cols), ) - cached_replacement = array_value.as_cached( - cache_table=self.bqclient.get_table(tmp_table), - ordering=compiled.row_order, - ).node - self._cached_executions[array_value.node] = cached_replacement + tmp_table = self.bqclient.get_table(tmp_table_ref) + assert compiled.row_order is not None + self._cache_results_table(array_value.node, tmp_table, compiled.row_order) def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue): """Executes the query and uses the resulting table to rewrite future executions.""" @@ -370,17 +368,14 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue): array_value.node, sort_rows=False, materialize_all_order_keys=True ) ) - - tmp_table = self._sql_as_cached_temp_table( + tmp_table_ref = self._sql_as_cached_temp_table( compiled.sql, w_offsets.schema.to_bigquery(), cluster_cols=[offset_column], ) - cached_replacement = array_value.as_cached( - cache_table=self.bqclient.get_table(tmp_table), - ordering=order.TotalOrdering.from_offset_col(offset_column), - ).node - self._cached_executions[array_value.node] = cached_replacement + tmp_table = self.bqclient.get_table(tmp_table_ref) + assert compiled.row_order is not None + self._cache_results_table(array_value.node, tmp_table, compiled.row_order) def _cache_with_session_awareness( self, @@ -534,14 +529,6 @@ def _execute_plan( else: table = None - if (table is not None) and (table.num_bytes or 0) >= MAX_SMALL_RESULT_BYTES: - msg = bfe.format_message( - "The query result size has exceeded 10 GB. In BigFrames 2.0 and " - "later, you might need to manually set `allow_large_results=True` in " - "the IO method or adjust the BigFrames option: " - "`bigframes.options.bigquery.allow_large_results=True`." - ) - warnings.warn(msg, FutureWarning) # Runs strict validations to ensure internal type predictions and ibis are completely in sync # Do not execute these validations outside of testing suite. if "PYTEST_CURRENT_TEST" in os.environ: @@ -549,54 +536,16 @@ def _execute_plan( bigframes.core.ArrayValue(root), iterator.schema ) - # if destination is set, this is an externally managed table, which may mutated, cannot use as cache - if ( - (destination is not None) - and (table is not None) - and (compiled.row_order is not None) - and (peek is None) - ): - # Assumption: GBQ cached table uses field name as bq column name - scan_list = nodes.ScanList( - tuple( - nodes.ScanItem(field.id, field.dtype, field.id.name) - for field in root.fields - ) - ) - cached_replacement = nodes.CachedTableNode( - source=nodes.BigqueryDataSource( - nodes.GbqTable.from_table( - table, columns=tuple(f.id.name for f in root.fields) - ), - ordering=compiled.row_order, - n_rows=table.num_rows, - ), - scan_list=scan_list, - table_session=root.session, - original_node=root, - ) - self._cached_executions[root] = cached_replacement - else: # no explicit destination, can maybe peek iterator - # Assumption: GBQ cached table uses field name as bq column name - scan_list = nodes.ScanList( - tuple( - nodes.ScanItem(field.id, field.dtype, field.id.name) - for field in root.fields - ) - ) - # Will increase when have auto-upload, 5000 is most want to inline - batch_iterator, batches = pyarrow_utils.peek_batches( - batch_iterator, max_bytes=5000 - ) - if batches: - local_cached = nodes.ReadLocalNode( - local_data_source=local_data.ManagedArrowTable.from_pyarrow( - pa.Table.from_batches(batches) - ), - scan_list=scan_list, - session=root.session, - ) - self._cached_executions[root] = local_cached + if peek is None: # peek is lossy, don't cache + if (destination is not None) and (table is not None): + if compiled.row_order is not None: + # Assumption: GBQ cached table uses field name as bq column name + self._cache_results_table(root, table, compiled.row_order) + elif ( + ordered + ): # no explicit destination, can maybe peek iterator, but rows need to be ordered + # need to reassign batch_iterator since this method consumes the head of the original + batch_iterator = self._try_cache_results_iterator(root, batch_iterator) return executor.ExecuteResult( arrow_batches=batch_iterator, @@ -606,6 +555,62 @@ def _execute_plan( total_rows=iterator.total_rows, ) + def _cache_results_table( + self, + original_root: nodes.BigFrameNode, + table: bigquery.Table, + ordering: order.RowOrdering, + ): + # if destination is set, this is an externally managed table, which may mutated, cannot use as cache + # Assumption: GBQ cached table uses field name as bq column name + scan_list = nodes.ScanList( + tuple( + nodes.ScanItem(field.id, field.dtype, field.id.sql) + for field in original_root.fields + ) + ) + cached_replacement = nodes.CachedTableNode( + source=nodes.BigqueryDataSource( + nodes.GbqTable.from_table( + table, columns=tuple(f.id.name for f in original_root.fields) + ), + ordering=ordering, + n_rows=table.num_rows, + ), + scan_list=scan_list, + table_session=original_root.session, + original_node=original_root, + ) + assert original_root.schema == cached_replacement.schema + self._cached_executions[original_root] = cached_replacement + + def _try_cache_results_iterator( + self, + original_root: nodes.BigFrameNode, + batch_iterator: Iterator[pa.RecordBatch], + ) -> Iterator[pa.RecordBatch]: + # Assumption: GBQ cached table uses field name as bq column name + scan_list = nodes.ScanList( + tuple( + nodes.ScanItem(field.id, field.dtype, field.id.name) + for field in original_root.fields + ) + ) + # Will increase when have auto-upload, 5000 is most want to inline + batch_iterator, batches = pyarrow_utils.peek_batches( + batch_iterator, max_bytes=5000 + ) + if batches: + local_cached = nodes.ReadLocalNode( + local_data_source=local_data.ManagedArrowTable.from_pyarrow( + pa.Table.from_batches(batches) + ), + scan_list=scan_list, + session=original_root.session, + ) + self._cached_executions[original_root] = local_cached + return batch_iterator + def _sanitize( schema: Tuple[bigquery.SchemaField, ...]