Skip to content

Commit df3dd96

Browse files
authored
[data][train] Fix deadlocks caused by streaming_split (#42601)
Fix a deadlock issue for training jobs. The issue happens in the following situation: * The output blocks of `streaming_split` are assigned to multiple splits (`output_split_idx`). * When one split has finished reading all blocks, it won't stop the iteration until all the other splits have all finished, because of [this](https://github.com/ray-project/ray/blob/fae8d2ff814377eb027d63d73a23d5c5bf3b02bd/python/ray/data/_internal/execution/streaming_executor_state.py#L288). * This is usually fine. But when the unfinished splits are waiting for the finished splits (e.g., there is a gradient synchronization), there will be a dead lock due to circular dependencies. This PR makes the finished splits can finish iteration immediately without waiting for others. --------- Signed-off-by: Hao Chen <chenh1024@gmail.com>
1 parent f6da38f commit df3dd96

File tree

2 files changed

+136
-51
lines changed

2 files changed

+136
-51
lines changed

python/ray/data/_internal/execution/streaming_executor_state.py

Lines changed: 71 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -106,62 +106,94 @@ class DownstreamMemoryInfo:
106106
object_store_memory: float
107107

108108

109-
class RefBundleDeque(deque):
110-
"""Thread-safe wrapper around collections.deque that stores current stats."""
109+
class OpBufferQueue:
110+
"""A FIFO queue to buffer RefBundles between upstream and downstream operators.
111+
This class is thread-safe.
112+
"""
111113

112114
def __init__(self):
113115
self._memory_usage = 0
114116
self._num_blocks = 0
117+
self._queue = deque()
118+
self._num_per_split = defaultdict(int)
115119
self._lock = threading.Lock()
116120
super().__init__()
117121

118122
@property
119123
def memory_usage(self) -> int:
124+
"""The total memory usage of the queue in bytes."""
120125
with self._lock:
121126
return self._memory_usage
122127

123128
@property
124129
def num_blocks(self) -> int:
130+
"""The total number of blocks in the queue."""
125131
with self._lock:
126132
return self._num_blocks
127133

128-
def append(self, ref: RefBundle):
129-
with self._lock:
130-
self._memory_usage += ref.size_bytes()
131-
self._num_blocks += len(ref.blocks)
132-
super().append(ref)
134+
def __len__(self):
135+
return len(self._queue)
133136

134-
def appendleft(self, ref: RefBundle):
135-
with self._lock:
136-
self._memory_usage += ref.size_bytes()
137-
self._num_blocks += len(ref.blocks)
138-
super().appendleft(ref)
137+
def has_next(self, output_split_idx: Optional[int] = None) -> bool:
138+
"""Whether next RefBundle is available.
139139
140-
def pop(self) -> RefBundle:
141-
ref = super().pop()
142-
with self._lock:
143-
self._memory_usage -= ref.size_bytes()
144-
self._num_blocks -= len(ref.blocks)
145-
return ref
140+
Args:
141+
output_split_idx: If specified, only check ref bundles with the
142+
given output split.
143+
"""
144+
if output_split_idx is None:
145+
return len(self._queue) > 0
146+
else:
147+
with self._lock:
148+
return self._num_per_split[output_split_idx] > 0
146149

147-
def popleft(self) -> RefBundle:
148-
ref = super().popleft()
150+
def append(self, ref: RefBundle):
151+
"""Append a RefBundle to the queue."""
152+
self._queue.append(ref)
149153
with self._lock:
150-
self._memory_usage -= ref.size_bytes()
151-
self._num_blocks -= len(ref.blocks)
152-
return ref
153-
154-
def remove(self, ref: RefBundle):
155-
super().remove(ref)
154+
self._memory_usage += ref.size_bytes()
155+
self._num_blocks += len(ref.blocks)
156+
if ref.output_split_idx is not None:
157+
self._num_per_split[ref.output_split_idx] += 1
158+
159+
def pop(self, output_split_idx: Optional[int] = None) -> Optional[RefBundle]:
160+
"""Pop a RefBundle from the queue.
161+
Args:
162+
output_split_idx: If specified, only pop a RefBundle
163+
with the given output split.
164+
Returns:
165+
A RefBundle if available, otherwise None.
166+
"""
167+
ret = None
168+
if output_split_idx is None:
169+
try:
170+
ret = self._queue.popleft()
171+
except IndexError:
172+
pass
173+
else:
174+
# TODO(hchen): Index the queue by output_split_idx to
175+
# avoid linear scan.
176+
for i in range(len(self._queue)):
177+
ref = self._queue[i]
178+
if ref.output_split_idx == output_split_idx:
179+
ret = ref
180+
del self._queue[i]
181+
break
182+
if ret is None:
183+
return None
156184
with self._lock:
157-
self._memory_usage -= ref.size_bytes()
158-
self._num_blocks -= len(ref.blocks)
185+
self._memory_usage -= ret.size_bytes()
186+
self._num_blocks -= len(ret.blocks)
187+
if ret.output_split_idx is not None:
188+
self._num_per_split[ret.output_split_idx] -= 1
189+
return ret
159190

160191
def clear(self):
161-
super().clear()
162192
with self._lock:
193+
self._queue.clear()
163194
self._memory_usage = 0
164195
self._num_blocks = 0
196+
self._num_per_split.clear()
165197

166198

167199
class OpState:
@@ -174,17 +206,17 @@ class OpState:
174206
operator queues to be shared across threads.
175207
"""
176208

177-
def __init__(self, op: PhysicalOperator, inqueues: List[RefBundleDeque]):
209+
def __init__(self, op: PhysicalOperator, inqueues: List[OpBufferQueue]):
178210
# Each inqueue is connected to another operator's outqueue.
179211
assert len(inqueues) == len(op.input_dependencies), (op, inqueues)
180-
self.inqueues: List[RefBundleDeque] = inqueues
212+
self.inqueues: List[OpBufferQueue] = inqueues
181213
# The outqueue is connected to another operator's inqueue (they physically
182214
# share the same Python list reference).
183215
#
184216
# Note: this queue is also accessed concurrently from the consumer thread.
185217
# (in addition to the streaming executor thread). Hence, it must be a
186218
# thread-safe type such as `deque`.
187-
self.outqueue: RefBundleDeque = RefBundleDeque()
219+
self.outqueue: OpBufferQueue = OpBufferQueue()
188220
self.op = op
189221
self.progress_bar = None
190222
self.num_completed_tasks = 0
@@ -266,8 +298,9 @@ def summary_str(self) -> str:
266298
def dispatch_next_task(self) -> None:
267299
"""Move a bundle from the operator inqueue to the operator itself."""
268300
for i, inqueue in enumerate(self.inqueues):
269-
if inqueue:
270-
self.op.add_input(inqueue.popleft(), input_index=i)
301+
ref = inqueue.pop()
302+
if ref is not None:
303+
self.op.add_input(ref, input_index=i)
271304
return
272305
assert False, "Nothing to dispatch"
273306

@@ -285,24 +318,11 @@ def get_output_blocking(self, output_split_idx: Optional[int]) -> RefBundle:
285318
# Check if StreamingExecutor has caught an exception or is done execution.
286319
if self._exception is not None:
287320
raise self._exception
288-
elif self._finished and len(self.outqueue) == 0:
321+
elif self._finished and not self.outqueue.has_next(output_split_idx):
289322
raise StopIteration()
290-
try:
291-
# Non-split output case.
292-
if output_split_idx is None:
293-
return self.outqueue.popleft()
294-
295-
# Scan the queue and look for outputs tagged for the given index.
296-
for i in range(len(self.outqueue)):
297-
bundle = self.outqueue[i]
298-
if bundle.output_split_idx == output_split_idx:
299-
self.outqueue.remove(bundle)
300-
return bundle
301-
302-
# Didn't find any outputs matching this index, repeat the loop until
303-
# we find one or hit a None.
304-
except IndexError:
305-
pass
323+
ref = self.outqueue.pop(output_split_idx)
324+
if ref is not None:
325+
return ref
306326
time.sleep(0.01)
307327

308328
def inqueue_memory_usage(self) -> int:

python/ray/data/tests/test_streaming_integration.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,71 @@ def consume(x, times):
276276
)
277277

278278

279+
def test_streaming_split_independent_finish(ray_start_10_cpus_shared):
280+
"""Test that stream_split iterators can finish independently without
281+
waiting for other iterators to finish. Otherwise, this would cause
282+
deadlocks.
283+
"""
284+
num_blocks_per_split = 10
285+
num_splits = 2
286+
ds = ray.data.range(
287+
num_splits * num_blocks_per_split,
288+
parallelism=num_splits * num_blocks_per_split,
289+
)
290+
(
291+
i1,
292+
i2,
293+
) = ds.streaming_split(num_splits, equal=True)
294+
295+
@ray.remote(max_concurrency=2)
296+
class SignalActor:
297+
def __init__(self):
298+
self._event = threading.Event()
299+
300+
def wait(self):
301+
self._event.wait()
302+
303+
def set(self):
304+
self._event.set()
305+
306+
@ray.remote
307+
class Consumer:
308+
def consume(self, it, signal_actor, split_index):
309+
for i, _ in enumerate(it.iter_batches(batch_size=None, prefetch_batches=0)):
310+
if i == num_blocks_per_split // 2 and split_index == 0:
311+
# The first consumer waits for the second consumer to
312+
# finish first in the middle of the iteration.
313+
print("before wait")
314+
ray.get(signal_actor.wait.remote())
315+
print("after wait")
316+
if split_index == 1:
317+
# The second consumer sends a signal to unblock the
318+
# first consumer. It should finish the iteration independently.
319+
# Otherwise, there will be a deadlock.
320+
print("before set")
321+
# Sleep some time to make sure the other
322+
# consume calls wait first.
323+
time.sleep(2)
324+
ray.get(signal_actor.set.remote())
325+
print("after set")
326+
pass
327+
328+
signal_actor = SignalActor.remote()
329+
consumer1 = Consumer.remote()
330+
consumer2 = Consumer.remote()
331+
332+
ready, _ = ray.wait(
333+
[
334+
consumer1.consume.remote(i1, signal_actor, 0),
335+
consumer2.consume.remote(i2, signal_actor, 1),
336+
],
337+
num_returns=2,
338+
timeout=20,
339+
)
340+
341+
assert len(ready) == 2
342+
343+
279344
@pytest.mark.skip(
280345
reason="Incomplete implementation of _validate_dag causes other errors, so we "
281346
"remove DAG validation for now; see https://github.com/ray-project/ray/pull/37829"

0 commit comments

Comments
 (0)