Skip to content

Commit dc90aea

Browse files
authored
Add lock in prefill and generate to prevent starvation (#126)
add lock for prefill and generate to prevent starvation
1 parent 97aaeae commit dc90aea

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

jetstream_pt/ray_engine.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import defaultdict
2+
import threading
23
from typing import Any, Iterable, Optional, Union, Tuple, List
34

45
import numpy as np
@@ -38,6 +39,8 @@ def __init__(
3839
self.batch_size = batch_size
3940
self.is_disaggregated = is_disaggregated
4041
self.pod_slice_name = pod_slice_name
42+
if not self.is_disaggregated:
43+
self._lock = threading.Lock()
4144

4245
# pylint: disable-next=all
4346
def load_params(self) -> Params:
@@ -66,6 +69,31 @@ def prefill(
6669
existing_prefix: Optional[Prefix] = None,
6770
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
6871
true_length: int,
72+
) -> Prefix:
73+
if self.is_disaggregated:
74+
return self.prefill_impl(
75+
params=params,
76+
existing_prefix=existing_prefix,
77+
padded_tokens=padded_tokens,
78+
true_length=true_length,
79+
)
80+
81+
with self._lock:
82+
return self.prefill_impl(
83+
params=params,
84+
existing_prefix=existing_prefix,
85+
padded_tokens=padded_tokens,
86+
true_length=true_length,
87+
)
88+
89+
# pylint: disable-next=all
90+
def prefill_impl(
91+
self,
92+
*,
93+
params: Any, # Weights
94+
existing_prefix: Optional[Prefix] = None,
95+
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
96+
true_length: int,
6997
) -> Prefix:
7098
all_outputs = []
7199
for worker in self.engine_workers:
@@ -116,6 +144,15 @@ def insert(
116144

117145
def generate(
118146
self, params: Any, decode_state: DecodeState
147+
) -> tuple[None, engine_api.ResultTokens]:
148+
if self.is_disaggregated:
149+
return self.generate_impl(params=params, decode_state=decode_state)
150+
with self._lock:
151+
return self.generate_impl(params=params, decode_state=decode_state)
152+
153+
# pylint: disable-next=all
154+
def generate_impl(
155+
self, params: Any, decode_state: DecodeState
119156
) -> tuple[None, engine_api.ResultTokens]:
120157
all_outputs = []
121158
for worker in self.engine_workers:

0 commit comments

Comments
 (0)