Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
efb0599
initial commit of mlp_speculator and hidden_states_worker to support …
JRosenkranz May 20, 2024
7a8eeff
Merge branch 'main' into mlp_speculator
JRosenkranz May 20, 2024
667ef88
removed fms_extras import
JRosenkranz May 20, 2024
d534ef2
updated with a working non-batch version - a lot hardcoded
JRosenkranz May 21, 2024
17541b6
updated experimental with working version - eager
JRosenkranz May 22, 2024
ac5a1da
fixed bug with speculator outputs
JRosenkranz May 22, 2024
6ba9a1e
removed comments; swapped to sampling in the example
JRosenkranz May 22, 2024
cb3aacf
Introduce MLPSpeculatorWorker and corresponding refactor
tdoublep May 27, 2024
bf2f102
Fix some issues with correctness + simplify API a bit
tdoublep May 27, 2024
6af4629
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill May 31, 2024
e0309a6
Fix typing and formatting
njhill May 31, 2024
abd42e7
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 1, 2024
314f2ae
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 5, 2024
9dd1c50
Remove separate MLPSpeculatorModelRunner and other cleanup
njhill Jun 5, 2024
0d43097
Use sample_len in mlp_speculator
njhill Jun 5, 2024
9dd1608
Some more rework/simplification, still in progress
njhill Jun 6, 2024
ea677bd
Config cleanup
njhill Jun 7, 2024
b39c94f
Ignore weird mypi error only happening in CI
njhill Jun 7, 2024
ab96c2a
Try again to ignore weird ruff error
njhill Jun 7, 2024
e9af7e5
Try to ignore both ruff and mypy errs
njhill Jun 7, 2024
30dc5e6
yapf
njhill Jun 7, 2024
3a61052
Fix leftover HiddenStatesWorker references
njhill Jun 7, 2024
cc05972
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 7, 2024
693974e
Fix AutoConfig import, mlp spec worker docstring
njhill Jun 7, 2024
f1bafba
Some cleanup/simplification
njhill Jun 7, 2024
455b9a9
Rework handling of accepted tokens
njhill Jun 7, 2024
e583ae9
Filter hidden states in Top1Proposer when needed
njhill Jun 9, 2024
7bff0d1
Enable bonus token
njhill Jun 9, 2024
3d04037
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 9, 2024
bea97d7
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 11, 2024
3012553
Move hidden state logic to separate class
njhill Jun 11, 2024
b116e02
Default num_speculative_tokens based on speculator model config
njhill Jun 15, 2024
e7742e7
Move offline_inference example to separate file
njhill Jun 15, 2024
ee83331
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 15, 2024
444a709
ruff
njhill Jun 15, 2024
bb9fd32
Add comment per review
njhill Jun 15, 2024
fcc6606
Some simplification to MLPSpeculatorWorker._prepare_input_tensors
njhill Jun 15, 2024
ffc0bcf
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 17, 2024
f3dc40a
Add check for TP == 1; TP support will be a fast-follow
njhill Jun 17, 2024
1b7e305
Fix test import
njhill Jun 17, 2024
46ceacd
Revert unrelated commit made by mistake
njhill Jun 17, 2024
d9ce339
Fix test mocks
njhill Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,89 @@
import time
from typing import List, Optional

from transformers import AutoConfig, PretrainedConfig

from vllm import LLM, SamplingParams


class MLPSpeculatorConfig(PretrainedConfig):
model_type = "mlp_speculator"

attribute_map = {
"hidden_size": "emb_dim",
}

def __init__(self,
vocab_size: int = 32000,
emb_dim: int = 4096,
inner_dim: int = 0,
n_predict: int = 3,
top_k_tokens_per_head: Optional[List[int]] = None,
n_candidates: int = 5,
**kwargs):
"""
Initialize an MLPSpeculatorConfig
Args:
vocab_size: int
the model vocab size
emb_dim: int
the model embedding dimension
inner_dim: int
the inner dimension of the model. If 0, will be the emb_dim.
n_predict: int
the number of lookaheads for the speculator
top_k_tokens_per_head: List[int]
Number of tokens to consider from each head when forming the
candidate tree.
For each candidate branch in the tree, head n produces topk[n]
additional sub-branches.
n_candidates: int
number of child candidates to create per sequence
"""
if top_k_tokens_per_head is None:
top_k_tokens_per_head = [5, 4, 3]
assert len(top_k_tokens_per_head) == n_predict
self.vocab_size = vocab_size
self.emb_dim = emb_dim
self.inner_dim = inner_dim
self.n_predict = n_predict
self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates
super().__init__(**kwargs)


AutoConfig.register("mlp_speculator", MLPSpeculatorConfig)

template = ("Below is an instruction that describes a task. Write a response "
"that appropriately completes the request.\n\n### Instruction:\n{}"
"\n\n### Response:")

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
# "The president of the United States is",
# "The capital of France is",
# "The future of AI is",
]
prompts = [template.format(prompt) for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="ibm-granite/granite-7b-instruct",
use_v2_block_manager=True,
enforce_eager=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does MLPSpeculator work with Cudagraph now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LiuXiaoxuanPKU I'm actually not sure ... it's still pretty fast without it. Was thinking to look at cudagraph as a follow-on.

speculative_model="ibm-granite/granite-7b-instruct-accelerator",
num_speculative_tokens=5)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.

