Skip to content

Add lock in prefill and generate to prevent starvation #126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions jetstream_pt/ray_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
import threading
from typing import Any, Iterable, Optional, Union, Tuple, List

import numpy as np
Expand Down Expand Up @@ -38,6 +39,8 @@ def __init__(
self.batch_size = batch_size
self.is_disaggregated = is_disaggregated
self.pod_slice_name = pod_slice_name
if not self.is_disaggregated:
self._lock = threading.Lock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets use RLock. For lock if the thread already have the lock and it try to lock again it will deadlock

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use lock (without R) should be the right way. The jetstream may have the issue before (same thread call prefill or decode concurrently), there is no recursion in prefill or decode code.

If a thread call the both prefill and decode or itself method multiple time without unlock, then there is a bug on jetstream.


# pylint: disable-next=all
def load_params(self) -> Params:
Expand Down Expand Up @@ -66,6 +69,31 @@ def prefill(
existing_prefix: Optional[Prefix] = None,
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
true_length: int,
) -> Prefix:
if self.is_disaggregated:
return self.prefill_impl(
params=params,
existing_prefix=existing_prefix,
padded_tokens=padded_tokens,
true_length=true_length,
)

with self._lock:
return self.prefill_impl(
params=params,
existing_prefix=existing_prefix,
padded_tokens=padded_tokens,
true_length=true_length,
)

# pylint: disable-next=all
def prefill_impl(
self,
*,
params: Any, # Weights
existing_prefix: Optional[Prefix] = None,
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
true_length: int,
) -> Prefix:
all_outputs = []
for worker in self.engine_workers:
Expand Down Expand Up @@ -116,6 +144,15 @@ def insert(

def generate(
self, params: Any, decode_state: DecodeState
) -> tuple[None, engine_api.ResultTokens]:
if self.is_disaggregated:
return self.generate_impl(params=params, decode_state=decode_state)
with self._lock:
return self.generate_impl(params=params, decode_state=decode_state)

# pylint: disable-next=all
def generate_impl(
self, params: Any, decode_state: DecodeState
) -> tuple[None, engine_api.ResultTokens]:
all_outputs = []
for worker in self.engine_workers:
Expand Down
Loading