|
1 | 1 | from collections import defaultdict
|
| 2 | +import threading |
2 | 3 | from typing import Any, Iterable, Optional, Union, Tuple, List
|
3 | 4 |
|
4 | 5 | import numpy as np
|
@@ -38,6 +39,8 @@ def __init__(
|
38 | 39 | self.batch_size = batch_size
|
39 | 40 | self.is_disaggregated = is_disaggregated
|
40 | 41 | self.pod_slice_name = pod_slice_name
|
| 42 | + if not self.is_disaggregated: |
| 43 | + self._lock = threading.Lock() |
41 | 44 |
|
42 | 45 | # pylint: disable-next=all
|
43 | 46 | def load_params(self) -> Params:
|
@@ -66,6 +69,31 @@ def prefill(
|
66 | 69 | existing_prefix: Optional[Prefix] = None,
|
67 | 70 | padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
|
68 | 71 | true_length: int,
|
| 72 | + ) -> Prefix: |
| 73 | + if self.is_disaggregated: |
| 74 | + return self.prefill_impl( |
| 75 | + params=params, |
| 76 | + existing_prefix=existing_prefix, |
| 77 | + padded_tokens=padded_tokens, |
| 78 | + true_length=true_length, |
| 79 | + ) |
| 80 | + |
| 81 | + with self._lock: |
| 82 | + return self.prefill_impl( |
| 83 | + params=params, |
| 84 | + existing_prefix=existing_prefix, |
| 85 | + padded_tokens=padded_tokens, |
| 86 | + true_length=true_length, |
| 87 | + ) |
| 88 | + |
| 89 | + # pylint: disable-next=all |
| 90 | + def prefill_impl( |
| 91 | + self, |
| 92 | + *, |
| 93 | + params: Any, # Weights |
| 94 | + existing_prefix: Optional[Prefix] = None, |
| 95 | + padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], |
| 96 | + true_length: int, |
69 | 97 | ) -> Prefix:
|
70 | 98 | all_outputs = []
|
71 | 99 | for worker in self.engine_workers:
|
@@ -116,6 +144,15 @@ def insert(
|
116 | 144 |
|
117 | 145 | def generate(
|
118 | 146 | self, params: Any, decode_state: DecodeState
|
| 147 | + ) -> tuple[None, engine_api.ResultTokens]: |
| 148 | + if self.is_disaggregated: |
| 149 | + return self.generate_impl(params=params, decode_state=decode_state) |
| 150 | + with self._lock: |
| 151 | + return self.generate_impl(params=params, decode_state=decode_state) |
| 152 | + |
| 153 | + # pylint: disable-next=all |
| 154 | + def generate_impl( |
| 155 | + self, params: Any, decode_state: DecodeState |
119 | 156 | ) -> tuple[None, engine_api.ResultTokens]:
|
120 | 157 | all_outputs = []
|
121 | 158 | for worker in self.engine_workers:
|
|
0 commit comments