Skip to content

user BuilderArgs and TokenizerArgs #191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 80 additions & 84 deletions builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,27 @@ class BuilderArgs:
setup_caches: bool = False
use_tp: bool = False

def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file()) or
(self.checkpoint_dir and self.checkpoint_path.is_dir()) or
(self.gguf_path and self.gguf_path.is_file()) or
(self.dso_path and Path(self.dso_path).is_file()) or
(self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path")

if (self.dso_path and self.pte_path):
raise RuntimeError("specify either DSO path or PTE path, but not both")

if (self.checkpoint_path and (self.dso_path or self.pte_path)):
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")
if (self.checkpoint_dir and (self.dso_path or self.pte_path)):
print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified")
if (self.gguf_path and (self.dso_path or self.pte_path)):
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")


@classmethod
def from_args(cls, args): # -> BuilderArgs:
return cls(
Expand All @@ -49,10 +70,22 @@ def from_args(cls, args): # -> BuilderArgs:
dso_path = args.dso_path,
pte_path = args.pte_path,
device = args.device,
precision = name_to_dtype(args.precision),
precision = name_to_dtype(args.dtype),
setup_caches = (args.output_dso_path or args.output_pte_path),
use_tp = False,
)

@classmethod
def from_speculative_args(cls, args): # -> BuilderArgs:
speculative_builder_args = BuilderArgs.from_args(args)
# let's limit multi-checkpoint to checker
speculative_builder_args.checkpoint_dir = None
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
speculative_builder_args.gguf_path = None
speculative_builder_args.dso_path = None
speculative_builder_args.pte_path = None
return speculative_builder_args


@dataclass
class TokenizerArgs:
Expand All @@ -62,23 +95,23 @@ class TokenizerArgs:

@classmethod
def from_args(cls, args): # -> TokenizerArgs:
is_Sentencepiece = True
is_SentencePiece = True
is_TikToken = False

if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
elif argscheckpoint_path:
elif args.checkpoint_path:
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
elif checkpoint_dir:
elif args.checkpoint_dir:
tokenizer_path = args.checkpoint_dir / "tokenizer.model"
else:
raise RuntimeError(f"cannot find tokenizer model")

if not tokenizer_path.is_file():
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")

if args.toktoken:
is_Sentencepiece = False
if args.tiktoken:
is_SentencePiece = False
is_TikToken = True

return cls(
Expand All @@ -87,13 +120,13 @@ def from_args(cls, args): # -> TokenizerArgs:
is_TikToken=is_TikToken
)

def _initialize_tokenizer(config: TokenizerArgs):
if is_SentencePiece:
return SentencePieceProcessor(model_file=str(tokenizer_path))
elif is_TikToken:
raise RUntimeError("TikToken not implemented yet!")
def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
if tokenizer_args.is_SentencePiece:
return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path))
elif tokenizer_args.is_TikToken:
raise RuntimeError("TikToken not implemented yet!")
else:
raise RUntimeError("must specify a valid tokenizer in TokenizerArgs")
raise RuntimeError("must specify a valid tokenizer in TokenizerArgs")


def device_sync(device):
Expand All @@ -115,38 +148,31 @@ def device_sync(device):
sys.path.append(str(wd))

def _load_model(
checkpoint_path,
checkpoint_dir,
params_path,
params_table,
gguf_path,
device,
precision,
use_tp # =False
builder_args
):
use_cuda = "cuda" in device
use_cuda = "cuda" in builder_args.device
with torch.device("meta"):
if params_path:
model = Transformer.from_params(params_path)
elif params_table:
model = Transformer.from_table(params_path)
elif gguf_path:
model = Transformer.from_gguf(gguf_path)
if builder_args.params_path:
model = Transformer.from_params(builder_args.params_path)
elif builder_args.params_table:
model = Transformer.from_table(builder_args.params_path)
elif builder_args.gguf_path:
model = Transformer.from_gguf(builder_args.gguf_path)
else:
model = Transformer.from_name(checkpoint_path.parent.name)
model = Transformer.from_name(builder_args.checkpoint_path.parent.name)

# checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
cps = []
if checkpoint_dir is not None:
if builder_args.checkpoint_dir is not None:
# Load multiple checkpoint; ignore the single path.
checkpoint_path = None
builder_args.checkpoint_path = None
for i in range(4):
cp_name = f"consolidated.{i}.pth"
print(f"Loading {cp_name}")
cps.append(
torch.load(
os.path.join(checkpoint_dir, cp_name),
map_location=device,
os.path.join(builder_args.checkpoint_dir, cp_name),
map_location=builder_args.device,
mmap=True,
)
)
Expand All @@ -162,69 +188,36 @@ def _load_model(
else:
checkpoint[key] = cps[0][key]
else:
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True, weights_only=True)
checkpoint = torch.load(builder_args.checkpoint_path, map_location=builder_args.device, mmap=True, weights_only=True)

