Skip to content
38 changes: 33 additions & 5 deletions python/ray/data/_internal/execution/operators/hash_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
)
from ray.data._internal.execution.operators.sub_progress import SubProgressBarMixin
from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data._internal.output_buffer import BlockOutputBuffer, OutputBlockSizeOption
from ray.data._internal.stats import OpRuntimeMetrics
from ray.data._internal.table_block import TableBlockAccessor
from ray.data._internal.util import GiB, MiB
Expand Down Expand Up @@ -1335,7 +1336,12 @@ def start(self):

aggregator = HashShuffleAggregator.options(
**self._aggregator_ray_remote_args
).remote(aggregator_id, target_partition_ids, self._aggregation_factory_ref)
).remote(
aggregator_id,
target_partition_ids,
self._aggregation_factory_ref,
self._data_context,
)

self._aggregators.append(aggregator)

Expand Down Expand Up @@ -1547,11 +1553,21 @@ def __init__(
aggregator_id: int,
target_partition_ids: List[int],
agg_factory: StatefulShuffleAggregationFactory,
data_context: DataContext,
):
self._lock = threading.Lock()
self._agg: StatefulShuffleAggregation = agg_factory(
aggregator_id, target_partition_ids
)
# One buffer per partition to enable concurrent finalization
self._output_buffers: Dict[int, BlockOutputBuffer] = {
partition_id: BlockOutputBuffer(
output_block_size_option=OutputBlockSizeOption(
target_max_block_size=data_context.target_max_block_size
)
)
for partition_id in target_partition_ids
}

def submit(self, input_seq_id: int, partition_id: int, partition_shard: Block):
with self._lock:
Expand All @@ -1560,17 +1576,29 @@ def submit(self, input_seq_id: int, partition_id: int, partition_shard: Block):
def finalize(
self, partition_id: int
) -> AsyncGenerator[Union[Block, "BlockMetadataWithSchema"], None]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Out-of-scope for this PR, but I think this is a regular generator, not async

Suggested change
) -> AsyncGenerator[Union[Block, "BlockMetadataWithSchema"], None]:
) -> Generator[Union[Block, "BlockMetadataWithSchema"], None]:

exec_stats_builder = BlockExecStats.builder()

with self._lock:
# Finalize given partition id
exec_stats_builder = BlockExecStats.builder()
block = self._agg.finalize(partition_id)
exec_stats = exec_stats_builder.build()
# Clear any remaining state (to release resources)
self._agg.clear(partition_id)

# TODO break down blocks to target size
yield block
yield BlockMetadataWithSchema.from_block(block, stats=exec_stats)
# No lock needed - each partition has its own buffer
output_buffer = self._output_buffers[partition_id]

output_buffer.add_block(block)
while output_buffer.has_next():
block = output_buffer.next()
yield block
yield BlockMetadataWithSchema.from_block(block, stats=exec_stats)

output_buffer.finalize()
while output_buffer.has_next():
block = output_buffer.next()
yield block
yield BlockMetadataWithSchema.from_block(block, stats=exec_stats)


def _get_total_cluster_resources() -> ExecutionResources:
Expand Down