Skip to content

Commit 4c23893

Browse files
mikekgfbmalfet
authored andcommitted
user BuilderArgs and TokenizerArgs (#191)
* user BuilderArgs and TokenizerArgs * import Buolder Args and related into eval * use builder_args for draft model
1 parent 3f747e4 commit 4c23893

File tree

4 files changed

+172
-200
lines changed

4 files changed

+172
-200
lines changed

builder.py

Lines changed: 80 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,27 @@ class BuilderArgs:
3838
setup_caches: bool = False
3939
use_tp: bool = False
4040

41+
def __post_init__(self):
42+
if not (
43+
(self.checkpoint_path and self.checkpoint_path.is_file()) or
44+
(self.checkpoint_dir and self.checkpoint_path.is_dir()) or
45+
(self.gguf_path and self.gguf_path.is_file()) or
46+
(self.dso_path and Path(self.dso_path).is_file()) or
47+
(self.pte_path and Path(self.pte_path).is_file())
48+
):
49+
raise RuntimeError("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path")
50+
51+
if (self.dso_path and self.pte_path):
52+
raise RuntimeError("specify either DSO path or PTE path, but not both")
53+
54+
if (self.checkpoint_path and (self.dso_path or self.pte_path)):
55+
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")
56+
if (self.checkpoint_dir and (self.dso_path or self.pte_path)):
57+
print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified")
58+
if (self.gguf_path and (self.dso_path or self.pte_path)):
59+
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")
60+
61+
4162
@classmethod
4263
def from_args(cls, args): # -> BuilderArgs:
4364
return cls(
@@ -49,10 +70,22 @@ def from_args(cls, args): # -> BuilderArgs:
4970
dso_path = args.dso_path,
5071
pte_path = args.pte_path,
5172
device = args.device,
52-
precision = name_to_dtype(args.precision),
73+
precision = name_to_dtype(args.dtype),
5374
setup_caches = (args.output_dso_path or args.output_pte_path),
5475
use_tp = False,
5576
)
77+
78+
@classmethod
79+
def from_speculative_args(cls, args): # -> BuilderArgs:
80+
speculative_builder_args = BuilderArgs.from_args(args)
81+
# let's limit multi-checkpoint to checker
82+
speculative_builder_args.checkpoint_dir = None
83+
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
84+
speculative_builder_args.gguf_path = None
85+
speculative_builder_args.dso_path = None
86+
speculative_builder_args.pte_path = None
87+
return speculative_builder_args
88+
5689

5790
@dataclass
5891
class TokenizerArgs:
@@ -62,23 +95,23 @@ class TokenizerArgs:
6295

6396
@classmethod
6497
def from_args(cls, args): # -> TokenizerArgs:
65-
is_Sentencepiece = True
98+
is_SentencePiece = True
6699
is_TikToken = False
67100

68101
if args.tokenizer_path:
69102
tokenizer_path = args.tokenizer_path
70-
elif argscheckpoint_path:
103+
elif args.checkpoint_path:
71104
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
72-
elif checkpoint_dir:
105+
elif args.checkpoint_dir:
73106
tokenizer_path = args.checkpoint_dir / "tokenizer.model"
74107
else:
75108
raise RuntimeError(f"cannot find tokenizer model")
76109

77110
if not tokenizer_path.is_file():
78111
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
79112

80-
if args.toktoken:
81-
is_Sentencepiece = False
113+
if args.tiktoken:
114+
is_SentencePiece = False
82115
is_TikToken = True
83116

84117
return cls(
@@ -87,13 +120,13 @@ def from_args(cls, args): # -> TokenizerArgs:
87120
is_TikToken=is_TikToken
88121
)
89122

90-
def _initialize_tokenizer(config: TokenizerArgs):
91-
if is_SentencePiece:
92-
return SentencePieceProcessor(model_file=str(tokenizer_path))
93-
elif is_TikToken:
94-
raise RUntimeError("TikToken not implemented yet!")
123+
def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
124+
if tokenizer_args.is_SentencePiece:
125+
return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path))
126+
elif tokenizer_args.is_TikToken:
127+
raise RuntimeError("TikToken not implemented yet!")
95128
else:
96-
raise RUntimeError("must specify a valid tokenizer in TokenizerArgs")
129+
raise RuntimeError("must specify a valid tokenizer in TokenizerArgs")
97130

98131

99132
def device_sync(device):
@@ -115,38 +148,31 @@ def device_sync(device):
115148
sys.path.append(str(wd))
116149