outputs = llm.generate(prompts, sampling_params)
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
# Print the outputs.
for output in outputs:
prompt = output.prompt
Expand Down
4 changes: 3 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,9 @@ def _verify_args(self) -> None:
raise ValueError("Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens}).")

if self.draft_model_config:
if (self.draft_model_config
and self.draft_model_config.hf_config.model_type !=
"mlp_speculator"):
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config)

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}

_EMBEDDING_MODELS = {
Expand Down
151 changes: 151 additions & 0 deletions vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import math
from typing import Iterable, Tuple

import torch
import torch.nn as nn

from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader


class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
elementwise_scale_weight : torch.Tensor
learned scaling term after normalization?
elementwise_shift_bias : torch.Tensor
learned bias term after normalization?
eps : float
Safety term to prevent division by zero. Make sure the chosen value
fits in the range of your encoding scheme
(i.e. fp16 requires eps >= 6e-8).
"""

def __init__(
self,
normalized_shape,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps

def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x


class MLPSpeculator(nn.Module):

def __init__(self, config, **kwargs) -> None:
super().__init__()
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
else config.emb_dim
self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
for _ in range(config.n_predict)
])

self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False) for i in range(config.n_predict)
])

self.head = nn.ModuleList([
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
for _ in range(config.n_predict)
])
self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim)
for _ in range(config.n_predict)
])

self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt(
(1 - self.state_weight**2) * (self.inner_dim / 2))
self.activation = nn.GELU()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
config.vocab_size, 1.0)
self.sampler = Sampler()

def generate_proposals(
self,
input_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:

if self.first_decode_step:
self.first_decode_step = False
else:
self.previous_hidden_state = self.previous_hidden_state.reshape(
-1, self.n_predict + 1, self.previous_hidden_state.size(1))
self.previous_hidden_state = self.previous_hidden_state.gather(
1, (self.accepted_token_lengths - 1)[:, None, None].expand(
-1, 1,
self.previous_hidden_state.size(2))).squeeze(1) # b x d

# b x 1 x d
self.previous_hidden_state = self.previous_hidden_state.reshape(
self.previous_hidden_state.size(0), 1,
self.previous_hidden_state.size(1))

# b x 1
last_tokens = input_ids.reshape(-1, 1)

next_tokens = []

for head_index in range(self.n_predict):

# Project and predict
z = self.emb[head_index](last_tokens) # b k d
state = self.proj[head_index](self.previous_hidden_state)

# Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator
# state_weight is close to 1, so shouldn't be any precision issues
state = torch.add(state,
z,
alpha=self.emb_weight / self.state_weight)

state = self.activation(self.ln[head_index](state)) # b k d
# todo: not yet supporting top_k_tokens_per_head
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, we are assuming the set of top_k_tokens_per_head are each 1 for every head. We may want to see how this affects performance (acceptance rate) as typically we create a tree and traverse the tree to find the optimal candidate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, n_candidates is also not used? (it is assumed to be equal 1?)

self.previous_hidden_state = state

logits = self.logits_processor(self.head[head_index].weight, state,
sampling_metadata)

tmp = logits.reshape(-1, logits.size(2))
output = self.sampler(tmp, sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)

return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name.replace("speculator.", "")]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
92 changes: 92 additions & 0 deletions vllm/spec_decode/hidden_states_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import List, Optional

import torch

from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.worker.worker import Worker


class HiddenStatesWorker(Worker):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.speculator = None
self.prev_request_context_lengths = {}

def _get_hidden_states(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
):

(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_input
) = self.model_runner.prepare_input_tensors(seq_group_metadata_list)

if self.model_runner.lora_config:
self.model_runner.set_active_loras(lora_requests, lora_mapping)

# Currently cuda graph is only supported by the decode phase.
prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.model_runner.graph_runners[
graph_batch_size]
else:
model_executable = self.model_runner.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})

# save the previous hidden states for later use
hidden_states = model_executable(**execute_model_kwargs)

# Compute the logits.
logits = self.model_runner.model.compute_logits(
hidden_states, sampling_metadata)

# Only perform sampling in the driver worker.
if not self.model_runner.is_driver_worker:
return None

# Sample the next token.
output = self.model_runner.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)

# we only need to pass hidden states of most recent token
if seq_group_metadata_list[0].is_prompt:
hidden_states = hidden_states.index_select(
0, sampling_metadata.selected_token_indices)

return output, hidden_states

@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:

if execute_model_req is None:
return []

sampler_output, hidden_states = self._get_hidden_states(
execute_model_req.seq_group_metadata_list, self.gpu_cache)

assert self.speculator is not None

# if we are executing the prompt, we need to flag the first decode step
# since pruning is handled differently
if execute_model_req.seq_group_metadata_list[0].is_prompt:
self.speculator.first_decode_step = True

self.speculator.previous_hidden_state = hidden_states
return [sampler_output]
Loading