diff --git a/benchmarks/gpt_fused/generate.py b/benchmarks/gpt_fused/generate.py new file mode 100644 index 0000000000..a9a246e3e2 --- /dev/null +++ b/benchmarks/gpt_fused/generate.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config + +from torchao.prototype.models.gpt_fused.model import Transformer +from tokenizer import get_tokenizer + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True + +default_device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + + +def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + +def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + +@torch.compile(mode='max-autotune', fullgraph=True) +def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + +def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, **sampling_kwargs): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=True): # Actually better for Inductor to codegen attention here + # with torch.autograd.profiler.record_function(f"generate token {i}"): + # torch.cuda.synchronize() + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + # torch.cuda.synchronize() + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + **sampling_kwargs +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(0) + T_new = T + max_new_tokens + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = prompt.device, prompt.dtype + max_seq_length = max_seq_length + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + from torchao.quantization import apply_weight_only_int8_quant + apply_weight_only_int8_quant(model) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + torch.cuda.synchronize(device) + t0 = time.perf_counter() + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs) + seq[T + 1:] = torch.cat(generated_tokens) + torch.cuda.synchronize(device) + t = time.perf_counter() - t0 + + return seq, t + +def encode_tokens(tokenizer, string, bos=True, device=default_device): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + +def _load_model(checkpoint_path, device, precision): + use_cuda = 'cuda' in device + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) + + if "int8" in str(checkpoint_path): + print("Using int8 weight-only quantization!") + from quantize import WeightOnlyInt8QuantHandler + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(checkpoint_path): + print("Using int4 weight-only quantization!") + path_comps = checkpoint_path.name.split(".") + assert path_comps[-3].startswith("g") + assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!" + groupsize = int(path_comps[-3][1:]) + from quantize import WeightOnlyInt4QuantHandler + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime(use_cuda) + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + model.load_state_dict(checkpoint, assign=True) + + model = model.to(device=device, dtype=precision) + return model.eval() + +B_INST, E_INST = "[INST]", "[/INST]" + +def profiler_runner(path, fn, *args, **kwargs): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + record_shapes=True) as prof: + result = fn(*args, **kwargs) + prof.export_chrome_trace(path) + return result + +def main( + prompt: str = "Hello, my name is", + num_samples: int = 5, + max_new_tokens: int = 100, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + device=default_device, +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer. + """ + assert checkpoint_path.is_file(), checkpoint_path + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + print(f"Using device={device}") + precision = torch.bfloat16 + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision) + + torch.cuda.synchronize(device) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) + + aggregate_metrics = { + 'tokens_per_sec': [], + } + + for i in range(num_samples): + with torch.autograd.profiler.record_function(f"timed region for inference {i}"): + y, t = generate( + model, + encoded, + max_new_tokens, + temperature=temperature, + top_k=top_k, + ) + if i == 0: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + + # print(tokenizer.decode(y.tolist())) + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + if i > 0: + aggregate_metrics['tokens_per_sec'].append(tokens_sec) + else: + print("Don't count first inference run.") + print("==========") + print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + + parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') + parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') + parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') + parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') + parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') + parser.add_argument('--device', type=str, default=default_device, help='Device to use') + + args = parser.parse_args() + # profiler_runner("profile.json.gz", main, + main( + args.prompt, args.num_samples, args.max_new_tokens, args.top_k, + args.temperature, args.checkpoint_path, + args.device + ) diff --git a/benchmarks/gpt_fused/tokenizer.py b/benchmarks/gpt_fused/tokenizer.py new file mode 100644 index 0000000000..c62a0c5b3a --- /dev/null +++ b/benchmarks/gpt_fused/tokenizer.py @@ -0,0 +1,111 @@ +import os +import sentencepiece as spm +import tiktoken +from tiktoken.load import load_tiktoken_bpe +from pathlib import Path +from typing import Dict + +class TokenizerInterface: + def __init__(self, model_path): + self.model_path = model_path + + def encode(self, text): + raise NotImplementedError("This method should be overridden by subclasses.") + + def decode(self, tokens): + raise NotImplementedError("This method should be overridden by subclasses.") + + def bos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + + def eos_id(self): + raise NotImplementedError("This method should be overridden by subclasses.") + +class SentencePieceWrapper(TokenizerInterface): + def __init__(self, model_path): + super().__init__(model_path) + self.processor = spm.SentencePieceProcessor(str(model_path)) + + def encode(self, text): + return self.processor.EncodeAsIds(text) + + def decode(self, tokens): + return self.processor.DecodeIds(tokens) + + def bos_id(self): + return self.processor.bos_id() + + def eos_id(self): + return self.processor.eos_id() + +class TiktokenWrapper(TokenizerInterface): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path): + super().__init__(model_path) + assert os.path.isfile(model_path), str(model_path) + mergeable_ranks = load_tiktoken_bpe(str(model_path)) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + # BOS / EOS token IDs + self._bos_id: int = self.special_tokens["<|begin_of_text|>"] + self._eos_id: int = self.special_tokens["<|end_of_text|>"] + + def encode(self, text): + return self.model.encode(text) + + def decode(self, tokens): + return self.model.decode(tokens) + + def bos_id(self): + return self._bos_id + + def eos_id(self): + return self._eos_id + +def get_tokenizer(tokenizer_model_path, model_name): + """ + Factory function to get the appropriate tokenizer based on the model name. + + Args: + - tokenizer_model_path (str): The file path to the tokenizer model. + - model_name (str): The name of the model, used to determine the tokenizer type. + + Returns: + - TokenizerInterface: An instance of a tokenizer. + """ + if "Llama-3" in str(model_name): + return TiktokenWrapper(tokenizer_model_path) + else: + return SentencePieceWrapper(tokenizer_model_path) diff --git a/torchao/prototype/models/gpt_fused/README.md b/torchao/prototype/models/gpt_fused/README.md new file mode 100644 index 0000000000..40eab89f5b --- /dev/null +++ b/torchao/prototype/models/gpt_fused/README.md @@ -0,0 +1,13 @@ +## gpt-fused + +A more handwritten version of [gpt-fast](https://github.com/pytorch-labs/gpt-fast)'s model.py for us to experiment with. + +Requires the use of gpt-fast to use, but is a drop-in replacement. + +To use it just set to PYTHONPATH environment variable to `PYTHONPATH=/ao/torchao/prototype/models/gpt_fused` and delete gpt-fast's model.py. + +For example + +``` +PYTHONPATH=/home/cpuhrsch/local/ao/torchao/prototype/models/gpt_fused CUDA_VISIBLE_DEVICES=0 numactl --membind 0 --cpubind 0 python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is" +``` diff --git a/torchao/prototype/models/gpt_fused/model.py b/torchao/prototype/models/gpt_fused/model.py new file mode 100644 index 0000000000..a303b19ff5 --- /dev/null +++ b/torchao/prototype/models/gpt_fused/model.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + + # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, + # take longer name (as it have more symbols matched) + if len(config) > 1: + config.sort(key=len, reverse=True) + assert len(config[0]) != len(config[1]), name # make sure only one 'best' match + + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), + "7B": dict(n_layer=32, n_head=32, dim=4096), + "13B": dict(n_layer=40, n_head=40, dim=5120), + "30B": dict(n_layer=60, n_head=52, dim=6656), + "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf + "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), + "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), + "stories15M": dict(n_layer=6, n_head=6, dim=288), + "stories110M": dict(n_layer=12, n_head=12, dim=768), + "Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256), +} + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + dtype = self.output.weight.dtype + # For quantized layers, dtype is encoded in scales + if hasattr(self.output, "scales"): + dtype = self.output.scales.dtype + elif hasattr(self.output, "scales_and_zeros"): + dtype = self.output.scales_and_zeros.dtype + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w13 = nn.Linear(config.dim, 2 * config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + self.dim = config.intermediate_size + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "w1.weight" in state_dict: + w1 = state_dict.pop(prefix + "w1.weight") + w3 = state_dict.pop(prefix + "w3.weight") + state_dict[prefix + "w13.weight"] = torch.cat([w1, w3]) + + def forward(self, x: Tensor) -> Tensor: + x1, x3 = self.w13(x).split([self.dim, self.dim], dim=-1) + return self.w2(F.silu(x1) * x3) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, + dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x)