Skip to content

Commit 70788bd

Browse files
authored
[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)
Signed-off-by: Bryan Lu <[email protected]>
1 parent c9c1b59 commit 70788bd

File tree

6 files changed

+152
-53
lines changed

6 files changed

+152
-53
lines changed

examples/offline_inference/eagle.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def parse_args():
3636
help="downloaded from the eagle repo " \
3737
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
3838
)
39+
parser.add_argument("--method",
40+
type=str,
41+
default='eagle',
42+
choices=['eagle', 'eagle3'])
3943
parser.add_argument("--max_num_seqs", type=int, default=8)
4044
parser.add_argument("--num_prompts", type=int, default=80)
4145
parser.add_argument("--num_spec_tokens", type=int, default=2)
@@ -53,7 +57,13 @@ def main():
5357
args = parse_args()
5458

5559
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
56-
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
60+
61+
if args.method == 'eagle':
62+
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
63+
elif args.method == 'eagle3':
64+
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
65+
else:
66+
raise ValueError(f"unknown method: {args.method}")
5767

5868
max_model_len = 2048
5969

@@ -81,7 +91,7 @@ def main():
8191
max_num_seqs=args.max_num_seqs,
8292
gpu_memory_utilization=0.8,
8393
speculative_config={
84-
"method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle",
94+
"method": args.method,
8595
"model": eagle_dir,
8696
"num_speculative_tokens": args.num_spec_tokens,
8797
"draft_tensor_parallel_size": args.draft_tp,

vllm/compilation/backends.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,12 @@ def configure_post_pass(self):
347347
PASS_KEY = "post_grad_custom_post_pass"
348348
if PASS_KEY in inductor_config:
349349
# Config should automatically wrap all inductor passes
350-
assert isinstance(inductor_config[PASS_KEY], InductorPass)
351-
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
350+
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
351+
assert (inductor_config[PASS_KEY].uuid() ==
352+
self.post_grad_pass_manager.uuid())
353+
else:
354+
assert isinstance(inductor_config[PASS_KEY], InductorPass)
355+
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
352356
inductor_config[PASS_KEY] = self.post_grad_pass_manager
353357

354358
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
@@ -408,8 +412,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
408412
)
409413
self.compilation_config.cache_dir = cache_dir
410414

411-
cache_dir = self.compilation_config.cache_dir
415+
if compilation_counter.num_graphs_seen > 0:
416+
cache_dir = self.compilation_config.cache_dir + \
417+
f'-{compilation_counter.num_graphs_seen}'
418+
else:
419+
cache_dir = self.compilation_config.cache_dir
412420
os.makedirs(cache_dir, exist_ok=True)
421+
self.compilation_config.cache_dir = cache_dir
413422
rank = vllm_config.parallel_config.rank
414423
dp_rank = vllm_config.parallel_config.data_parallel_rank
415424
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")

vllm/model_executor/models/llama_eagle.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import torch.nn as nn
77
from transformers import LlamaConfig
88