117150
def _load_model(
118-
checkpoint_path,
119-
checkpoint_dir,
120-
params_path,
121-
params_table,
122-
gguf_path,
123-
device,
124-
precision,
125-
use_tp # =False
151+
builder_args
126152
):
127-
use_cuda = "cuda" in device
153+
use_cuda = "cuda" in builder_args.device
128154
with torch.device("meta"):
129-
if params_path:
130-
model = Transformer.from_params(params_path)
131-
elif params_table:
132-
model = Transformer.from_table(params_path)
133-
elif gguf_path:
134-
model = Transformer.from_gguf(gguf_path)
155+
if builder_args.params_path:
156+
model = Transformer.from_params(builder_args.params_path)
157+
elif builder_args.params_table:
158+
model = Transformer.from_table(builder_args.params_path)
159+
elif builder_args.gguf_path:
160+
model = Transformer.from_gguf(builder_args.gguf_path)
135161
else:
136-
model = Transformer.from_name(checkpoint_path.parent.name)
162+
model = Transformer.from_name(builder_args.checkpoint_path.parent.name)
137163

138-
# checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
164+
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
139165
cps = []
140-
if checkpoint_dir is not None:
166+
if builder_args.checkpoint_dir is not None:
141167
# Load multiple checkpoint; ignore the single path.
142-
checkpoint_path = None
168+
builder_args.checkpoint_path = None
143169
for i in range(4):
144170
cp_name = f"consolidated.{i}.pth"
145171
print(f"Loading {cp_name}")
146172
cps.append(
147173
torch.load(
148-
os.path.join(checkpoint_dir, cp_name),
149-
map_location=device,
174+
os.path.join(builder_args.checkpoint_dir, cp_name),
175+
map_location=builder_args.device,
150176
mmap=True,
151177
)
152178
)
@@ -162,69 +188,36 @@ def _load_model(
162188
else:
163189
checkpoint[key] = cps[0][key]
164190
else:
165-
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True, weights_only=True)
191+
checkpoint = torch.load(builder_args.checkpoint_path, map_location=builder_args.device, mmap=True, weights_only=True)
166192

167-
if "model" in checkpoint and "stories" in str(checkpoint_path):
193+
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
168194
checkpoint = checkpoint["model"]
169195

170196
model.load_state_dict(checkpoint, assign=True)
171197

172-
if use_tp:
198+
if builder_args.use_tp:
173199
from tp import apply_tp
174200

175201
print("Applying tensor parallel to model ...")
176202
apply_tp(model)
177203

178-
model = model.to(device=device, dtype=precision)
204+
model = model.to(device=builder_args.device, dtype=builder_args.precision)
179205
return model.eval()
180206

181207

182208
def _initialize_model(
183-
checkpoint_path,
184-
checkpoint_dir,
185-
params_path,
186-
params_table,
187-
gguf_path,
188-
dso_path,
189-
pte_path,
209+
builder_args,
190210
quantize,
191-
device,
192-
precision,
193-
setup_caches,
194-
use_tp # =False
195211
):
196-
assert (
197-
(checkpoint_path and checkpoint_path.is_file()) or
198-
(checkpoint_dir and checkpoint_path.is_dir()) or
199-
(gguf_path and gguf_path.is_file()) or
200-
(dso_path and Path(dso_path).is_file()) or
201-
(pte_path and Path(pte_path).is_file())
202-
), "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
203-
assert not (dso_path and pte_path), "specify either DSO path or PTE path, but not both"
204-
205-
if (checkpoint_path and (dso_path or pte_path)):
206-
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")
207-
if (checkpoint_dir and (dso_path or pte_path)):
208-
print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified")
209-
if (gguf_path and (dso_path or pte_path)):
210-
print("Warning: GGUF path ignored because an exported DSO or PTE path specified")
211-
212212
print("Loading model ...")
213213
t0 = time.time()
214214
model_ = _load_model(
215-
checkpoint_path,
216-
checkpoint_dir,
217-
params_path,
218-
params_table,
219-
gguf_path,
220-
device,
221-
precision,
222-
use_tp
215+
builder_args
223216
)
224-
device_sync(device=device) # MKG
217+
device_sync(device=builder_args.device)
225218
print(f"Time to load model: {time.time() - t0:.02f} seconds")
226219

