Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 57 additions & 59 deletions builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,27 @@ class BuilderArgs:
setup_caches: bool = False
use_tp: bool = False

def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file()) or
(self.checkpoint_dir and self.checkpoint_path.is_dir()) or
(self.gguf_path and self.gguf_path.is_file()) or
(self.dso_path and Path(self.dso_path).is_file()) or
(self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path")

if (self.dso_path and self.pte_path):
raise RuntimeError("specify either DSO path or PTE path, but not both")

if (self.checkpoint_path and (self.dso_path or self.pte_path)):
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")
if (self.checkpoint_dir and (self.dso_path or self.pte_path)):
print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified")
if (self.gguf_path and (self.dso_path or self.pte_path)):
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")


@classmethod
def from_args(cls, args): # -> BuilderArgs:
return cls(
Expand All @@ -49,7 +70,7 @@ def from_args(cls, args): # -> BuilderArgs:
dso_path = args.dso_path,
pte_path = args.pte_path,
device = args.device,
precision = name_to_dtype(args.precision),
precision = name_to_dtype(args.dtype),
setup_caches = (args.output_dso_path or args.output_pte_path),
use_tp = False,
)
Expand All @@ -62,23 +83,23 @@ class TokenizerArgs:

@classmethod
def from_args(cls, args): # -> TokenizerArgs:
is_Sentencepiece = True
is_SentencePiece = True
is_TikToken = False

if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
elif argscheckpoint_path:
elif args.checkpoint_path:
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
elif checkpoint_dir:
elif args.checkpoint_dir:
tokenizer_path = args.checkpoint_dir / "tokenizer.model"
else:
raise RuntimeError(f"cannot find tokenizer model")

if not tokenizer_path.is_file():
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")

if args.toktoken:
is_Sentencepiece = False
if args.tiktoken:
is_SentencePiece = False
is_TikToken = True

return cls(
Expand All @@ -87,13 +108,13 @@ def from_args(cls, args): # -> TokenizerArgs:
is_TikToken=is_TikToken
)

def _initialize_tokenizer(config: TokenizerArgs):
if is_SentencePiece:
return SentencePieceProcessor(model_file=str(tokenizer_path))
elif is_TikToken:
raise RUntimeError("TikToken not implemented yet!")
def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
if tokenizer_args.is_SentencePiece:
return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path))
elif tokenizer_args.is_TikToken:
raise RuntimeError("TikToken not implemented yet!")
else:
raise RUntimeError("must specify a valid tokenizer in TokenizerArgs")
raise RuntimeError("must specify a valid tokenizer in TokenizerArgs")