if "model" in checkpoint and "stories" in str(checkpoint_path):
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
checkpoint = checkpoint["model"]

model.load_state_dict(checkpoint, assign=True)

if use_tp:
if builder_args.use_tp:
from tp import apply_tp

print("Applying tensor parallel to model ...")
apply_tp(model)

model = model.to(device=device, dtype=precision)
model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()


def _initialize_model(
checkpoint_path,
checkpoint_dir,
params_path,
params_table,
gguf_path,
dso_path,
pte_path,
builder_args,
quantize,
device,
precision,
setup_caches,
use_tp # =False
):
assert (
(checkpoint_path and checkpoint_path.is_file()) or
(checkpoint_dir and checkpoint_path.is_dir()) or
(gguf_path and gguf_path.is_file()) or
(dso_path and Path(dso_path).is_file()) or
(pte_path and Path(pte_path).is_file())
), "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
assert not (dso_path and pte_path), "specify either DSO path or PTE path, but not both"

if (checkpoint_path and (dso_path or pte_path)):
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")
if (checkpoint_dir and (dso_path or pte_path)):
print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified")
if (gguf_path and (dso_path or pte_path)):
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")

print("Loading model ...")
t0 = time.time()
model_ = _load_model(
checkpoint_path,
checkpoint_dir,
params_path,
params_table,
gguf_path,
device,
precision,
use_tp
builder_args
)
device_sync(device=device) # MKG
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")

if dso_path:
if builder_args.dso_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
Expand All @@ -236,33 +229,36 @@ def _initialize_model(
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
model.forward = torch._export.aot_load(str(dso_path.absolute()), device)
model.forward = torch._export.aot_load(str(builder_args.dso_path.absolute()), builder_args.device)
except:
raise RuntimeError(f"Failed to load AOTI compiled {dso_path}")
elif pte_path:
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
elif builder_args.pte_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
try:
from model_et import PTEModel
model = PTEModel(model_.config, pte_path)
model = PTEModel(model_.config, builder_args.pte_path)
except Exception as e:
raise RuntimeError(f"Failed to load ET compiled {pte_path}")
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
else:
model = model_

if quantize:
t0q = time.time()
quantize_model(model, quantize)
device_sync(device=device) # MKG
device_sync(device=builder_args.device)
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")

if setup_caches:
if builder_args.setup_caches:
max_seq_length = 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length
)

model.to(dtype=precision)
model.to(dtype=builder_args.precision)

return model

Expand Down
41 changes: 14 additions & 27 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
except:
lm_eval_available = False

from builder import _initialize_model
from builder import _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs
from generate import encode_tokens, model_forward

if lm_eval_available:
Expand Down Expand Up @@ -212,6 +212,9 @@ def eval_main(args) -> None:

"""

builder_args = BuilderArgs.from_args(args)
tokenizer_args = TokenizerArgs.from_args(args)

checkpoint_path = args.checkpoint_path
checkpoint_dir = args.checkpoint_dir
params_path = args.params_path
Expand All @@ -228,34 +231,18 @@ def eval_main(args) -> None:
max_seq_length = args.max_seq_length
use_tiktoken = args.tiktoken

if not tokenizer_path:
assert checkpoint_path, "either a tokenizer or a checkpoint path must be specified"
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path

print(f"Using device={device}")
precision = name_to_dtype(model_dtype)
set_precision(precision)

set_precision(buildeer_args.precision)

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
builder_args.setup_caches = False
model = _initialize_model(
checkpoint_path,
checkpoint_dir,
params_path,
params_table,
gguf_path,
dso_path,
pte_path,
buildeer_args,
quantize,
device,
precision,
setup_caches=False,
use_tp=False
)

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))

if compile:
assert not (dso_path or pte_path), "cannot compile exported model"
assert not (builder_args.dso_path or builder_args.pte_path), "cannot compile exported model"
global model_forward
model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True)
torch._inductor.config.coordinate_descent_tuning = True
Expand All @@ -270,13 +257,13 @@ def eval_main(args) -> None:
)
print(f"Time to run eval: {time.time() - t1:.02f} seconds.")
if dso_path:
print(f"For model {dso_path}")
print(f"For model {builder_args.dso_path}")
elif pte_path:
print(f"For model {pte_path}")
print(f"For model {builder_args.pte_path}")
elif checkpoint_path:
print(f"For model {checkpoint_path}")
print(f"For model {builder_args.checkpoint_path}")
elif checkpoint_dir:
print(f"For model {checkpoint_dir}")
print(f"For model {builder_args.checkpoint_dir}")
else:
raise RuntimeError("Well That's Fine. How did we get here")

Expand Down
Loading