Skip to content

Commit ee2c40c

Browse files
authored
feat: include local data bytes in the dry run report when available (#2185)
* feat: include local data bytes in the dry run report when available * fix test
1 parent 2c50310 commit ee2c40c

File tree

3 files changed

+53
-3
lines changed

3 files changed

+53
-3
lines changed

bigframes/core/blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,7 @@ def _compute_dry_run(
967967
}
968968

969969
dry_run_stats = dry_runs.get_query_stats_with_dtypes(
970-
query_job, column_dtypes, self.index.dtypes
970+
query_job, column_dtypes, self.index.dtypes, self.expr.node
971971
)
972972
return dry_run_stats, query_job
973973

bigframes/session/dry_runs.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pandas
2121

2222
from bigframes import dtypes
23+
from bigframes.core import bigframe_node, nodes
2324

2425

2526
def get_table_stats(table: bigquery.Table) -> pandas.Series:
@@ -86,13 +87,26 @@ def get_query_stats_with_dtypes(
8687
query_job: bigquery.QueryJob,
8788
column_dtypes: Dict[str, dtypes.Dtype],
8889
index_dtypes: Sequence[dtypes.Dtype],
90+
expr_root: bigframe_node.BigFrameNode | None = None,
8991
) -> pandas.Series:
92+
"""
93+
Returns important stats from the query job as a Pandas Series. The dtypes information is added too.
94+
95+
Args:
96+
expr_root (Optional):
97+
The root of the expression tree that may contain local data, whose size is added to the
98+
total bytes count if available.
99+
100+
"""
90101
index = ["columnCount", "columnDtypes", "indexLevel", "indexDtypes"]
91102
values = [len(column_dtypes), column_dtypes, len(index_dtypes), index_dtypes]
92103

93104
s = pandas.Series(values, index=index)
94105

95-
return pandas.concat([s, get_query_stats(query_job)])
106+
result = pandas.concat([s, get_query_stats(query_job)])
107+
if expr_root is not None:
108+
result["totalBytesProcessed"] += get_local_bytes(expr_root)
109+
return result
96110

97111

98112
def get_query_stats(
@@ -145,4 +159,24 @@ def get_query_stats(
145159
else None
146160
)
147161

148-
return pandas.Series(values, index=index)
162+
result = pandas.Series(values, index=index)
163+
if result["totalBytesProcessed"] is None:
164+
result["totalBytesProcessed"] = 0
165+
else:
166+
result["totalBytesProcessed"] = int(result["totalBytesProcessed"])
167+
168+
return result
169+
170+
171+
def get_local_bytes(root: bigframe_node.BigFrameNode) -> int:
172+
def get_total_bytes(
173+
root: bigframe_node.BigFrameNode, child_results: tuple[int, ...]
174+
) -> int:
175+
child_bytes = sum(child_results)
176+
177+
if isinstance(root, nodes.ReadLocalNode):
178+
return child_bytes + root.local_data_source.data.get_total_buffer_size()
179+
180+
return child_bytes
181+
182+
return root.reduce_up(get_total_bytes)

tests/system/small/test_session.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2173,6 +2173,22 @@ def test_read_gbq_query_dry_run(scalars_table_id, session):
21732173
_assert_query_dry_run_stats_are_valid(result)
21742174

21752175

2176+
def test_block_dry_run_includes_local_data(session):
2177+
df1 = bigframes.dataframe.DataFrame({"col_1": [1, 2, 3]}, session=session)
2178+
df2 = bigframes.dataframe.DataFrame({"col_2": [1, 2, 3]}, session=session)
2179+
2180+
result = df1.merge(df2, how="cross").to_pandas(dry_run=True)
2181+
2182+
assert isinstance(result, pd.Series)
2183+
_assert_query_dry_run_stats_are_valid(result)
2184+
assert result["totalBytesProcessed"] > 0
2185+
assert (
2186+
df1.to_pandas(dry_run=True)["totalBytesProcessed"]
2187+
+ df2.to_pandas(dry_run=True)["totalBytesProcessed"]
2188+
== result["totalBytesProcessed"]
2189+
)
2190+
2191+
21762192
def _assert_query_dry_run_stats_are_valid(result: pd.Series):
21772193
expected_index = pd.Index(
21782194
[

0 commit comments

Comments
 (0)