4
4
import triton
5
5
import triton .language as tl
6
6
7
- from vllm .config import VllmConfig , set_current_vllm_config
7
+ from vllm .config import CompilationLevel , VllmConfig , set_current_vllm_config
8
8
from vllm .forward_context import set_forward_context
9
9
from vllm .logger import init_logger
10
10
from vllm .model_executor .model_loader .loader import get_model_loader
@@ -26,10 +26,41 @@ def __init__(
26
26
device : torch .device ,
27
27
):
28
28
self .vllm_config = vllm_config
29
+ self .method = self .vllm_config .speculative_config .method
29
30
self .num_speculative_tokens = (
30
31
vllm_config .speculative_config .num_speculative_tokens )
31
32
self .max_model_len = vllm_config .model_config .max_model_len
32
33
self .block_size = vllm_config .cache_config .block_size
34
+
35
+ self .dtype = vllm_config .model_config .dtype
36
+
37
+ self .max_num_tokens = vllm_config .scheduler_config \
38
+ .max_num_batched_tokens
39
+
40
+ self .hidden_size = vllm_config .model_config .get_hidden_size ()
41
+
42
+ # TODO: make eagle3 compatible with cudagraph
43
+ self .use_cuda_graph = self .method != 'eagle3' and \
44
+ (self .vllm_config .compilation_config .level
45
+ == CompilationLevel .PIECEWISE and
46
+ not self .vllm_config .model_config .enforce_eager )
47
+
48
+ self .cudagraph_batch_sizes = list (
49
+ reversed (
50
+ self .vllm_config .compilation_config .cudagraph_capture_sizes ))
51
+
52
+ # persistent buffers for cuda graph
53
+ self .input_ids = torch .zeros (self .max_num_tokens ,
54
+ dtype = torch .int32 ,
55
+ device = device )
56
+ self .positions = torch .zeros (self .max_num_tokens ,
57
+ dtype = torch .int64 ,
58
+ device = device )
59
+
60
+ self .hidden_states = torch .zeros (
61
+ (self .max_num_tokens , self .hidden_size ),
62
+ dtype = self .dtype ,
63
+ device = device )
33
64
# We need +1 here because the arange is used to set query_start_loc,
34
65
# which has one more element than batch_size.
35
66
self .arange = torch .arange (vllm_config .scheduler_config .max_num_seqs +
@@ -59,13 +90,12 @@ def propose(
59
90
batch_size = next_token_ids .shape [0 ]
60
91
last_token_indices = cu_num_tokens [1 :] - 1
61
92
62
- input_ids = torch .empty_like (target_token_ids )
63
93
# Shift the input ids by one token.
64
94
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
65
- input_ids [:- 1 ] = target_token_ids [1 :]
95
+ self . input_ids [:num_tokens - 1 ] = target_token_ids [1 :]
66
96
# Replace the last token with the next token.
67
97
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
68
- input_ids [last_token_indices ] = next_token_ids
98
+ self . input_ids [last_token_indices ] = next_token_ids
69
99
70
100
# FA requires seq_len to have dtype int32.
71
101
seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
@@ -88,14 +118,30 @@ def propose(
88
118
prefix_kv_lens = None ,
89
119
suffix_kv_lens = None ,
90
120
)
121
+ if self .use_cuda_graph and \
122
+ num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
123
+ num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
124
+ else :
125
+ num_input_tokens = num_tokens
126
+ # copy inputs to buffer for cudagraph
127
+ self .positions [:num_tokens ] = target_positions
91
128
92
- with set_forward_context (attn_metadata , self .vllm_config ):
93
- hidden_states_logits , hidden_states_fwd = self .model (
94
- input_ids = input_ids ,
95
- hidden_states = target_hidden_states ,
96
- positions = target_positions ,
129
+ if self .method == 'eagle' :
130
+ self .hidden_states [:num_tokens ] = target_hidden_states
131
+ hidden_states = self .hidden_states
132
+ else :
133
+ # TODO: make eagle3 compatible with cuda graph
134
+ hidden_states = target_hidden_states
135
+
136
+ with set_forward_context (attn_metadata ,
137
+ self .vllm_config ,
138
+ num_tokens = num_input_tokens ):
139
+ last_hidden_states , hidden_states = self .model (
140
+ input_ids = self .input_ids [:num_input_tokens ],
141
+ positions = self .positions [:num_input_tokens ],
142
+ hidden_states = hidden_states [:num_input_tokens ],
97
143
)
98
- sample_hidden_states = hidden_states_logits [last_token_indices ]
144
+ sample_hidden_states = last_hidden_states [last_token_indices ]
99
145
logits = self .model .compute_logits (sample_hidden_states , None )
100
146
draft_token_ids = logits .argmax (dim = - 1 )
101
147
@@ -108,13 +154,20 @@ def propose(
108
154
draft_token_ids_list = [draft_token_ids ]
109
155
110
156
positions = target_positions [last_token_indices ]
111
- hidden_states = hidden_states_fwd [last_token_indices ]
157
+ hidden_states = hidden_states [last_token_indices ]
158
+ if self .use_cuda_graph and \
159
+ batch_size <= self .cudagraph_batch_sizes [- 1 ]:
160
+ input_batch_size = self .vllm_config .pad_for_cudagraph (batch_size )
161
+ else :
162
+ input_batch_size = batch_size
112
163
attn_metadata .num_actual_tokens = batch_size
113
164
attn_metadata .max_query_len = 1
114
165
attn_metadata .query_start_loc = self .arange [:batch_size + 1 ]
115
166
for _ in range (self .num_speculative_tokens - 1 ):
116
167
# Update the inputs.
117
- input_ids = draft_token_ids_list [- 1 ]
168
+ # cast to int32 is crucial when eagle model is compiled.
169
+ # tensor.argmax() returns int64 by default.
170
+ input_ids = draft_token_ids_list [- 1 ].int ()
118
171
positions += 1
119
172
120
173
# NOTE(woosuk): We should handle the case where the draft model
@@ -152,14 +205,27 @@ def propose(
152
205
attn_metadata .slot_mapping .masked_fill_ (exceeds_max_model_len ,
153
206
PADDING_SLOT_ID )
154
207
208
+ # copy inputs to buffer for cudagraph
209
+ self .input_ids [:batch_size ] = input_ids
210
+ self .positions [:batch_size ] = clamped_positions
211
+
212
+ if self .method == 'eagle' :
213
+ # TODO: make eagle3 compatible with cudagraph.
214
+ self .hidden_states [:batch_size ] = hidden_states
215
+ hidden_states = self .hidden_states
216
+
155
217
# Run the model.
156
- with set_forward_context (attn_metadata , self .vllm_config ):
157
- hidden_states_logits , hidden_states = self .model (
158
- input_ids = input_ids ,
159
- hidden_states = hidden_states ,
160
- positions = clamped_positions ,
218
+ with set_forward_context (attn_metadata ,
219
+ self .vllm_config ,
220
+ num_tokens = input_batch_size ):
221
+ last_hidden_states , hidden_states = self .model (
222
+ input_ids = self .input_ids [:input_batch_size ],
223
+ positions = self .positions [:input_batch_size ],
224
+ hidden_states = hidden_states [:input_batch_size ],
161
225
)
162
- logits = self .model .compute_logits (hidden_states_logits , None )
226
+ hidden_states = hidden_states [:batch_size ]
227
+ logits = self .model .compute_logits (last_hidden_states [:batch_size ],
228
+ None )
163
229
draft_token_ids = logits .argmax (dim = - 1 )
164
230
draft_token_ids_list .append (draft_token_ids )
165
231
@@ -227,13 +293,11 @@ def load_model(self, target_model: nn.Module) -> None:
227
293
draft_model_cls , arch = ModelRegistry .resolve_model_cls (
228
294
draft_model_config .architectures )
229
295
self .model = draft_model_cls (
230
- model_config = draft_model_config ,
296
+ vllm_config = self . vllm_config ,
231
297
start_layer_id = target_layer_num ).to (target_device )
232
298
233
299
loaded_weights = self .model .load_weights (
234
- loader .get_all_weights (
235
- self .vllm_config .speculative_config .draft_model_config ,
236
- self .model ))
300
+ loader .get_all_weights (draft_model_config , self .model ))
237
301
if self .vllm_config .speculative_config .method == "eagle3" :
238
302
if "model.embed_tokens.weight" not in loaded_weights :
239
303
logger .info (
@@ -243,6 +307,20 @@ def load_model(self, target_model: nn.Module) -> None:
243
307
logger .info ("Loading EAGLE LM head weights from the target model." )
244
308
self .model .lm_head = target_model .lm_head
245
309
310
+ @torch .inference_mode ()
311
+ def dummy_run (
312
+ self ,
313
+ num_tokens : int ,
314
+ ) -> None :
315
+ with set_forward_context (None , self .vllm_config ,
316
+ num_tokens = num_tokens ):
317
+ if self .method == 'eagle' :
318
+ self .model (
319
+ input_ids = self .input_ids [:num_tokens ],
320
+ positions = self .positions [:num_tokens ],
321
+ hidden_states = self .hidden_states [:num_tokens ],
322
+ )
323
+
246
324
247
325
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
248
326
# to sample the draft tokens. We will use this after we find a way to manage
0 commit comments