Skip to content

Commit b18aeaa

Browse files
committed
feat: support inability to yeild cpu for loop when it's not using Tokio MPSC (RecordBatchReceiverStream)
1 parent 2d12bf6 commit b18aeaa

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use datafusion_physical_expr::{
4747
PhysicalSortRequirement,
4848
};
4949

50+
use crate::aggregates::no_grouping::YieldStream;
5051
use datafusion_physical_expr_common::physical_expr::fmt_sql;
5152
use itertools::Itertools;
5253

@@ -983,8 +984,9 @@ impl ExecutionPlan for AggregateExec {
983984
partition: usize,
984985
context: Arc<TaskContext>,
985986
) -> Result<SendableRecordBatchStream> {
986-
self.execute_typed(partition, context)
987-
.map(|stream| stream.into())
987+
let raw_stream = self.execute_typed(partition, context)?.into();
988+
let wrapped = Box::pin(YieldStream::new(raw_stream));
989+
Ok(wrapped)
988990
}
989991

990992
fn metrics(&self) -> Option<MetricsSet> {

datafusion/physical-plan/src/aggregates/no_grouping.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ impl AggregateStream {
7676

7777
let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
7878
let input = agg.input.execute(partition, Arc::clone(&context))?;
79+
let input = Box::pin(YieldStream::new(input)) as SendableRecordBatchStream;
7980

8081
let aggregate_expressions = aggregate_expressions(&agg.aggr_expr, &agg.mode, 0)?;
8182
let filter_expressions = match agg.mode {
@@ -170,6 +171,55 @@ impl AggregateStream {
170171
}
171172
}
172173

174+
/// A stream that yields batches of data, yielding control back to the executor
175+
pub struct YieldStream {
176+
inner: SendableRecordBatchStream,
177+
batches_processed: usize,
178+
}
179+
180+
impl YieldStream {
181+
pub fn new(inner: SendableRecordBatchStream) -> Self {
182+
Self {
183+
inner,
184+
batches_processed: 0,
185+
}
186+
}
187+
}
188+
189+
// Stream<Item = Result<RecordBatch>> to poll_next_unpin
190+
impl Stream for YieldStream {
191+
type Item = Result<RecordBatch>;
192+
193+
fn poll_next(
194+
mut self: std::pin::Pin<&mut Self>,
195+
cx: &mut Context<'_>,
196+
) -> Poll<Option<Self::Item>> {
197+
const YIELD_BATCHES: usize = 64;
198+
let this = &mut *self;
199+
200+
match this.inner.poll_next_unpin(cx) {
201+
Poll::Ready(Some(Ok(batch))) => {
202+
this.batches_processed += 1;
203+
if this.batches_processed >= YIELD_BATCHES {
204+
this.batches_processed = 0;
205+
cx.waker().wake_by_ref();
206+
Poll::Pending
207+
} else {
208+
Poll::Ready(Some(Ok(batch)))
209+
}
210+
}
211+
other => other,
212+
}
213+
}
214+
}
215+
216+
// RecordBatchStream schema()
217+
impl RecordBatchStream for YieldStream {
218+
fn schema(&self) -> Arc<arrow_schema::Schema> {
219+
self.inner.schema()
220+
}
221+
}
222+
173223
impl Stream for AggregateStream {
174224
type Item = Result<RecordBatch>;
175225

parquet-testing

Submodule parquet-testing updated 67 files

0 commit comments

Comments
 (0)