227-
if dso_path:
220+
if builder_args.dso_path:
228221
# make sure user did not try to set dtype
229222
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
230223
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
@@ -236,33 +229,36 @@ def _initialize_model(
236229
# attributes will NOT be seen on by AOTI-compiled forward
237230
# function, e.g. calling model.setup_cache will NOT touch
238231
# AOTI compiled and maintained model buffers such as kv_cache.
239-
model.forward = torch._export.aot_load(str(dso_path.absolute()), device)
232+
model.forward = torch._export.aot_load(str(builder_args.dso_path.absolute()), builder_args.device)
240233
except:
241-
raise RuntimeError(f"Failed to load AOTI compiled {dso_path}")
242-
elif pte_path:
234+
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
235+
elif builder_args.pte_path:
243236
# make sure user did not try to set dtype
244237
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
245238
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
246239
try:
247240
from model_et import PTEModel
248-
model = PTEModel(model_.config, pte_path)
241+
model = PTEModel(model_.config, builder_args.pte_path)
249242
except Exception as e:
250-
raise RuntimeError(f"Failed to load ET compiled {pte_path}")
243+
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
251244
else:
252245
model = model_
253246

254247
if quantize:
255248
t0q = time.time()
256249
quantize_model(model, quantize)
257-
device_sync(device=device) # MKG
250+
device_sync(device=builder_args.device)
258251
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")
259252

260-
if setup_caches:
253+
if builder_args.setup_caches:
261254
max_seq_length = 350
262-
with torch.device(device):
263-
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
255+
with torch.device(builder_args.device):
256+
model.setup_caches(
257+
max_batch_size=1,
258+
max_seq_length=max_seq_length
259+
)
264260

265-
model.to(dtype=precision)
261+
model.to(dtype=builder_args.precision)
266262

267263
return model
268264

eval.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
except:
3232
lm_eval_available = False
3333

34-
from builder import _initialize_model
34+
from builder import _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs
3535
from generate import encode_tokens, model_forward
3636

3737
if lm_eval_available:
@@ -212,6 +212,9 @@ def eval_main(args) -> None:
212212
213213
"""
214214

215+
builder_args = BuilderArgs.from_args(args)
216+
tokenizer_args = TokenizerArgs.from_args(args)
217+
215218
checkpoint_path = args.checkpoint_path
216219
checkpoint_dir = args.checkpoint_dir
217220
params_path = args.params_path
@@ -228,34 +231,18 @@ def eval_main(args) -> None:
228231
max_seq_length = args.max_seq_length
229232
use_tiktoken = args.tiktoken
230233

231-
if not tokenizer_path:
232-
assert checkpoint_path, "either a tokenizer or a checkpoint path must be specified"
233-
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
234-
assert tokenizer_path.is_file(), tokenizer_path
235-
236234
print(f"Using device={device}")
237-
precision = name_to_dtype(model_dtype)
238-
set_precision(precision)
239-
235+
set_precision(buildeer_args.precision)
236+
237+
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
238+
builder_args.setup_caches = False
240239
model = _initialize_model(
241-
checkpoint_path,
242-
checkpoint_dir,
243-
params_path,
244-
params_table,
245-
gguf_path,
246-
dso_path,
247-
pte_path,
240+
buildeer_args,
248241
quantize,
249-
device,
250-
precision,
251-
setup_caches=False,
252-
use_tp=False
253242
)
254243

255-
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
256-
257244
if compile:
258-
assert not (dso_path or pte_path), "cannot compile exported model"
245+
assert not (builder_args.dso_path or builder_args.pte_path), "cannot compile exported model"
259246
global model_forward
260247
model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True)
261248
torch._inductor.config.coordinate_descent_tuning = True
@@ -270,13 +257,13 @@ def eval_main(args) -> None:
270257
)
271258
print(f"Time to run eval: {time.time() - t1:.02f} seconds.")
272259
if dso_path:
273-
print(f"For model {dso_path}")
260+
print(f"For model {builder_args.dso_path}")
274261
elif pte_path:
275-
print(f"For model {pte_path}")
262+
print(f"For model {builder_args.pte_path}")
276263
elif checkpoint_path:
277-
print(f"For model {checkpoint_path}")
264+
print(f"For model {builder_args.checkpoint_path}")
278265
elif checkpoint_dir:
279-
print(f"For model {checkpoint_dir}")
266+
print(f"For model {builder_args.checkpoint_dir}")
280267
else:
281268
raise RuntimeError("Well That's Fine. How did we get here")
282269

0 commit comments

Comments
 (0)