def device_sync(device):
Expand Down Expand Up @@ -180,51 +201,25 @@ def _load_model(


def _initialize_model(
checkpoint_path,
checkpoint_dir,
params_path,
params_table,
gguf_path,
dso_path,
pte_path,
builder_args,
quantize,
device,
precision,
setup_caches,
use_tp # =False
):
assert (
(checkpoint_path and checkpoint_path.is_file()) or
(checkpoint_dir and checkpoint_path.is_dir()) or
(gguf_path and gguf_path.is_file()) or
(dso_path and Path(dso_path).is_file()) or
(pte_path and Path(pte_path).is_file())
), "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
assert not (dso_path and pte_path), "specify either DSO path or PTE path, but not both"

if (checkpoint_path and (dso_path or pte_path)):
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")
if (checkpoint_dir and (dso_path or pte_path)):
print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified")
if (gguf_path and (dso_path or pte_path)):
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")

print("Loading model ...")
t0 = time.time()
model_ = _load_model(
checkpoint_path,
checkpoint_dir,
params_path,
params_table,
gguf_path,
device,
precision,
use_tp
builder_args.checkpoint_path,
builder_args.checkpoint_dir,
builder_args.params_path,
builder_args.params_table,
builder_args.gguf_path,
builder_args.device,
builder_args.precision,
builder_args.use_tp
)
device_sync(device=device) # MKG
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")

if dso_path:
if builder_args.dso_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
Expand All @@ -236,33 +231,36 @@ def _initialize_model(
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
model.forward = torch._export.aot_load(str(dso_path.absolute()), device)
model.forward = torch._export.aot_load(str(builder_args.dso_path.absolute()), builder_args.device)
except:
raise RuntimeError(f"Failed to load AOTI compiled {dso_path}")
elif pte_path:
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
elif builder_args.pte_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
try:
from model_et import PTEModel
model = PTEModel(model_.config, pte_path)
model = PTEModel(model_.config, builder_args.pte_path)
except Exception as e:
raise RuntimeError(f"Failed to load ET compiled {pte_path}")
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
else:
model = model_

if quantize:
t0q = time.time()
quantize_model(model, quantize)
device_sync(device=device) # MKG
device_sync(device=builder_args.device)
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")

if setup_caches:
if builder_args.setup_caches:
max_seq_length = 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length
)

model.to(dtype=precision)
model.to(dtype=builder_args.precision)

return model

Expand Down
39 changes: 13 additions & 26 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def eval_main(args) -> None:

"""

builder_args = BuilderArgs.from_args(args)
tokenizer_args = TokenizerArgs.from_args(args)

checkpoint_path = args.checkpoint_path
checkpoint_dir = args.checkpoint_dir
params_path = args.params_path
Expand All @@ -228,34 +231,18 @@ def eval_main(args) -> None:
max_seq_length = args.max_seq_length
use_tiktoken = args.tiktoken

if not tokenizer_path:
assert checkpoint_path, "either a tokenizer or a checkpoint path must be specified"
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path

print(f"Using device={device}")
precision = name_to_dtype(model_dtype)
set_precision(precision)

set_precision(buildeer_args.precision)

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
builder_args.setup_caches = False
model = _initialize_model(
checkpoint_path,
checkpoint_dir,
params_path,
params_table,
gguf_path,
dso_path,
pte_path,
buildeer_args,
quantize,
device,
precision,
setup_caches=False,
use_tp=False
)

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))

if compile:
assert not (dso_path or pte_path), "cannot compile exported model"
assert not (builder_args.dso_path or builder_args.pte_path), "cannot compile exported model"
global model_forward
model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True)
torch._inductor.config.coordinate_descent_tuning = True
Expand All @@ -270,13 +257,13 @@ def eval_main(args) -> None:
)
print(f"Time to run eval: {time.time() - t1:.02f} seconds.")
if dso_path:
print(f"For model {dso_path}")
print(f"For model {builder_args.dso_path}")
elif pte_path:
print(f"For model {pte_path}")
print(f"For model {builder_args.pte_path}")
elif checkpoint_path:
print(f"For model {checkpoint_path}")
print(f"For model {builder_args.checkpoint_path}")
elif checkpoint_dir:
print(f"For model {checkpoint_dir}")
print(f"For model {builder_args.checkpoint_dir}")
else:
raise RuntimeError("Well That's Fine. How did we get here")

Expand Down
39 changes: 13 additions & 26 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
executorch_export_available = True
from export_et import export_model as export_model_et
except Exception as e:
print("ET EXPORT EXCEPTION: ", e) # TODO: remove
executorch_exception = f"ET EXPORT EXCEPTION: {e}"
executorch_export_available = False

from export_aoti import export_model as export_model_aoti

from model import Transformer
from builder import _initialize_model
from builder import _initialize_model, BuilderArgs, TokenizerArgs
from generate import decode_one_token
from quantize import quantize_model, name_to_dtype
from torch._export import capture_pre_autograd_graph
Expand All @@ -44,35 +44,21 @@ def device_sync(device):


def main(args):
checkpoint_path = args.checkpoint_path
device = args.device
builder_args = BuilderArgs.from_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
quantize = args.quantize

assert checkpoint_path.is_file(), checkpoint_path
print(f"Using device={builder_args.device}")
set_precision(builder_args.precision)

print(f"Using device={device}")
precision = name_to_dtype(args.dtype) # torch.float # bfloat16
set_precision(precision)

builder_args.dso_path = None
builder_args.pte_path = None
builder_args.setup_caches = True
model = _initialize_model(
args.checkpoint_path,
args.checkpoint_dir,
args.params_path,
args.params_table,
args.gguf_path,
None, # dso_path - cannot re-export exported model
None, # pte_path - cannot re-export exported model
builder_args,
quantize,
device,
precision,
setup_caches=True,
use_tp=False
)

# dtype:
# if args.dtype:
# model.to(dtype=name_to_dtype(args.dtype))

output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path

Expand All @@ -82,13 +68,14 @@ def main(args):
print(f">{output_pte_path}<")
if executorch_export_available:
print(f"Exporting model using Executorch to {output_pte_path}")
export_model_et(model, device, args.output_pte_path, args)
export_model_et(model, builder_args.device, args.output_pte_path, args)
else:
print(f"Export with executorch requested but Executorch could not be loaded")
print(executorch_exception)
if output_dso_path:
output_dso_path = str(os.path.abspath(output_dso_path))
print(f"Exporting model using AOT Inductor to {output_dso_path}")
export_model_aoti(model, device, output_dso_path, args)
export_model_aoti(model, builder_args.device, output_dso_path, args)


def cli():
Expand Down
Loading