Skip to content

feat: Add deferred data uploading #1720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion bigframes/core/array_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,17 @@ def from_table(
ordering=ordering,
n_rows=n_rows,
)
return cls.from_bq_data_source(source_def, scan_list, session)

@classmethod
def from_bq_data_source(
cls,
source: nodes.BigqueryDataSource,
scan_list: nodes.ScanList,
session: Session,
):
node = nodes.ReadTableNode(
source=source_def,
source=source,
scan_list=scan_list,
table_session=session,
)
Expand Down
20 changes: 19 additions & 1 deletion bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,9 @@ class ScanItem(typing.NamedTuple):
def with_id(self, id: identifiers.ColumnId) -> ScanItem:
return ScanItem(id, self.dtype, self.source_id)

def with_source_id(self, source_id: str) -> ScanItem:
return ScanItem(self.id, self.dtype, source_id)


@dataclasses.dataclass(frozen=True)
class ScanList:
Expand Down Expand Up @@ -614,16 +617,31 @@ def project(
result = ScanList((self.items[:1]))
return result

def remap_source_ids(
self,
mapping: Mapping[str, str],
) -> ScanList:
items = tuple(
item.with_source_id(mapping.get(item.source_id, item.source_id))
for item in self.items
)
return ScanList(items)

def append(
self, source_id: str, dtype: bigframes.dtypes.Dtype, id: identifiers.ColumnId
) -> ScanList:
return ScanList((*self.items, ScanItem(id, dtype, source_id)))


@dataclasses.dataclass(frozen=True, eq=False)
class ReadLocalNode(LeafNode):
# TODO: Track nullability for local data
local_data_source: local_data.ManagedArrowTable
# Mapping of local ids to bfet id.
scan_list: ScanList
session: bigframes.session.Session
# Offsets are generated only if this is non-null
offsets_col: Optional[identifiers.ColumnId] = None
session: typing.Optional[bigframes.session.Session] = None

@property
def fields(self) -> Sequence[Field]:
Expand Down
19 changes: 12 additions & 7 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,6 @@ def __init__(
self._temp_storage_manager = (
self._session_resource_manager or self._anon_dataset_manager
)
self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor(
bqclient=self._clients_provider.bqclient,
bqstoragereadclient=self._clients_provider.bqstoragereadclient,
storage_manager=self._temp_storage_manager,
strictly_ordered=self._strictly_ordered,
metrics=self._metrics,
)
self._loader = bigframes.session.loader.GbqDataLoader(
session=self,
bqclient=self._clients_provider.bqclient,
Expand All @@ -263,6 +256,14 @@ def __init__(
force_total_order=self._strictly_ordered,
metrics=self._metrics,
)
self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor(
bqclient=self._clients_provider.bqclient,
bqstoragereadclient=self._clients_provider.bqstoragereadclient,
loader=self._loader,
storage_manager=self._temp_storage_manager,
strictly_ordered=self._strictly_ordered,
metrics=self._metrics,
)

def __del__(self):
"""Automatic cleanup of internal resources."""
Expand Down Expand Up @@ -929,6 +930,10 @@ def _read_pandas(
return self._loader.read_pandas(
pandas_dataframe, method="write", api_name=api_name
)
elif write_engine == "_deferred":
import bigframes.dataframe as dataframe

return dataframe.DataFrame(blocks.Block.from_local(pandas_dataframe, self))
else:
raise ValueError(f"Got unexpected write_engine '{write_engine}'")

Expand Down
76 changes: 72 additions & 4 deletions bigframes/session/bq_caching_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dataclasses
import math
import os
import threading
from typing import cast, Literal, Mapping, Optional, Sequence, Tuple, Union
import warnings
import weakref
Expand All @@ -28,15 +29,15 @@
import google.cloud.bigquery_storage_v1

import bigframes.core
from bigframes.core import compile, rewrite
from bigframes.core import compile, local_data, rewrite
import bigframes.core.guid
import bigframes.core.nodes as nodes
import bigframes.core.ordering as order
import bigframes.core.tree_properties as tree_properties
import bigframes.dtypes
import bigframes.exceptions as bfe
import bigframes.features
from bigframes.session import executor, local_scan_executor, read_api_execution
from bigframes.session import executor, loader, local_scan_executor, read_api_execution
import bigframes.session._io.bigquery as bq_io
import bigframes.session.metrics
import bigframes.session.planner
Expand Down Expand Up @@ -65,12 +66,19 @@ def _get_default_output_spec() -> OutputSpec:
)


SourceIdMapping = Mapping[str, str]


class ExecutionCache:
def __init__(self):
# current assumption is only 1 cache of a given node
# in future, might have multiple caches, with different layout, localities
self._cached_executions: weakref.WeakKeyDictionary[
nodes.BigFrameNode, nodes.BigFrameNode
nodes.BigFrameNode, nodes.CachedTableNode
] = weakref.WeakKeyDictionary()
self._uploaded_local_data: weakref.WeakKeyDictionary[
local_data.ManagedArrowTable,
tuple[nodes.BigqueryDataSource, SourceIdMapping],
] = weakref.WeakKeyDictionary()

@property
Expand Down Expand Up @@ -103,6 +111,17 @@ def cache_results_table(
assert original_root.schema == cached_replacement.schema
self._cached_executions[original_root] = cached_replacement

def cache_remote_replacement(
self,
local_data: local_data.ManagedArrowTable,
bq_data: nodes.BigqueryDataSource,
):
mapping = {
local_data.schema.items[i].column: bq_data.table.physical_schema[i].name
for i in range(len(local_data.schema))
}
self._uploaded_local_data[local_data] = (bq_data, mapping)


class BigQueryCachingExecutor(executor.Executor):
"""Computes BigFrames values using BigQuery Engine.
Expand All @@ -118,6 +137,7 @@ def __init__(
bqclient: bigquery.Client,
storage_manager: bigframes.session.temporary_storage.TemporaryStorageManager,
bqstoragereadclient: google.cloud.bigquery_storage_v1.BigQueryReadClient,
loader: loader.GbqDataLoader,
*,
strictly_ordered: bool = True,
metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None,
Expand All @@ -127,6 +147,7 @@ def __init__(
self.strictly_ordered: bool = strictly_ordered
self.cache: ExecutionCache = ExecutionCache()
self.metrics = metrics
self.loader = loader
self.bqstoragereadclient = bqstoragereadclient
# Simple left-to-right precedence for now
self._semi_executors = (
Expand All @@ -136,6 +157,7 @@ def __init__(
),
local_scan_executor.LocalScanExecutor(),
)
self._upload_lock = threading.Lock()

def to_sql(
self,
Expand All @@ -147,6 +169,7 @@ def to_sql(
if offset_column:
array_value, _ = array_value.promote_offsets()
node = self.logical_plan(array_value.node) if enable_cache else array_value.node
node = self._substitute_large_local_sources(node)
compiled = compile.compile_sql(compile.CompileRequest(node, sort_rows=ordered))
return compiled.sql

Expand Down Expand Up @@ -378,6 +401,7 @@ def _cache_with_cluster_cols(
):
"""Executes the query and uses the resulting table to rewrite future executions."""
plan = self.logical_plan(array_value.node)
plan = self._substitute_large_local_sources(plan)
compiled = compile.compile_sql(
compile.CompileRequest(
plan, sort_rows=False, materialize_all_order_keys=True
Expand All @@ -398,7 +422,7 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
w_offsets, offset_column = array_value.promote_offsets()
compiled = compile.compile_sql(
compile.CompileRequest(
self.logical_plan(w_offsets.node),
self.logical_plan(self._substitute_large_local_sources(w_offsets.node)),
sort_rows=False,
)
)
Expand Down Expand Up @@ -509,6 +533,48 @@ def _validate_result_schema(
f"This error should only occur while testing. Ibis schema: {ibis_schema} does not match actual schema: {actual_schema}"
)

def _substitute_large_local_sources(self, original_root: nodes.BigFrameNode):
"""
Replace large local sources with the uploaded version of those datasources.
"""
# Step 1: Upload all previously un-uploaded data
for leaf in original_root.unique_nodes():
if isinstance(leaf, nodes.ReadLocalNode):
if leaf.local_data_source.metadata.total_bytes > 5000:
self._upload_local_data(leaf.local_data_source)

# Step 2: Replace local scans with remote scans
def map_local_scans(node: nodes.BigFrameNode):
if not isinstance(node, nodes.ReadLocalNode):
return node
if node.local_data_source not in self.cache._uploaded_local_data:
return node
bq_source, source_mapping = self.cache._uploaded_local_data[
node.local_data_source
]
scan_list = node.scan_list.remap_source_ids(source_mapping)
if node.offsets_col is not None:
scan_list = scan_list.append(
bq_source.table.physical_schema[-1].name,
bigframes.dtypes.INT_DTYPE,
node.offsets_col,
)
return nodes.ReadTableNode(bq_source, scan_list, node.session)

return original_root.bottom_up(map_local_scans)

def _upload_local_data(self, local_table: local_data.ManagedArrowTable):
if local_table in self.cache._uploaded_local_data:
return
# Lock prevents concurrent repeated work, but slows things down.
# Might be better as a queue and a worker thread
with self._upload_lock:
if local_table not in self.cache._uploaded_local_data:
uploaded = self.loader.load_data(
local_table, bigframes.core.guid.generate_guid()
)
self.cache.cache_remote_replacement(local_table, uploaded)

def _execute_plan(
self,
plan: nodes.BigFrameNode,
Expand Down Expand Up @@ -539,6 +605,8 @@ def _execute_plan(
# Use explicit destination to avoid 10GB limit of temporary table
if destination_table is not None:
job_config.destination = destination_table

plan = self._substitute_large_local_sources(plan)
compiled = compile.compile_sql(
compile.CompileRequest(plan, sort_rows=ordered, peek_count=peek)
)
Expand Down
Loading