Skip to content

perf: Defer some data uploads to execution time #1658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion bigframes/core/array_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,17 @@ def from_table(
ordering=ordering,
n_rows=n_rows,
)
return cls.from_bq_data_source(source_def, scan_list, session)

@classmethod
def from_bq_data_source(
cls,
source: nodes.BigqueryDataSource,
scan_list: nodes.ScanList,
session: Session,
):
node = nodes.ReadTableNode(
source=source_def,
source=source,
scan_list=scan_list,
table_session=session,
)
Expand Down
5 changes: 5 additions & 0 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,11 @@ def project(
result = ScanList((self.items[:1]))
return result

def append(
self, source_id: str, dtype: bigframes.dtypes.Dtype, id: identifiers.ColumnId
) -> ScanList:
return ScanList((*self.items, ScanItem(id, dtype, source_id)))


@dataclasses.dataclass(frozen=True, eq=False)
class ReadLocalNode(LeafNode):
Expand Down
15 changes: 8 additions & 7 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,6 @@ def __init__(
self._temp_storage_manager = (
self._session_resource_manager or self._anon_dataset_manager
)
self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor(
bqclient=self._clients_provider.bqclient,
bqstoragereadclient=self._clients_provider.bqstoragereadclient,
storage_manager=self._temp_storage_manager,
strictly_ordered=self._strictly_ordered,
metrics=self._metrics,
)
self._loader = bigframes.session.loader.GbqDataLoader(
session=self,
bqclient=self._clients_provider.bqclient,
Expand All @@ -261,6 +254,14 @@ def __init__(
force_total_order=self._strictly_ordered,
metrics=self._metrics,
)
self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor(
bqclient=self._clients_provider.bqclient,
bqstoragereadclient=self._clients_provider.bqstoragereadclient,
storage_manager=self._temp_storage_manager,
loader=self._loader,
strictly_ordered=self._strictly_ordered,
metrics=self._metrics,
)

def __del__(self):
"""Automatic cleanup of internal resources."""
Expand Down
30 changes: 29 additions & 1 deletion bigframes/session/bq_caching_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import bigframes.dtypes
import bigframes.exceptions as bfe
import bigframes.features
from bigframes.session import executor, local_scan_executor, read_api_execution
from bigframes.session import executor, loader, local_scan_executor, read_api_execution
import bigframes.session._io.bigquery as bq_io
import bigframes.session.metrics
import bigframes.session.planner
Expand All @@ -47,6 +47,7 @@
MAX_SUBTREE_FACTORINGS = 5
_MAX_CLUSTER_COLUMNS = 4
MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G
MAX_INLINE_BYTES = 5000


class BigQueryCachingExecutor(executor.Executor):
Expand All @@ -63,6 +64,7 @@ def __init__(
bqclient: bigquery.Client,
storage_manager: bigframes.session.temporary_storage.TemporaryStorageManager,
bqstoragereadclient: google.cloud.bigquery_storage_v1.BigQueryReadClient,
loader: loader.GbqDataLoader,
*,
strictly_ordered: bool = True,
metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None,
Expand All @@ -72,6 +74,7 @@ def __init__(
self.compiler: bigframes.core.compile.SQLCompiler = (
bigframes.core.compile.SQLCompiler()
)
self.loader = loader
self.strictly_ordered: bool = strictly_ordered
self._cached_executions: weakref.WeakKeyDictionary[
nodes.BigFrameNode, nodes.BigFrameNode
Expand Down Expand Up @@ -437,6 +440,31 @@ def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue):
if not did_cache:
return

def _upload_large_local_sources(self, root: nodes.BigFrameNode):
for leaf in root.unique_nodes():
if isinstance(leaf, nodes.ReadLocalNode):
if leaf.local_data_source.metadata.total_bytes > MAX_INLINE_BYTES:
self._cache_local_table(leaf)

def _cache_local_table(self, node: nodes.ReadLocalNode):
offsets_col = bigframes.core.guid.generate_guid()
# TODO: Best effort go through available upload paths
bq_data_source = self.loader.write_data(
node.local_data_source, offsets_col=offsets_col
)
scan_list = node.scan_list
if node.offsets_col is not None:
scan_list = scan_list.append(
offsets_col, bigframes.dtypes.INT_DTYPE, node.offsets_col
)
cache_node = nodes.CachedTableNode(
source=bq_data_source,
scan_list=scan_list,
table_session=self.loader._session,
original_node=node,
)
self._cached_executions[node] = cache_node

def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool:
# TODO: If query fails, retry with lower complexity limit
selection = tree_properties.select_cache_target(
Expand Down
120 changes: 71 additions & 49 deletions bigframes/session/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import pandas
import pyarrow as pa

from bigframes.core import guid, local_data, utils
from bigframes.core import guid, identifiers, local_data, nodes, ordering, utils
import bigframes.core as core
import bigframes.core.blocks as blocks
import bigframes.core.schema as schemata
Expand Down Expand Up @@ -183,35 +183,59 @@ def read_pandas(
)
managed_data = local_data.ManagedArrowTable.from_pandas(prepared_df)

block = blocks.Block(
self.read_managed_data(managed_data, method=method, api_name=api_name),
index_columns=idx_cols,
column_labels=pandas_dataframe.columns,
index_labels=pandas_dataframe.index.names,
)
return dataframe.DataFrame(block)

def read_managed_data(
self,
data: local_data.ManagedArrowTable,
method: Literal["load", "stream", "write"],
api_name: str,
) -> core.ArrayValue:
offsets_col = guid.generate_guid("upload_offsets_")
if method == "load":
array_value = self.load_data(managed_data, api_name=api_name)
gbq_source = self.load_data(
data, offsets_col=offsets_col, api_name=api_name
)
elif method == "stream":
array_value = self.stream_data(managed_data)
gbq_source = self.stream_data(data, offsets_col=offsets_col)
elif method == "write":
array_value = self.write_data(managed_data)
gbq_source = self.write_data(data, offsets_col=offsets_col)
else:
raise ValueError(f"Unsupported read method {method}")

block = blocks.Block(
array_value,
index_columns=idx_cols,
column_labels=pandas_dataframe.columns,
index_labels=pandas_dataframe.index.names,
return core.ArrayValue.from_bq_data_source(
source=gbq_source,
scan_list=nodes.ScanList(
tuple(
nodes.ScanItem(
identifiers.ColumnId(item.column), item.dtype, item.column
)
for item in data.schema.items
)
),
session=self._session,
)
return dataframe.DataFrame(block)

def load_data(
self, data: local_data.ManagedArrowTable, api_name: Optional[str] = None
) -> core.ArrayValue:
self,
data: local_data.ManagedArrowTable,
offsets_col: str,
api_name: Optional[str] = None,
) -> nodes.BigqueryDataSource:
"""Load managed data into bigquery"""
ordering_col = guid.generate_guid("load_offsets_")

# JSON support incomplete
for item in data.schema.items:
_validate_dtype_can_load(item.column, item.dtype)

schema_w_offsets = data.schema.append(
schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE)
schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE)
)
bq_schema = schema_w_offsets.to_bigquery(_LOAD_JOB_TYPE_OVERRIDES)

Expand All @@ -222,13 +246,13 @@ def load_data(
job_config.labels = {"bigframes-api": api_name}

load_table_destination = self._storage_manager.create_temp_table(
bq_schema, [ordering_col]
bq_schema, [offsets_col]
)

buffer = io.BytesIO()
data.to_parquet(
buffer,
offsets_col=ordering_col,
offsets_col=offsets_col,
geo_format="wkt",
duration_type="duration",
json_type="string",
Expand All @@ -240,23 +264,24 @@ def load_data(
self._start_generic_job(load_job)
# must get table metadata after load job for accurate metadata
destination_table = self._bqclient.get_table(load_table_destination)
return core.ArrayValue.from_table(
table=destination_table,
schema=schema_w_offsets,
session=self._session,
offsets_col=ordering_col,
n_rows=data.data.num_rows,
).drop_columns([ordering_col])
return nodes.BigqueryDataSource(
nodes.GbqTable.from_table(destination_table),
ordering=ordering.TotalOrdering.from_offset_col(offsets_col),
n_rows=destination_table.num_rows,
)

def stream_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
def stream_data(
self,
data: local_data.ManagedArrowTable,
offsets_col: str,
) -> nodes.BigqueryDataSource:
"""Load managed data into bigquery"""
ordering_col = guid.generate_guid("stream_offsets_")
schema_w_offsets = data.schema.append(
schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE)
schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE)
)
bq_schema = schema_w_offsets.to_bigquery(_STREAM_JOB_TYPE_OVERRIDES)
load_table_destination = self._storage_manager.create_temp_table(
bq_schema, [ordering_col]
bq_schema, [offsets_col]
)

rows = data.itertuples(
Expand All @@ -275,24 +300,23 @@ def stream_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
f"Problem loading at least one row from DataFrame: {errors}. {constants.FEEDBACK_LINK}"
)
destination_table = self._bqclient.get_table(load_table_destination)
return core.ArrayValue.from_table(
table=destination_table,
schema=schema_w_offsets,
session=self._session,
offsets_col=ordering_col,
n_rows=data.data.num_rows,
).drop_columns([ordering_col])
return nodes.BigqueryDataSource(
nodes.GbqTable.from_table(destination_table),
ordering=ordering.TotalOrdering.from_offset_col(offsets_col),
n_rows=destination_table.num_rows,
)

def write_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
def write_data(
self,
data: local_data.ManagedArrowTable,
offsets_col: str,
) -> nodes.BigqueryDataSource:
"""Load managed data into bigquery"""
ordering_col = guid.generate_guid("stream_offsets_")
schema_w_offsets = data.schema.append(
schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE)
schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE)
)
bq_schema = schema_w_offsets.to_bigquery(_STREAM_JOB_TYPE_OVERRIDES)
bq_table_ref = self._storage_manager.create_temp_table(
bq_schema, [ordering_col]
)
bq_table_ref = self._storage_manager.create_temp_table(bq_schema, [offsets_col])

requested_stream = bq_storage_types.stream.WriteStream()
requested_stream.type_ = bq_storage_types.stream.WriteStream.Type.COMMITTED # type: ignore
Expand All @@ -304,7 +328,7 @@ def write_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:

def request_gen() -> Generator[bq_storage_types.AppendRowsRequest, None, None]:
schema, batches = data.to_arrow(
offsets_col=ordering_col, duration_type="int"
offsets_col=offsets_col, duration_type="int"
)
offset = 0
for batch in batches:
Expand All @@ -330,13 +354,11 @@ def request_gen() -> Generator[bq_storage_types.AppendRowsRequest, None, None]:
assert response.row_count == data.data.num_rows

destination_table = self._bqclient.get_table(bq_table_ref)
return core.ArrayValue.from_table(
table=destination_table,
schema=schema_w_offsets,
session=self._session,
offsets_col=ordering_col,
n_rows=data.data.num_rows,
).drop_columns([ordering_col])
return nodes.BigqueryDataSource(
nodes.GbqTable.from_table(destination_table),
ordering=ordering.TotalOrdering.from_offset_col(offsets_col),
n_rows=destination_table.num_rows,
)

def _start_generic_job(self, job: formatting_helpers.GenericJob):
if bigframes.options.display.progress_bar is not None:
Expand Down Expand Up @@ -533,7 +555,7 @@ def read_gbq_table(
if not primary_key:
array_value = array_value.order_by(
[
bigframes.core.ordering.OrderingExpression(
ordering.OrderingExpression(
bigframes.operations.RowKey().as_expr(
*(id for id in array_value.column_ids)
),
Expand Down