9-
from vllm.config import ModelConfig
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import VllmConfig
1011
from vllm.logger import init_logger
1112
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1213
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -37,17 +38,19 @@ def __init__(
3738
self.input_layernorm = nn.Identity()
3839

3940

41+
@support_torch_compile
4042
class LlamaModel(nn.Module):
4143

4244
def __init__(
4345
self,
4446
*,
45-
model_config: ModelConfig,
46-
start_layer_id: int = 0,
47+
vllm_config: VllmConfig,
4748
prefix: str = "",
49+
start_layer_id: int = 0,
4850
) -> None:
4951
super().__init__()
50-
self.config = model_config.hf_config
52+
self.config = vllm_config. \
53+
speculative_config.draft_model_config.hf_config
5154
self.vocab_size = self.config.vocab_size
5255
self.embed_tokens = VocabParallelEmbedding(
5356
self.config.vocab_size,
@@ -75,8 +78,7 @@ def forward(
7578
hidden_states = self.fc(
7679
torch.cat((input_embeds, hidden_states), dim=-1))
7780
residual = None
78-
for i in range(len(self.layers)):
79-
layer = self.layers[i]
81+
for layer in self.layers:
8082
hidden_states, residual = layer(
8183
positions,
8284
hidden_states,
@@ -117,12 +119,13 @@ def load_weights(self, weights: Iterable[Tuple[str,
117119

118120
class EagleLlamaForCausalLM(LlamaForCausalLM):
119121

120-
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
122+
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
121123
nn.Module.__init__(self)
122-
self.config = model_config.hf_config
123-
self.model = LlamaModel(model_config=model_config,
124-
start_layer_id=start_layer_id,
125-
prefix="model")
124+
self.config = vllm_config. \
125+
speculative_config.draft_model_config.hf_config
126+
self.model = LlamaModel(vllm_config=vllm_config,
127+
prefix="model",
128+
start_layer_id=start_layer_id)
126129

127130
logit_scale = getattr(self.config, "logit_scale", 1.0)
128131
self.logits_processor = LogitsProcessor(self.config.vocab_size,

vllm/model_executor/models/llama_eagle3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn as nn
77
from transformers import LlamaConfig
88

9-
from vllm.config import ModelConfig
9+
from vllm.config import ModelConfig, VllmConfig
1010
from vllm.logger import init_logger
1111
from vllm.model_executor.layers.layernorm import RMSNorm
1212
from vllm.model_executor.layers.linear import QKVParallelLinear
@@ -167,8 +167,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
167167

168168
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
169169

170-
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
170+
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
171171
nn.Module.__init__(self)
172+
model_config = vllm_config.speculative_config.draft_model_config
172173
self.config = model_config.hf_config
173174
self.model = LlamaModel(model_config=model_config,
174175
start_layer_id=start_layer_id,

vllm/v1/spec_decode/eagle.py

Lines changed: 100 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import triton
55
import triton.language as tl
66

7-
from vllm.config import VllmConfig, set_current_vllm_config
7+
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
88
from vllm.forward_context import set_forward_context
99
from vllm.logger import init_logger
1010
from vllm.model_executor.model_loader.loader import get_model_loader
@@ -26,10 +26,41 @@ def __init__(
2626
device: torch.device,
2727
):
2828
self.vllm_config = vllm_config
29+
self.method = self.vllm_config.speculative_config.method
2930
self.num_speculative_tokens = (
3031
vllm_config.speculative_config.num_speculative_tokens)
3132
self.max_model_len = vllm_config.model_config.max_model_len
3233
self.block_size = vllm_config.cache_config.block_size
34+
35+
self.dtype = vllm_config.model_config.dtype
36+
37+
self.max_num_tokens = vllm_config.scheduler_config \
38+
.max_num_batched_tokens
39+
40+
self.hidden_size = vllm_config.model_config.get_hidden_size()
41+
42+
# TODO: make eagle3 compatible with cudagraph
43+
self.use_cuda_graph = self.method != 'eagle3' and \
44+
(self.vllm_config.compilation_config.level
45+
== CompilationLevel.PIECEWISE and
46+
not self.vllm_config.model_config.enforce_eager)
47+
48+
self.cudagraph_batch_sizes = list(
49+
reversed(
50+
self.vllm_config.compilation_config.cudagraph_capture_sizes))
51+
52+
# persistent buffers for cuda graph
53+
self.input_ids = torch.zeros(self.max_num_tokens,
54+
dtype=torch.int32,
55+
device=device)
56+
self.positions = torch.zeros(self.max_num_tokens,
57+
dtype=torch.int64,
58+
device=device)
59+
60+
self.hidden_states = torch.zeros(
61+
(self.max_num_tokens, self.hidden_size),
62+
dtype=self.dtype,
63+
device=device)
3364
# We need +1 here because the arange is used to set query_start_loc,
3465
# which has one more element than batch_size.
3566
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
@@ -59,13 +90,12 @@ def propose(
5990
batch_size = next_token_ids.shape[0]
6091
last_token_indices = cu_num_tokens[1:] - 1
6192

62-
input_ids = torch.empty_like(target_token_ids)
6393
# Shift the input ids by one token.
6494
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
65-
input_ids[:-1] = target_token_ids[1:]
95+
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
6696
# Replace the last token with the next token.
6797
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
68-
input_ids[last_token_indices] = next_token_ids
98+
self.input_ids[last_token_indices] = next_token_ids
6999

70100
# FA requires seq_len to have dtype int32.
71101
seq_lens = (target_positions[last_token_indices] + 1).int()
@@ -88,14 +118,30 @@ def propose(
88118
prefix_kv_lens=None,
89119
suffix_kv_lens=None,
90120
)
121+
if self.use_cuda_graph and \
122+
num_tokens <= self.cudagraph_batch_sizes[-1]:
123+
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
124+
else:
125+
num_input_tokens = num_tokens
126+
# copy inputs to buffer for cudagraph
127+
self.positions[:num_tokens] = target_positions
91128

92-
with set_forward_context(attn_metadata, self.vllm_config):
93-
hidden_states_logits, hidden_states_fwd = self.model(
94-
input_ids=input_ids,
95-
hidden_states=target_hidden_states,
96-
positions=target_positions,
129+
if self.method == 'eagle':
130+
self.hidden_states[:num_tokens] = target_hidden_states
131+
hidden_states = self.hidden_states
132+
else:
133+
# TODO: make eagle3 compatible with cuda graph
134+
hidden_states = target_hidden_states
135+
136+
with set_forward_context(attn_metadata,
137+
self.vllm_config,
138+
num_tokens=num_input_tokens):
139+
last_hidden_states, hidden_states = self.model(
140+
input_ids=self.input_ids[:num_input_tokens],
141+
positions=self.positions[:num_input_tokens],
142+
hidden_states=hidden_states[:num_input_tokens],
97143
)
98-
sample_hidden_states = hidden_states_logits[last_token_indices]
144+
sample_hidden_states = last_hidden_states[last_token_indices]
99145
logits = self.model.compute_logits(sample_hidden_states, None)
100146
draft_token_ids = logits.argmax(dim=-1)
101147

@@ -108,13 +154,20 @@ def propose(
108154
draft_token_ids_list = [draft_token_ids]
109155

110156
positions = target_positions[last_token_indices]
111-
hidden_states = hidden_states_fwd[last_token_indices]
157+
hidden_states = hidden_states[last_token_indices]
158+
if self.use_cuda_graph and \
159+
batch_size <= self.cudagraph_batch_sizes[-1]:
160+
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
161+
else:
162+
input_batch_size = batch_size
112163
attn_metadata.num_actual_tokens = batch_size
113164
attn_metadata.max_query_len = 1
114165
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
115166
for _ in range(self.num_speculative_tokens - 1):
116167
# Update the inputs.
117-
input_ids = draft_token_ids_list[-1]
168+
# cast to int32 is crucial when eagle model is compiled.
169+
# tensor.argmax() returns int64 by default.
170+
input_ids = draft_token_ids_list[-1].int()
118171
positions += 1
119172

120173
# NOTE(woosuk): We should handle the case where the draft model
@@ -152,14 +205,27 @@ def propose(
152205
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
153206
PADDING_SLOT_ID)
154207

208+
# copy inputs to buffer for cudagraph
209+
self.input_ids[:batch_size] = input_ids
210+
self.positions[:batch_size] = clamped_positions
211+
212+
if self.method == 'eagle':
213+
# TODO: make eagle3 compatible with cudagraph.
214+
self.hidden_states[:batch_size] = hidden_states
215+
hidden_states = self.hidden_states
216+
155217
# Run the model.
156-
with set_forward_context(attn_metadata, self.vllm_config):
157-
hidden_states_logits, hidden_states = self.model(
158-
input_ids=input_ids,
159-
hidden_states=hidden_states,
160-
positions=clamped_positions,
218+
with set_forward_context(attn_metadata,
219+
self.vllm_config,
220+
num_tokens=input_batch_size):
221+
last_hidden_states, hidden_states = self.model(
222+
input_ids=self.input_ids[:input_batch_size],
223+
positions=self.positions[:input_batch_size],
224+
hidden_states=hidden_states[:input_batch_size],
161225
)
162-
logits = self.model.compute_logits(hidden_states_logits, None)
226+
hidden_states = hidden_states[:batch_size]
227+
logits = self.model.compute_logits(last_hidden_states[:batch_size],
228+
None)
163229
draft_token_ids = logits.argmax(dim=-1)
164230
draft_token_ids_list.append(draft_token_ids)
165231

@@ -227,13 +293,11 @@ def load_model(self, target_model: nn.Module) -> None:
227293
draft_model_cls, arch = ModelRegistry.resolve_model_cls(
228294
draft_model_config.architectures)
229295
self.model = draft_model_cls(
230-
model_config=draft_model_config,
296+
vllm_config=self.vllm_config,
231297
start_layer_id=target_layer_num).to(target_device)
232298

233299
loaded_weights = self.model.load_weights(
234-
loader.get_all_weights(
235-
self.vllm_config.speculative_config.draft_model_config,
236-
self.model))
300+
loader.get_all_weights(draft_model_config, self.model))
237301
if self.vllm_config.speculative_config.method == "eagle3":
238302
if "model.embed_tokens.weight" not in loaded_weights:
239303
logger.info(
@@ -243,6 +307,20 @@ def load_model(self, target_model: nn.Module) -> None:
243307
logger.info("Loading EAGLE LM head weights from the target model.")
244308
self.model.lm_head = target_model.lm_head
245309

310+
@torch.inference_mode()
311+
def dummy_run(
312+
self,
313+
num_tokens: int,
314+
) -> None:
315+
with set_forward_context(None, self.vllm_config,
316+
num_tokens=num_tokens):
317+
if self.method == 'eagle':
318+
self.model(
319+
input_ids=self.input_ids[:num_tokens],
320+
positions=self.positions[:num_tokens],
321+
hidden_states=self.hidden_states[:num_tokens],
322+
)
323+
246324

247325
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
248326
# to sample the draft tokens. We will use this after we find a way to manage

0 commit comments

Comments
 (0)