Skip to content

Commit 0512f5a

Browse files
committed
Tighten compilation cache invariants around eagle
I'm recording down my understanding of how eagle and the compilation cache works after discussing vllm-project#17211 with @luyuzhe111 and @WoosukKwon. In the future we likely will have a situation where we want to torch.compile multiple pieces of code (e.g. decoder and encoder separately) and then we'll need to refactor the system to support it (each compiled region needs its own cache directory with its own hash) But until then the current design seems fine. Signed-off-by: rzou <[email protected]>
1 parent 22481fb commit 0512f5a

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

vllm/compilation/backends.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,22 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
417417
self.compilation_config.cache_dir = cache_dir
418418

419419
if compilation_counter.num_graphs_seen > 0:
420+
# NOTE: Eagle compilation
421+
# The eagle head is a separate model that gets run, so it needs
422+
# its own cache dir (each cache dir is 1:1 with a model.forward).
423+
#
424+
# We currently assume that the eagle head does not need its own
425+
# hash: in the vLLM repo, the hash of the original model currently
426+
# entirely determines the config of the eagle head.
427+
# It's very possible that this assumption will change in the
428+
# future and we'll need to update this code.
429+
#
430+
# If you are here because you are using multiple torch.compile
431+
# calls in a single model, please open an issue and let's discuss.
432+
speculative_config = self.vllm_config.speculative_config
433+
assert speculative_config is not None
434+
assert speculative_config.method.use_eagle()
435+
420436
cache_dir = self.compilation_config.cache_dir + \
421437
f'-{compilation_counter.num_graphs_seen}'
422438
else:

0 commit comments

Comments
 (0)