Skip to content

Tighten compilation cache invariants around eagle #17662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

zou3519
Copy link
Collaborator

@zou3519 zou3519 commented May 5, 2025

I'm recording down my understanding of how eagle and the compilation cache works after discussing
#17211 with @luyuzhe111 and @WoosukKwon.

In the future we likely will have a situation where we want to torch.compile multiple pieces of code (e.g. decoder and encoder separately) and then we'll need to refactor the system to support it (each compiled region needs its own cache directory with its own hash) But until then the current design seems fine.

Copy link

github-actions bot commented May 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@zou3519 zou3519 requested a review from WoosukKwon May 5, 2025 15:11
@zou3519 zou3519 marked this pull request as ready for review May 5, 2025 15:12
@zou3519 zou3519 requested a review from houseroad May 5, 2025 15:12
# calls in a single model, please open an issue and let's discuss.
speculative_config = self.vllm_config.speculative_config
assert speculative_config is not None
assert speculative_config.method in ("eagle", "eagle3")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
assert speculative_config.method in ("eagle", "eagle3")
assert speculative_config.use_eagle()

Comment on lines 421 to 423
# The eagle head does not need its own hash; we assume
# the hash of the original model entirely determines the config of
# the eagle head.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small concern is that the eagle3 head often has different hidden size than the original model.
For example, the hidden size of Llama 3.3 70B is 8192 while the hidden size of its eagle3 head (from the eagle3 author) is 6144 (https://huggingface.co/yuhuili/EAGLE3-LLaMA3.3-Instruct-70B)
So, technically an eagle3 head can define its own hidden size.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only one public eagle3 head per model, so this assumption works for those public heads. I'm a little bit concerned this might not be the case for internal models/heads.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "internal models/heads"? Internal to vLLM or Meta or something else?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Internal to Meta or other companies.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 I don't mean to block this PR. I think this PR should be shipped (once the CI passes). I just wanted to heads up about the edge case. Sorry for the confusion!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Internal to Meta or other companies.

That makes sense to me

@zou3519 I don't mean to block this PR. I think this PR should be shipped (once the CI passes). I just wanted to heads up about the edge case. Sorry for the confusion!

Oh I was asking so that I can drop in some more comments here about the current state. I'll update this PR with your comments, thanks for the discussions!

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label May 8, 2025
I'm recording down my understanding of how eagle and the compilation
cache works after discussing
vllm-project#17211 with @luyuzhe111 and
@WoosukKwon.

In the future we likely will have a situation where we want to
torch.compile multiple pieces of code (e.g. decoder and encoder
separately) and then we'll need to refactor the system to support it
(each compiled region needs its own cache directory with its own hash)
But until then the current design seems fine.

Signed-off-by: rzou <[email protected]>
@zou3519
Copy link
Collaborator Author

zou3519 commented May 10, 2025

Look like the assumptions are wrong (the asserts are triggering on the tests), so we need some fixing. I have some idea of how to do this, it'll be a bigger refactor.

[2025-05-09T20:34:59Z] if compilation_counter.num_graphs_seen > 0:
  | [2025-05-09T20:34:59Z] # NOTE: Eagle compilation
  | [2025-05-09T20:34:59Z] # The eagle head is a separate model that gets run, so it needs
  | [2025-05-09T20:34:59Z] # its own cache dir (each cache dir is 1:1 with a model.forward).
  | [2025-05-09T20:34:59Z] #
  | [2025-05-09T20:34:59Z] # We currently assume that the eagle head does not need its own
  | [2025-05-09T20:34:59Z] # hash: in the vLLM repo, the hash of the original model currently
  | [2025-05-09T20:34:59Z] # entirely determines the config of the eagle head.
  | [2025-05-09T20:34:59Z] # It's very possible that this assumption will change in the
  | [2025-05-09T20:34:59Z] # future and we'll need to update this code.
  | [2025-05-09T20:34:59Z] #
  | [2025-05-09T20:34:59Z] # If you are here because you are using multiple torch.compile
  | [2025-05-09T20:34:59Z] # calls in a single model, please open an issue and let's discuss.
  | [2025-05-09T20:34:59Z] speculative_config = self.vllm_config.speculative_config
  | [2025-05-09T20:34:59Z] > assert speculative_config is not None
  | [2025-05-09T20:34:59Z] E torch._dynamo.exc.BackendCompilerFailed: backend='<vllm.compilation.backends.VllmBackend object at 0x7f84b548a1e0>' raised:
  | [2025-05-09T20:34:59Z] E AssertionError:
  | [2025-05-09T20:34:59Z] E
  | [2025-05-09T20:34:59Z] E Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

@zou3519 zou3519 closed this Jun 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants