-
-
Notifications
You must be signed in to change notification settings - Fork 12.6k
[Model] MLPSpeculator speculative decoding support #4947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
efb0599
7a8eeff
667ef88
d534ef2
17541b6
ac5a1da
6ba9a1e
cb3aacf
bf2f102
6af4629
e0309a6
abd42e7
314f2ae
9dd1c50
0d43097
9dd1608
ea677bd
b39c94f
ab96c2a
e9af7e5
30dc5e6
3a61052
cc05972
693974e
f1bafba
455b9a9
e583ae9
7bff0d1
3d04037
bea97d7
3012553
b116e02
e7742e7
ee83331
444a709
bb9fd32
fcc6606
ffc0bcf
f3dc40a
1b7e305
46ceacd
d9ce339
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,20 +1,89 @@ | ||
| import time | ||
| from typing import List, Optional | ||
|
|
||
| from transformers import AutoConfig, PretrainedConfig | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
|
|
||
|
|
||
| class MLPSpeculatorConfig(PretrainedConfig): | ||
| model_type = "mlp_speculator" | ||
|
|
||
| attribute_map = { | ||
| "hidden_size": "emb_dim", | ||
| } | ||
|
|
||
| def __init__(self, | ||
| vocab_size: int = 32000, | ||
| emb_dim: int = 4096, | ||
| inner_dim: int = 0, | ||
| n_predict: int = 3, | ||
| top_k_tokens_per_head: Optional[List[int]] = None, | ||
| n_candidates: int = 5, | ||
| **kwargs): | ||
| """ | ||
| Initialize an MLPSpeculatorConfig | ||
| Args: | ||
| vocab_size: int | ||
| the model vocab size | ||
| emb_dim: int | ||
| the model embedding dimension | ||
| inner_dim: int | ||
| the inner dimension of the model. If 0, will be the emb_dim. | ||
| n_predict: int | ||
| the number of lookaheads for the speculator | ||
| top_k_tokens_per_head: List[int] | ||
| Number of tokens to consider from each head when forming the | ||
| candidate tree. | ||
| For each candidate branch in the tree, head n produces topk[n] | ||
| additional sub-branches. | ||
| n_candidates: int | ||
| number of child candidates to create per sequence | ||
| """ | ||
| if top_k_tokens_per_head is None: | ||
| top_k_tokens_per_head = [5, 4, 3] | ||
| assert len(top_k_tokens_per_head) == n_predict | ||
| self.vocab_size = vocab_size | ||
| self.emb_dim = emb_dim | ||
| self.inner_dim = inner_dim | ||
| self.n_predict = n_predict | ||
| self.top_k_tokens_per_head = top_k_tokens_per_head | ||
| self.n_candidates = n_candidates | ||
| super().__init__(**kwargs) | ||
|
|
||
|
|
||
| AutoConfig.register("mlp_speculator", MLPSpeculatorConfig) | ||
njhill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| template = ("Below is an instruction that describes a task. Write a response " | ||
| "that appropriately completes the request.\n\n### Instruction:\n{}" | ||
| "\n\n### Response:") | ||
|
|
||
| # Sample prompts. | ||
| prompts = [ | ||
| "Hello, my name is", | ||
| "The president of the United States is", | ||
| "The capital of France is", | ||
| "The future of AI is", | ||
| # "The president of the United States is", | ||
njhill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # "The capital of France is", | ||
| # "The future of AI is", | ||
| ] | ||
| prompts = [template.format(prompt) for prompt in prompts] | ||
| # Create a sampling params object. | ||
| sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
|
||
| # Create an LLM. | ||
| llm = LLM(model="facebook/opt-125m") | ||
| llm = LLM(model="ibm-granite/granite-7b-instruct", | ||
| use_v2_block_manager=True, | ||
| enforce_eager=True, | ||
|
||
| speculative_model="ibm-granite/granite-7b-instruct-accelerator", | ||
| num_speculative_tokens=5) | ||
| # Generate texts from the prompts. The output is a list of RequestOutput objects | ||
| # that contain the prompt, generated text, and other information. | ||
|
|
||
| outputs = llm.generate(prompts, sampling_params) | ||
| start = time.time() | ||
| outputs = llm.generate(prompts, sampling_params) | ||
| end = time.time() | ||
| print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs])) | ||
| # Print the outputs. | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| import math | ||
| from typing import Iterable, Tuple | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from vllm.model_executor import SamplingMetadata | ||
| from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
| from vllm.model_executor.layers.sampler import Sampler | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
| VocabParallelEmbedding) | ||
| from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
|
|
||
|
|
||
| class MLPSpeculatorLayerNorm(nn.Module): | ||
| """ | ||
| A L2 normalization implementation | ||
| ... | ||
| Args | ||
| ---- | ||
| normalized_shape : int | ||
| Dimensionality of input data (size of final tensor axis) | ||
| elementwise_scale_weight : torch.Tensor | ||
| learned scaling term after normalization? | ||
njhill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| elementwise_shift_bias : torch.Tensor | ||
| learned bias term after normalization? | ||
njhill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| eps : float | ||
| Safety term to prevent division by zero. Make sure the chosen value | ||
| fits in the range of your encoding scheme | ||
| (i.e. fp16 requires eps >= 6e-8). | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| normalized_shape, | ||
| eps=1e-06, | ||
| ): | ||
| super(MLPSpeculatorLayerNorm, self).__init__() | ||
| self.weight = nn.Parameter(torch.empty(normalized_shape)) | ||
| self.bias = nn.Parameter(torch.empty(normalized_shape)) | ||
| self.eps = eps | ||
|
|
||
| def forward(self, x): | ||
| xf = x | ||
| xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) | ||
| x = xf.type_as(x) | ||
| x = self.weight * x | ||
| x = x + self.bias | ||
| return x | ||
|
|
||
|
|
||
| class MLPSpeculator(nn.Module): | ||
|
|
||
| def __init__(self, config, **kwargs) -> None: | ||
| super().__init__() | ||
| self.n_predict = config.n_predict | ||
| self.vocab_size = config.vocab_size | ||
| self.emb_dim = config.emb_dim | ||
| self.inner_dim = config.inner_dim if config.inner_dim != 0 \ | ||
| else config.emb_dim | ||
| self.emb = nn.ModuleList([ | ||
| VocabParallelEmbedding(config.vocab_size, | ||
| self.inner_dim, | ||
| org_num_embeddings=config.vocab_size) | ||
| for _ in range(config.n_predict) | ||
| ]) | ||
|
|
||
| self.proj = nn.ModuleList([ | ||
| nn.Linear((self.emb_dim if i == 0 else self.inner_dim), | ||
| self.inner_dim, | ||
| bias=False) for i in range(config.n_predict) | ||
| ]) | ||
|
|
||
| self.head = nn.ModuleList([ | ||
| nn.Linear(self.inner_dim, self.vocab_size, bias=False) | ||
| for _ in range(config.n_predict) | ||
| ]) | ||
| self.ln = nn.ModuleList([ | ||
| MLPSpeculatorLayerNorm(self.inner_dim) | ||
| for _ in range(config.n_predict) | ||
| ]) | ||
|
|
||
| self.state_weight = 0.5**(0.5 / config.n_predict) | ||
| self.emb_weight = math.sqrt( | ||
| (1 - self.state_weight**2) * (self.inner_dim / 2)) | ||
| self.activation = nn.GELU() | ||
| self.config = config | ||
| self.logits_processor = LogitsProcessor(config.vocab_size, | ||
| config.vocab_size, 1.0) | ||
| self.sampler = Sampler() | ||
|
|
||
| def generate_proposals( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| sampling_metadata: SamplingMetadata, | ||
| ) -> torch.Tensor: | ||
|
|
||
| if self.first_decode_step: | ||
| self.first_decode_step = False | ||
| else: | ||
| self.previous_hidden_state = self.previous_hidden_state.reshape( | ||
| -1, self.n_predict + 1, self.previous_hidden_state.size(1)) | ||
| self.previous_hidden_state = self.previous_hidden_state.gather( | ||
| 1, (self.accepted_token_lengths - 1)[:, None, None].expand( | ||
| -1, 1, | ||
| self.previous_hidden_state.size(2))).squeeze(1) # b x d | ||
|
|
||
| # b x 1 x d | ||
| self.previous_hidden_state = self.previous_hidden_state.reshape( | ||
| self.previous_hidden_state.size(0), 1, | ||
| self.previous_hidden_state.size(1)) | ||
|
|
||
| # b x 1 | ||
| last_tokens = input_ids.reshape(-1, 1) | ||
|
|
||
| next_tokens = [] | ||
|
|
||
| for head_index in range(self.n_predict): | ||
|
|
||
| # Project and predict | ||
| z = self.emb[head_index](last_tokens) # b k d | ||
| state = self.proj[head_index](self.previous_hidden_state) | ||
|
|
||
| # Weighted add of state_weight*state and emb_weight*z | ||
| # Let subsequent LN take care of denominator | ||
| # state_weight is close to 1, so shouldn't be any precision issues | ||
| state = torch.add(state, | ||
| z, | ||
| alpha=self.emb_weight / self.state_weight) | ||
|
|
||
| state = self.activation(self.ln[head_index](state)) # b k d | ||
| # todo: not yet supporting top_k_tokens_per_head | ||
njhill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.previous_hidden_state = state | ||
|
|
||
| logits = self.logits_processor(self.head[head_index].weight, state, | ||
| sampling_metadata) | ||
|
|
||
| tmp = logits.reshape(-1, logits.size(2)) | ||
| output = self.sampler(tmp, sampling_metadata) | ||
| last_tokens = output.sampled_token_ids | ||
| next_tokens.append(output) | ||
|
|
||
| return next_tokens | ||
|
|
||
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
| params_dict = dict(self.named_parameters()) | ||
| for name, loaded_weight in weights: | ||
| param = params_dict[name.replace("speculator.", "")] | ||
| weight_loader = getattr(param, "weight_loader", | ||
| default_weight_loader) | ||
| weight_loader(param, loaded_weight) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.sequence import (ExecuteModelRequest, SamplerOutput, | ||
| SequenceGroupMetadata) | ||
| from vllm.worker.worker import Worker | ||
|
|
||
|
|
||
| class HiddenStatesWorker(Worker): | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self.speculator = None | ||
| self.prev_request_context_lengths = {} | ||
|
|
||
| def _get_hidden_states( | ||
| self, | ||
| seq_group_metadata_list: List[SequenceGroupMetadata], | ||
| kv_caches: List[torch.Tensor], | ||
| ): | ||
|
|
||
| (input_tokens, input_positions, attn_metadata, sampling_metadata, | ||
| lora_requests, lora_mapping, multi_modal_input | ||
| ) = self.model_runner.prepare_input_tensors(seq_group_metadata_list) | ||
|
|
||
| if self.model_runner.lora_config: | ||
| self.model_runner.set_active_loras(lora_requests, lora_mapping) | ||
|
|
||
| # Currently cuda graph is only supported by the decode phase. | ||
| prefill_meta = attn_metadata.prefill_metadata | ||
| decode_meta = attn_metadata.decode_metadata | ||
| if prefill_meta is None and decode_meta.use_cuda_graph: | ||
| graph_batch_size = input_tokens.shape[0] | ||
| model_executable = self.model_runner.graph_runners[ | ||
| graph_batch_size] | ||
| else: | ||
| model_executable = self.model_runner.model | ||
| execute_model_kwargs = { | ||
| "input_ids": input_tokens, | ||
| "positions": input_positions, | ||
| "kv_caches": kv_caches, | ||
| "attn_metadata": attn_metadata, | ||
| } | ||
| if self.vision_language_config: | ||
| execute_model_kwargs.update({"image_input": multi_modal_input}) | ||
|
|
||
| # save the previous hidden states for later use | ||
| hidden_states = model_executable(**execute_model_kwargs) | ||
|
|
||
| # Compute the logits. | ||
| logits = self.model_runner.model.compute_logits( | ||
| hidden_states, sampling_metadata) | ||
|
|
||
| # Only perform sampling in the driver worker. | ||
| if not self.model_runner.is_driver_worker: | ||
| return None | ||
|
|
||
| # Sample the next token. | ||
| output = self.model_runner.model.sample( | ||
| logits=logits, | ||
| sampling_metadata=sampling_metadata, | ||
| ) | ||
|
|
||
| # we only need to pass hidden states of most recent token | ||
| if seq_group_metadata_list[0].is_prompt: | ||
| hidden_states = hidden_states.index_select( | ||
| 0, sampling_metadata.selected_token_indices) | ||
|
|
||
| return output, hidden_states | ||
|
|
||
| @torch.inference_mode() | ||
| def execute_model( | ||
| self, | ||
| execute_model_req: Optional[ExecuteModelRequest] = None, | ||
| ) -> List[SamplerOutput]: | ||
|
|
||
| if execute_model_req is None: | ||
| return [] | ||
|
|
||
| sampler_output, hidden_states = self._get_hidden_states( | ||
| execute_model_req.seq_group_metadata_list, self.gpu_cache) | ||
|
|
||
| assert self.speculator is not None | ||
|
|
||
| # if we are executing the prompt, we need to flag the first decode step | ||
| # since pruning is handled differently | ||
| if execute_model_req.seq_group_metadata_list[0].is_prompt: | ||
| self.speculator.first_decode_step = True | ||
|
|
||
| self.speculator.previous_hidden_state = hidden_states | ||
| return [sampler_output] |
Uh oh!
There was an error while loading. Please reload this page.