Skip to content

[WIP] gpt_fused #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 253 additions & 0 deletions benchmarks/gpt_fused/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import sys
import time
from pathlib import Path
from typing import Optional, Tuple

import torch
import torch._dynamo.config
import torch._inductor.config

from torchao.prototype.models.gpt_fused.model import Transformer
from tokenizer import get_tokenizer

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True

default_device = 'cuda' if torch.cuda.is_available() else 'cpu'

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))


def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
logits = logits / max(temperature, 1e-5)

if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[0, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
# input_pos: [B, S]
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)[0]

@torch.compile(mode='max-autotune', fullgraph=True)
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)

def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, **sampling_kwargs):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=True): # Actually better for Inductor to codegen attention here
# with torch.autograd.profiler.record_function(f"generate token {i}"):
# torch.cuda.synchronize()
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)
# torch.cuda.synchronize()

return new_tokens, new_probs


def model_forward(model, x, input_pos):
return model(x, input_pos)

@torch.no_grad()
def generate(
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
**sampling_kwargs
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""

# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(0)
T_new = T + max_new_tokens
max_seq_length = min(T_new, model.config.block_size)

device, dtype = prompt.device, prompt.dtype
max_seq_length = max_seq_length
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

from torchao.quantization import apply_weight_only_int8_quant
apply_weight_only_int8_quant(model)

# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = prompt
seq = empty
input_pos = torch.arange(0, T, device=device)

next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
seq[T] = next_token

input_pos = torch.tensor([T], device=device, dtype=torch.int)
torch.cuda.synchronize(device)
t0 = time.perf_counter()
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs)
seq[T + 1:] = torch.cat(generated_tokens)
torch.cuda.synchronize(device)
t = time.perf_counter() - t0

return seq, t

def encode_tokens(tokenizer, string, bos=True, device=default_device):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)

def _load_model(checkpoint_path, device, precision):
use_cuda = 'cuda' in device
with torch.device('meta'):
model = Transformer.from_name(checkpoint_path.parent.name)

if "int8" in str(checkpoint_path):
print("Using int8 weight-only quantization!")
from quantize import WeightOnlyInt8QuantHandler
simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()

if "int4" in str(checkpoint_path):
print("Using int4 weight-only quantization!")
path_comps = checkpoint_path.name.split(".")
assert path_comps[-3].startswith("g")
assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!"
groupsize = int(path_comps[-3][1:])
from quantize import WeightOnlyInt4QuantHandler
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime(use_cuda)

checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
if "model" in checkpoint and "stories" in str(checkpoint_path):
checkpoint = checkpoint["model"]
model.load_state_dict(checkpoint, assign=True)

model = model.to(device=device, dtype=precision)
return model.eval()

B_INST, E_INST = "[INST]", "[/INST]"

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result

def main(
prompt: str = "Hello, my name is",
num_samples: int = 5,
max_new_tokens: int = 100,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
device=default_device,
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
"""
assert checkpoint_path.is_file(), checkpoint_path

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)

print(f"Using device={device}")
precision = torch.bfloat16

print("Loading model ...")
t0 = time.time()
model = _load_model(checkpoint_path, device, precision)

torch.cuda.synchronize(device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")

tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
prompt_length = encoded.size(0)

torch.manual_seed(1234)
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])

aggregate_metrics = {
'tokens_per_sec': [],
}

for i in range(num_samples):
with torch.autograd.profiler.record_function(f"timed region for inference {i}"):
y, t = generate(
model,
encoded,
max_new_tokens,
temperature=temperature,
top_k=top_k,
)
if i == 0:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")

# print(tokenizer.decode(y.tolist()))
tokens_generated = y.size(0) - prompt_length
tokens_sec = tokens_generated / t
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
if i > 0:
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
else:
print("Don't count first inference run.")
print("==========")
print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Your CLI description.')

parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('--device', type=str, default=default_device, help='Device to use')

args = parser.parse_args()
# profiler_runner("profile.json.gz", main,
main(
args.prompt, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path,
args.device
)
111 changes: 111 additions & 0 deletions benchmarks/gpt_fused/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import sentencepiece as spm
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from pathlib import Path
from typing import Dict

class TokenizerInterface:
def __init__(self, model_path):
self.model_path = model_path

def encode(self, text):
raise NotImplementedError("This method should be overridden by subclasses.")

def decode(self, tokens):
raise NotImplementedError("This method should be overridden by subclasses.")

def bos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")

def eos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")

class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
self.processor = spm.SentencePieceProcessor(str(model_path))

def encode(self, text):
return self.processor.EncodeAsIds(text)

def decode(self, tokens):
return self.processor.DecodeIds(tokens)

def bos_id(self):
return self.processor.bos_id()

def eos_id(self):
return self.processor.eos_id()

class TiktokenWrapper(TokenizerInterface):
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""

special_tokens: Dict[str, int]

num_reserved_special_tokens = 256

pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501

def __init__(self, model_path):
super().__init__(model_path)
assert os.path.isfile(model_path), str(model_path)
mergeable_ranks = load_tiktoken_bpe(str(model_path))
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [
f"<|reserved_special_token_{i}|>"
for i in range(5, self.num_reserved_special_tokens - 5)
]
self.special_tokens = {
token: num_base_tokens + i for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
# BOS / EOS token IDs
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
self._eos_id: int = self.special_tokens["<|end_of_text|>"]

def encode(self, text):
return self.model.encode(text)

def decode(self, tokens):
return self.model.decode(tokens)

def bos_id(self):
return self._bos_id

def eos_id(self):
return self._eos_id

def get_tokenizer(tokenizer_model_path, model_name):
"""
Factory function to get the appropriate tokenizer based on the model name.

Args:
- tokenizer_model_path (str): The file path to the tokenizer model.
- model_name (str): The name of the model, used to determine the tokenizer type.

Returns:
- TokenizerInterface: An instance of a tokenizer.
"""
if "Llama-3" in str(model_name):
return TiktokenWrapper(tokenizer_model_path)
else:
return SentencePieceWrapper(tokenizer_model_path)
13 changes: 13 additions & 0 deletions torchao/prototype/models/gpt_fused/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## gpt-fused

A more handwritten version of [gpt-fast](https://github.com/pytorch-labs/gpt-fast)'s model.py for us to experiment with.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdym by more handwritten?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use this to try various fused kernels (Triton or CUDA).


Requires the use of gpt-fast to use, but is a drop-in replacement.

To use it just set to PYTHONPATH environment variable to `PYTHONPATH=<path to ao repository>/ao/torchao/prototype/models/gpt_fused` and delete gpt-fast's model.py.

For example

```
PYTHONPATH=/home/cpuhrsch/local/ao/torchao/prototype/models/gpt_fused CUDA_VISIBLE_DEVICES=0 numactl --membind 0 --cpubind 0 python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's going on here lol, why do i need to set the python path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that the import statements in gpt-fast pick up on the location of model.py in torchao

```
Loading
Loading