-
Notifications
You must be signed in to change notification settings - Fork 260
[WIP] gpt_fused #189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] gpt_fused #189
Changes from all commits
3824a28
480d83e
d362042
f730525
620e208
a994000
12535d8
cc1f244
de8400d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import itertools | ||
import sys | ||
import time | ||
from pathlib import Path | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
import torch._dynamo.config | ||
import torch._inductor.config | ||
|
||
from torchao.prototype.models.gpt_fused.model import Transformer | ||
from tokenizer import get_tokenizer | ||
|
||
torch._inductor.config.coordinate_descent_tuning = True | ||
torch._inductor.config.triton.unique_kernel_names = True | ||
|
||
default_device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
# support running without installing as a package | ||
wd = Path(__file__).parent.parent.resolve() | ||
sys.path.append(str(wd)) | ||
|
||
|
||
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization | ||
q = torch.empty_like(probs_sort).exponential_(1) | ||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | ||
|
||
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): | ||
logits = logits / max(temperature, 1e-5) | ||
|
||
if top_k is not None: | ||
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | ||
pivot = v.select(-1, -1).unsqueeze(-1) | ||
logits = torch.where(logits < pivot, -float("Inf"), logits) | ||
probs = torch.nn.functional.softmax(logits, dim=-1) | ||
return probs | ||
|
||
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): | ||
probs = logits_to_probs(logits[0, -1], temperature, top_k) | ||
idx_next = multinomial_sample_one_no_sync(probs) | ||
return idx_next, probs | ||
|
||
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: | ||
# input_pos: [B, S] | ||
logits = model(x, input_pos) | ||
return sample(logits, **sampling_kwargs)[0] | ||
|
||
@torch.compile(mode='max-autotune', fullgraph=True) | ||
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: | ||
# input_pos: [B, 1] | ||
assert input_pos.shape[-1] == 1 | ||
logits = model(x, input_pos) | ||
return sample(logits, **sampling_kwargs) | ||
|
||
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, **sampling_kwargs): | ||
new_tokens, new_probs = [], [] | ||
for i in range(num_new_tokens): | ||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here | ||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=True): # Actually better for Inductor to codegen attention here | ||
# with torch.autograd.profiler.record_function(f"generate token {i}"): | ||
# torch.cuda.synchronize() | ||
next_token, next_prob = decode_one_token( | ||
model, cur_token, input_pos, **sampling_kwargs | ||
) | ||
input_pos += 1 | ||
new_tokens.append(next_token.clone()) | ||
new_probs.append(next_prob.clone()) | ||
cur_token = next_token.view(1, -1) | ||
# torch.cuda.synchronize() | ||
|
||
return new_tokens, new_probs | ||
|
||
|
||
def model_forward(model, x, input_pos): | ||
return model(x, input_pos) | ||
|
||
@torch.no_grad() | ||
def generate( | ||
model: Transformer, | ||
prompt: torch.Tensor, | ||
max_new_tokens: int, | ||
**sampling_kwargs | ||
) -> torch.Tensor: | ||
""" | ||
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. | ||
""" | ||
|
||
# create an empty tensor of the expected final shape and fill in the current tokens | ||
T = prompt.size(0) | ||
T_new = T + max_new_tokens | ||
max_seq_length = min(T_new, model.config.block_size) | ||
|
||
device, dtype = prompt.device, prompt.dtype | ||
max_seq_length = max_seq_length | ||
with torch.device(device): | ||
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) | ||
|
||
from torchao.quantization import apply_weight_only_int8_quant | ||
apply_weight_only_int8_quant(model) | ||
|
||
# create an empty tensor of the expected final shape and fill in the current tokens | ||
empty = torch.empty(T_new, dtype=dtype, device=device) | ||
empty[:T] = prompt | ||
seq = empty | ||
input_pos = torch.arange(0, T, device=device) | ||
|
||
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) | ||
seq[T] = next_token | ||
|
||
input_pos = torch.tensor([T], device=device, dtype=torch.int) | ||
torch.cuda.synchronize(device) | ||
t0 = time.perf_counter() | ||
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs) | ||
seq[T + 1:] = torch.cat(generated_tokens) | ||
torch.cuda.synchronize(device) | ||
t = time.perf_counter() - t0 | ||
|
||
return seq, t | ||
|
||
def encode_tokens(tokenizer, string, bos=True, device=default_device): | ||
tokens = tokenizer.encode(string) | ||
if bos: | ||
tokens = [tokenizer.bos_id()] + tokens | ||
return torch.tensor(tokens, dtype=torch.int, device=device) | ||
|
||
def _load_model(checkpoint_path, device, precision): | ||
use_cuda = 'cuda' in device | ||
with torch.device('meta'): | ||
model = Transformer.from_name(checkpoint_path.parent.name) | ||
|
||
if "int8" in str(checkpoint_path): | ||
print("Using int8 weight-only quantization!") | ||
from quantize import WeightOnlyInt8QuantHandler | ||
simple_quantizer = WeightOnlyInt8QuantHandler(model) | ||
model = simple_quantizer.convert_for_runtime() | ||
|
||
if "int4" in str(checkpoint_path): | ||
print("Using int4 weight-only quantization!") | ||
path_comps = checkpoint_path.name.split(".") | ||
assert path_comps[-3].startswith("g") | ||
assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!" | ||
groupsize = int(path_comps[-3][1:]) | ||
from quantize import WeightOnlyInt4QuantHandler | ||
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) | ||
model = simple_quantizer.convert_for_runtime(use_cuda) | ||
|
||
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) | ||
if "model" in checkpoint and "stories" in str(checkpoint_path): | ||
checkpoint = checkpoint["model"] | ||
model.load_state_dict(checkpoint, assign=True) | ||
|
||
model = model.to(device=device, dtype=precision) | ||
return model.eval() | ||
|
||
B_INST, E_INST = "[INST]", "[/INST]" | ||
|
||
def profiler_runner(path, fn, *args, **kwargs): | ||
with torch.profiler.profile( | ||
activities=[torch.profiler.ProfilerActivity.CPU, | ||
torch.profiler.ProfilerActivity.CUDA], | ||
record_shapes=True) as prof: | ||
result = fn(*args, **kwargs) | ||
prof.export_chrome_trace(path) | ||
return result | ||
|
||
def main( | ||
prompt: str = "Hello, my name is", | ||
num_samples: int = 5, | ||
max_new_tokens: int = 100, | ||
top_k: int = 200, | ||
temperature: float = 0.8, | ||
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), | ||
device=default_device, | ||
) -> None: | ||
"""Generates text samples based on a pre-trained Transformer model and tokenizer. | ||
""" | ||
assert checkpoint_path.is_file(), checkpoint_path | ||
|
||
tokenizer_path = checkpoint_path.parent / "tokenizer.model" | ||
assert tokenizer_path.is_file(), str(tokenizer_path) | ||
|
||
print(f"Using device={device}") | ||
precision = torch.bfloat16 | ||
|
||
print("Loading model ...") | ||
t0 = time.time() | ||
model = _load_model(checkpoint_path, device, precision) | ||
|
||
torch.cuda.synchronize(device) | ||
print(f"Time to load model: {time.time() - t0:.02f} seconds") | ||
|
||
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) | ||
|
||
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) | ||
prompt_length = encoded.size(0) | ||
|
||
torch.manual_seed(1234) | ||
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) | ||
|
||
aggregate_metrics = { | ||
'tokens_per_sec': [], | ||
} | ||
|
||
for i in range(num_samples): | ||
with torch.autograd.profiler.record_function(f"timed region for inference {i}"): | ||
y, t = generate( | ||
model, | ||
encoded, | ||
max_new_tokens, | ||
temperature=temperature, | ||
top_k=top_k, | ||
) | ||
if i == 0: | ||
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") | ||
|
||
# print(tokenizer.decode(y.tolist())) | ||
tokens_generated = y.size(0) - prompt_length | ||
tokens_sec = tokens_generated / t | ||
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") | ||
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") | ||
if i > 0: | ||
aggregate_metrics['tokens_per_sec'].append(tokens_sec) | ||
else: | ||
print("Don't count first inference run.") | ||
print("==========") | ||
print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") | ||
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") | ||
|
||
|
||
if __name__ == '__main__': | ||
import argparse | ||
parser = argparse.ArgumentParser(description='Your CLI description.') | ||
|
||
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') | ||
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') | ||
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') | ||
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') | ||
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') | ||
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') | ||
parser.add_argument('--device', type=str, default=default_device, help='Device to use') | ||
|
||
args = parser.parse_args() | ||
# profiler_runner("profile.json.gz", main, | ||
main( | ||
args.prompt, args.num_samples, args.max_new_tokens, args.top_k, | ||
args.temperature, args.checkpoint_path, | ||
args.device | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import os | ||
import sentencepiece as spm | ||
import tiktoken | ||
from tiktoken.load import load_tiktoken_bpe | ||
from pathlib import Path | ||
from typing import Dict | ||
|
||
class TokenizerInterface: | ||
def __init__(self, model_path): | ||
self.model_path = model_path | ||
|
||
def encode(self, text): | ||
raise NotImplementedError("This method should be overridden by subclasses.") | ||
|
||
def decode(self, tokens): | ||
raise NotImplementedError("This method should be overridden by subclasses.") | ||
|
||
def bos_id(self): | ||
raise NotImplementedError("This method should be overridden by subclasses.") | ||
|
||
def eos_id(self): | ||
raise NotImplementedError("This method should be overridden by subclasses.") | ||
|
||
class SentencePieceWrapper(TokenizerInterface): | ||
def __init__(self, model_path): | ||
super().__init__(model_path) | ||
self.processor = spm.SentencePieceProcessor(str(model_path)) | ||
|
||
def encode(self, text): | ||
return self.processor.EncodeAsIds(text) | ||
|
||
def decode(self, tokens): | ||
return self.processor.DecodeIds(tokens) | ||
|
||
def bos_id(self): | ||
return self.processor.bos_id() | ||
|
||
def eos_id(self): | ||
return self.processor.eos_id() | ||
|
||
class TiktokenWrapper(TokenizerInterface): | ||
""" | ||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer. | ||
""" | ||
|
||
special_tokens: Dict[str, int] | ||
|
||
num_reserved_special_tokens = 256 | ||
|
||
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 | ||
|
||
def __init__(self, model_path): | ||
super().__init__(model_path) | ||
assert os.path.isfile(model_path), str(model_path) | ||
mergeable_ranks = load_tiktoken_bpe(str(model_path)) | ||
num_base_tokens = len(mergeable_ranks) | ||
special_tokens = [ | ||
"<|begin_of_text|>", | ||
"<|end_of_text|>", | ||
"<|reserved_special_token_0|>", | ||
"<|reserved_special_token_1|>", | ||
"<|reserved_special_token_2|>", | ||
"<|reserved_special_token_3|>", | ||
"<|start_header_id|>", | ||
"<|end_header_id|>", | ||
"<|reserved_special_token_4|>", | ||
"<|eot_id|>", # end of turn | ||
] + [ | ||
f"<|reserved_special_token_{i}|>" | ||
for i in range(5, self.num_reserved_special_tokens - 5) | ||
] | ||
self.special_tokens = { | ||
token: num_base_tokens + i for i, token in enumerate(special_tokens) | ||
} | ||
self.model = tiktoken.Encoding( | ||
name=Path(model_path).name, | ||
pat_str=self.pat_str, | ||
mergeable_ranks=mergeable_ranks, | ||
special_tokens=self.special_tokens, | ||
) | ||
# BOS / EOS token IDs | ||
self._bos_id: int = self.special_tokens["<|begin_of_text|>"] | ||
self._eos_id: int = self.special_tokens["<|end_of_text|>"] | ||
|
||
def encode(self, text): | ||
return self.model.encode(text) | ||
|
||
def decode(self, tokens): | ||
return self.model.decode(tokens) | ||
|
||
def bos_id(self): | ||
return self._bos_id | ||
|
||
def eos_id(self): | ||
return self._eos_id | ||
|
||
def get_tokenizer(tokenizer_model_path, model_name): | ||
""" | ||
Factory function to get the appropriate tokenizer based on the model name. | ||
|
||
Args: | ||
- tokenizer_model_path (str): The file path to the tokenizer model. | ||
- model_name (str): The name of the model, used to determine the tokenizer type. | ||
|
||
Returns: | ||
- TokenizerInterface: An instance of a tokenizer. | ||
""" | ||
if "Llama-3" in str(model_name): | ||
return TiktokenWrapper(tokenizer_model_path) | ||
else: | ||
return SentencePieceWrapper(tokenizer_model_path) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
## gpt-fused | ||
|
||
A more handwritten version of [gpt-fast](https://github.com/pytorch-labs/gpt-fast)'s model.py for us to experiment with. | ||
|
||
Requires the use of gpt-fast to use, but is a drop-in replacement. | ||
|
||
To use it just set to PYTHONPATH environment variable to `PYTHONPATH=<path to ao repository>/ao/torchao/prototype/models/gpt_fused` and delete gpt-fast's model.py. | ||
|
||
For example | ||
|
||
``` | ||
PYTHONPATH=/home/cpuhrsch/local/ao/torchao/prototype/models/gpt_fused CUDA_VISIBLE_DEVICES=0 numactl --membind 0 --cpubind 0 python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's going on here lol, why do i need to set the python path? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So that the import statements in gpt-fast pick up on the location of |
||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wdym by more handwritten?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could use this to try various fused kernels (Triton or CUDA).