Skip to content

perf: Cache results opportunistically #1694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions bigframes/core/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
from __future__ import annotations

from bigframes.core.compile.api import SQLCompiler, test_only_ibis_inferred_schema
from bigframes.core.compile.compiler import compile_sql
from bigframes.core.compile.configs import CompileRequest, CompileResult

__all__ = [
"SQLCompiler",
"test_only_ibis_inferred_schema",
"compile_sql",
"CompileRequest",
"CompileResult",
]
23 changes: 22 additions & 1 deletion bigframes/core/pyarrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Iterator
import itertools
from typing import Iterable, Iterator, Optional

import pyarrow as pa

Expand Down Expand Up @@ -85,3 +86,23 @@ def truncate_pyarrow_iterable(
else:
yield batch
total_yielded += batch.num_rows


def peek_batches(
batch_iter: Iterable[pa.RecordBatch], max_bytes: int
) -> tuple[Iterator[pa.RecordBatch], Optional[tuple[pa.RecordBatch, ...]]]:
"""
Try to peek a pyarrow batch iterable. If greater than max_bytes, give up.

Will consume max_bytes + one batch of memory at worst.
"""
batch_list = []
current_bytes = 0
for batch in batch_iter:
batch_list.append(batch)
current_bytes += batch.nbytes

if current_bytes > max_bytes:
return itertools.chain(batch_list, batch_iter), None

return iter(batch_list), tuple(batch_list)
184 changes: 125 additions & 59 deletions bigframes/session/bq_caching_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import math
import os
from typing import cast, Literal, Mapping, Optional, Sequence, Tuple, Union
from typing import cast, Iterator, Literal, Mapping, Optional, Sequence, Tuple, Union
import warnings
import weakref

Expand All @@ -25,10 +25,10 @@
import google.cloud.bigquery.job as bq_job
import google.cloud.bigquery.table as bq_table
import google.cloud.bigquery_storage_v1
import pyarrow as pa

import bigframes.core
from bigframes.core import rewrite
import bigframes.core.compile
from bigframes.core import compile, local_data, pyarrow_utils, rewrite
import bigframes.core.guid
import bigframes.core.nodes as nodes
import bigframes.core.ordering as order
Expand Down Expand Up @@ -70,9 +70,6 @@ def __init__(
):
self.bqclient = bqclient
self.storage_manager = storage_manager
self.compiler: bigframes.core.compile.SQLCompiler = (
bigframes.core.compile.SQLCompiler()
)
self.strictly_ordered: bool = strictly_ordered
self._cached_executions: weakref.WeakKeyDictionary[
nodes.BigFrameNode, nodes.BigFrameNode
Expand All @@ -97,8 +94,11 @@ def to_sql(
) -> str:
if offset_column:
array_value, _ = array_value.promote_offsets()
node = self.logical_plan(array_value.node) if enable_cache else array_value.node
return self.compiler.compile(node, ordered=ordered)
node = (
self.simplify_plan(array_value.node) if enable_cache else array_value.node
)
compiled = compile.compile_sql(compile.CompileRequest(node, sort_rows=ordered))
return compiled.sql

def execute(
self,
Expand All @@ -113,7 +113,6 @@ def execute(
if bigframes.options.compute.enable_multi_query_execution:
self._simplify_with_caching(array_value)

plan = self.logical_plan(array_value.node)
# Use explicit destination to avoid 10GB limit of temporary table
destination_table = (
self.storage_manager.create_temp_table(
Expand All @@ -123,7 +122,7 @@ def execute(
else None
)
return self._execute_plan(
plan,
array_value.node,
ordered=ordered,
destination=destination_table,
)
Expand Down Expand Up @@ -220,7 +219,7 @@ def peek(
"""
A 'peek' efficiently accesses a small number of rows in the dataframe.
"""
plan = self.logical_plan(array_value.node)
plan = self.simplify_plan(array_value.node)
if not tree_properties.can_fast_peek(plan):
msg = bfe.format_message("Peeking this value cannot be done efficiently.")
warnings.warn(msg)
Expand All @@ -236,7 +235,7 @@ def peek(
)

return self._execute_plan(
plan, ordered=False, destination=destination_table, peek=n_rows
array_value.node, ordered=False, destination=destination_table, peek=n_rows
)

def cached(
Expand Down Expand Up @@ -321,10 +320,10 @@ def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue):
# Once rewriting is available, will want to rewrite before
# evaluating execution cost.
return tree_properties.is_trivially_executable(
self.logical_plan(array_value.node)
self.simplify_plan(array_value.node)
)

def logical_plan(self, root: nodes.BigFrameNode) -> nodes.BigFrameNode:
def simplify_plan(self, root: nodes.BigFrameNode) -> nodes.BigFrameNode:
"""
Apply universal logical simplifications that are helpful regardless of engine.
"""
Expand All @@ -337,37 +336,38 @@ def _cache_with_cluster_cols(
self, array_value: bigframes.core.ArrayValue, cluster_cols: Sequence[str]
):
"""Executes the query and uses the resulting table to rewrite future executions."""

sql, schema, ordering_info = self.compiler.compile_raw(
self.logical_plan(array_value.node)
plan = self.simplify_plan(array_value.node)
compiled = compile.compile_sql(
compile.CompileRequest(
plan, sort_rows=False, materialize_all_order_keys=True
)
)
tmp_table = self._sql_as_cached_temp_table(
sql,
schema,
cluster_cols=bq_io.select_cluster_cols(schema, cluster_cols),
tmp_table_ref = self._sql_as_cached_temp_table(
compiled.sql,
compiled.sql_schema,
cluster_cols=bq_io.select_cluster_cols(compiled.sql_schema, cluster_cols),
)
cached_replacement = array_value.as_cached(
cache_table=self.bqclient.get_table(tmp_table),
ordering=ordering_info,
).node
self._cached_executions[array_value.node] = cached_replacement
tmp_table = self.bqclient.get_table(tmp_table_ref)
assert compiled.row_order is not None
self._cache_results_table(array_value.node, tmp_table, compiled.row_order)

def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
"""Executes the query and uses the resulting table to rewrite future executions."""
offset_column = bigframes.core.guid.generate_guid("bigframes_offsets")
w_offsets, offset_column = array_value.promote_offsets()
sql = self.compiler.compile(self.logical_plan(w_offsets.node), ordered=False)

tmp_table = self._sql_as_cached_temp_table(
sql,
compiled = compile.compile_sql(
compile.CompileRequest(
array_value.node, sort_rows=False, materialize_all_order_keys=True
)
)
tmp_table_ref = self._sql_as_cached_temp_table(
compiled.sql,
w_offsets.schema.to_bigquery(),
cluster_cols=[offset_column],
)
cached_replacement = array_value.as_cached(
cache_table=self.bqclient.get_table(tmp_table),
ordering=order.TotalOrdering.from_offset_col(offset_column),
).node
self._cached_executions[array_value.node] = cached_replacement
tmp_table = self.bqclient.get_table(tmp_table_ref)
assert compiled.row_order is not None
self._cache_results_table(array_value.node, tmp_table, compiled.row_order)

def _cache_with_session_awareness(
self,
Expand All @@ -393,7 +393,7 @@ def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue):
# Apply existing caching first
for _ in range(MAX_SUBTREE_FACTORINGS):
if (
self.logical_plan(array_value.node).planning_complexity
self.simplify_plan(array_value.node).planning_complexity
< QUERY_COMPLEXITY_LIMIT
):
return
Expand Down Expand Up @@ -450,8 +450,8 @@ def _validate_result_schema(
bq_schema: list[bigquery.SchemaField],
):
actual_schema = _sanitize(tuple(bq_schema))
ibis_schema = bigframes.core.compile.test_only_ibis_inferred_schema(
self.logical_plan(array_value.node)
ibis_schema = compile.test_only_ibis_inferred_schema(
self.simplify_plan(array_value.node)
).to_bigquery()
internal_schema = _sanitize(array_value.schema.to_bigquery())
if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable:
Expand All @@ -469,7 +469,7 @@ def _validate_result_schema(

def _execute_plan(
self,
plan: nodes.BigFrameNode,
root: nodes.BigFrameNode,
ordered: bool,
destination: Optional[bq_table.TableReference] = None,
peek: Optional[int] = None,
Expand All @@ -479,7 +479,9 @@ def _execute_plan(
# First try to execute fast-paths
if (not destination) and (not peek):
for semi_executor in self._semi_executors:
maybe_result = semi_executor.execute(plan, ordered=ordered)
maybe_result = semi_executor.execute(
self.simplify_plan(root), ordered=ordered
)
if maybe_result:
return maybe_result

Expand All @@ -489,45 +491,109 @@ def _execute_plan(
# Use explicit destination to avoid 10GB limit of temporary table
if destination is not None:
job_config.destination = destination
sql = self.compiler.compile(plan, ordered=ordered, limit=peek)
compiled = compile.compile_sql(
compile.CompileRequest(
self.simplify_plan(root), sort_rows=ordered, peek_count=peek
)
)
iterator, query_job = self._run_execute_query(
sql=sql,
sql=compiled.sql,
job_config=job_config,
query_with_job=(destination is not None),
)

# Though we provide the read client, iterator may or may not use it based on what is efficient for the result
def iterator_supplier():
return iterator.to_arrow_iterable(bqstorage_client=self.bqstoragereadclient)
batch_iterator = iterator.to_arrow_iterable(
bqstorage_client=self.bqstoragereadclient
)

if query_job:
size_bytes = self.bqclient.get_table(query_job.destination).num_bytes
table = self.bqclient.get_table(query_job.destination)
else:
size_bytes = None

if size_bytes is not None and size_bytes >= MAX_SMALL_RESULT_BYTES:
msg = bfe.format_message(
"The query result size has exceeded 10 GB. In BigFrames 2.0 and "
"later, you might need to manually set `allow_large_results=True` in "
"the IO method or adjust the BigFrames option: "
"`bigframes.options.bigquery.allow_large_results=True`."
)
warnings.warn(msg, FutureWarning)
table = None

# Runs strict validations to ensure internal type predictions and ibis are completely in sync
# Do not execute these validations outside of testing suite.
if "PYTEST_CURRENT_TEST" in os.environ:
self._validate_result_schema(
bigframes.core.ArrayValue(plan), iterator.schema
bigframes.core.ArrayValue(root), iterator.schema
)

if peek is None: # peek is lossy, don't cache
if (destination is not None) and (table is not None):
if compiled.row_order is not None:
# Assumption: GBQ cached table uses field name as bq column name
self._cache_results_table(root, table, compiled.row_order)
elif (
ordered
): # no explicit destination, can maybe peek iterator, but rows need to be ordered
# need to reassign batch_iterator since this method consumes the head of the original
batch_iterator = self._try_cache_results_iterator(root, batch_iterator)

return executor.ExecuteResult(
arrow_batches=iterator_supplier,
schema=plan.schema,
arrow_batches=batch_iterator,
schema=root.schema,
query_job=query_job,
total_bytes=size_bytes,
total_bytes=table.num_bytes if table else None,
total_rows=iterator.total_rows,
)

def _cache_results_table(
self,
original_root: nodes.BigFrameNode,
table: bigquery.Table,
ordering: order.RowOrdering,
):
# if destination is set, this is an externally managed table, which may mutated, cannot use as cache
# Assumption: GBQ cached table uses field name as bq column name
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(field.id, field.dtype, field.id.sql)
for field in original_root.fields
)
)
cached_replacement = nodes.CachedTableNode(
source=nodes.BigqueryDataSource(
nodes.GbqTable.from_table(
table, columns=tuple(f.id.name for f in original_root.fields)
),
ordering=ordering,
n_rows=table.num_rows,
),
scan_list=scan_list,
table_session=original_root.session,
original_node=original_root,
)
assert original_root.schema == cached_replacement.schema
self._cached_executions[original_root] = cached_replacement

def _try_cache_results_iterator(
self,
original_root: nodes.BigFrameNode,
batch_iterator: Iterator[pa.RecordBatch],
) -> Iterator[pa.RecordBatch]:
# Assumption: GBQ cached table uses field name as bq column name
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(field.id, field.dtype, field.id.name)
for field in original_root.fields
)
)
# Will increase when have auto-upload, 5000 is most want to inline
batch_iterator, batches = pyarrow_utils.peek_batches(
batch_iterator, max_bytes=5000
)
if batches:
local_cached = nodes.ReadLocalNode(
local_data_source=local_data.ManagedArrowTable.from_pyarrow(
pa.Table.from_batches(batches)
),
scan_list=scan_list,
session=original_root.session,
)
self._cached_executions[original_root] = local_cached
return batch_iterator


def _sanitize(
schema: Tuple[bigquery.SchemaField, ...]
Expand Down
Loading