@@ -38,6 +38,27 @@ class BuilderArgs:
38
38
setup_caches : bool = False
39
39
use_tp : bool = False
40
40
41
+ def __post_init__ (self ):
42
+ if not (
43
+ (self .checkpoint_path and self .checkpoint_path .is_file ()) or
44
+ (self .checkpoint_dir and self .checkpoint_path .is_dir ()) or
45
+ (self .gguf_path and self .gguf_path .is_file ()) or
46
+ (self .dso_path and Path (self .dso_path ).is_file ()) or
47
+ (self .pte_path and Path (self .pte_path ).is_file ())
48
+ ):
49
+ raise RuntimeError ("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path" )
50
+
51
+ if (self .dso_path and self .pte_path ):
52
+ raise RuntimeError ("specify either DSO path or PTE path, but not both" )
53
+
54
+ if (self .checkpoint_path and (self .dso_path or self .pte_path )):
55
+ print ("Warning: checkpoint path ignored because an exported DSO or PTE path specified" )
56
+ if (self .checkpoint_dir and (self .dso_path or self .pte_path )):
57
+ print ("Warning: checkpoint dir ignored because an exported DSO or PTE path specified" )
58
+ if (self .gguf_path and (self .dso_path or self .pte_path )):
59
+ print ("Warning: GGUF path ignored because an exported DSO or PTE path specified" )
60
+
61
+
41
62
@classmethod
42
63
def from_args (cls , args ): # -> BuilderArgs:
43
64
return cls (
@@ -49,10 +70,22 @@ def from_args(cls, args): # -> BuilderArgs:
49
70
dso_path = args .dso_path ,
50
71
pte_path = args .pte_path ,
51
72
device = args .device ,
52
- precision = name_to_dtype (args .precision ),
73
+ precision = name_to_dtype (args .dtype ),
53
74
setup_caches = (args .output_dso_path or args .output_pte_path ),
54
75
use_tp = False ,
55
76
)
77
+
78
+ @classmethod
79
+ def from_speculative_args (cls , args ): # -> BuilderArgs:
80
+ speculative_builder_args = BuilderArgs .from_args (args )
81
+ # let's limit multi-checkpoint to checker
82
+ speculative_builder_args .checkpoint_dir = None
83
+ speculative_builder_args .checkpoint_path = args .draft_checkpoint_path
84
+ speculative_builder_args .gguf_path = None
85
+ speculative_builder_args .dso_path = None
86
+ speculative_builder_args .pte_path = None
87
+ return speculative_builder_args
88
+
56
89
57
90
@dataclass
58
91
class TokenizerArgs :
@@ -62,23 +95,23 @@ class TokenizerArgs:
62
95
63
96
@classmethod
64
97
def from_args (cls , args ): # -> TokenizerArgs:
65
- is_Sentencepiece = True
98
+ is_SentencePiece = True
66
99
is_TikToken = False
67
100
68
101
if args .tokenizer_path :
69
102
tokenizer_path = args .tokenizer_path
70
- elif argscheckpoint_path :
103
+ elif args . checkpoint_path :
71
104
tokenizer_path = args .checkpoint_path .parent / "tokenizer.model"
72
- elif checkpoint_dir :
105
+ elif args . checkpoint_dir :
73
106
tokenizer_path = args .checkpoint_dir / "tokenizer.model"
74
107
else :
75
108
raise RuntimeError (f"cannot find tokenizer model" )
76
109
77
110
if not tokenizer_path .is_file ():
78
111
raise RuntimeError (f"did not find tokenizer at { tokenizer_path } " )
79
112
80
- if args .toktoken :
81
- is_Sentencepiece = False
113
+ if args .tiktoken :
114
+ is_SentencePiece = False
82
115
is_TikToken = True
83
116
84
117
return cls (
@@ -87,13 +120,13 @@ def from_args(cls, args): # -> TokenizerArgs:
87
120
is_TikToken = is_TikToken
88
121
)
89
122
90
- def _initialize_tokenizer (config : TokenizerArgs ):
91
- if is_SentencePiece :
92
- return SentencePieceProcessor (model_file = str (tokenizer_path ))
93
- elif is_TikToken :
94
- raise RUntimeError ("TikToken not implemented yet!" )
123
+ def _initialize_tokenizer (tokenizer_args : TokenizerArgs ):
124
+ if tokenizer_args . is_SentencePiece :
125
+ return SentencePieceProcessor (model_file = str (tokenizer_args . tokenizer_path ))
126
+ elif tokenizer_args . is_TikToken :
127
+ raise RuntimeError ("TikToken not implemented yet!" )
95
128
else :
96
- raise RUntimeError ("must specify a valid tokenizer in TokenizerArgs" )
129
+ raise RuntimeError ("must specify a valid tokenizer in TokenizerArgs" )
97
130
98
131
99
132
def device_sync (device ):
@@ -115,38 +148,31 @@ def device_sync(device):
115
148
sys .path .append (str (wd ))
116
149
117
150
def _load_model (
118
- checkpoint_path ,
119
- checkpoint_dir ,
120
- params_path ,
121
- params_table ,
122
- gguf_path ,
123
- device ,
124
- precision ,
125
- use_tp # =False
151
+ builder_args
126
152
):
127
- use_cuda = "cuda" in device
153
+ use_cuda = "cuda" in builder_args . device
128
154
with torch .device ("meta" ):
129
- if params_path :
130
- model = Transformer .from_params (params_path )
131
- elif params_table :
132
- model = Transformer .from_table (params_path )
133
- elif gguf_path :
134
- model = Transformer .from_gguf (gguf_path )
155
+ if builder_args . params_path :
156
+ model = Transformer .from_params (builder_args . params_path )
157
+ elif builder_args . params_table :
158
+ model = Transformer .from_table (builder_args . params_path )
159
+ elif builder_args . gguf_path :
160
+ model = Transformer .from_gguf (builder_args . gguf_path )
135
161
else :
136
- model = Transformer .from_name (checkpoint_path .parent .name )
162
+ model = Transformer .from_name (builder_args . checkpoint_path .parent .name )
137
163
138
- # checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
164
+ # checkpoint = torch.load(str(builder_args. checkpoint_path), mmap=True, weights_only=True)
139
165
cps = []
140
- if checkpoint_dir is not None :
166
+ if builder_args . checkpoint_dir is not None :
141
167
# Load multiple checkpoint; ignore the single path.
142
- checkpoint_path = None
168
+ builder_args . checkpoint_path = None
143
169
for i in range (4 ):
144
170
cp_name = f"consolidated.{ i } .pth"
145
171
print (f"Loading { cp_name } " )
146
172
cps .append (
147
173
torch .load (
148
- os .path .join (checkpoint_dir , cp_name ),
149
- map_location = device ,
174
+ os .path .join (builder_args . checkpoint_dir , cp_name ),
175
+ map_location = builder_args . device ,
150
176
mmap = True ,
151
177
)
152
178
)
@@ -162,69 +188,36 @@ def _load_model(
162
188
else :
163
189
checkpoint [key ] = cps [0 ][key ]
164
190
else :
165
- checkpoint = torch .load (checkpoint_path , map_location = device , mmap = True , weights_only = True )
191
+ checkpoint = torch .load (builder_args . checkpoint_path , map_location = builder_args . device , mmap = True , weights_only = True )
166
192
167
- if "model" in checkpoint and "stories" in str (checkpoint_path ):
193
+ if "model" in checkpoint and "stories" in str (builder_args . checkpoint_path ):
168
194
checkpoint = checkpoint ["model" ]
169
195
170
196
model .load_state_dict (checkpoint , assign = True )
171
197
172
- if use_tp :
198
+ if builder_args . use_tp :
173
199
from tp import apply_tp
174
200
175
201
print ("Applying tensor parallel to model ..." )
176
202
apply_tp (model )
177
203
178
- model = model .to (device = device , dtype = precision )
204
+ model = model .to (device = builder_args . device , dtype = builder_args . precision )
179
205
return model .eval ()
180
206
181
207
182
208
def _initialize_model (
183
- checkpoint_path ,
184
- checkpoint_dir ,
185
- params_path ,
186
- params_table ,
187
- gguf_path ,
188
- dso_path ,
189
- pte_path ,
209
+ builder_args ,
190
210
quantize ,
191
- device ,
192
- precision ,
193
- setup_caches ,
194
- use_tp # =False
195
211
):
196
- assert (
197
- (checkpoint_path and checkpoint_path .is_file ()) or
198
- (checkpoint_dir and checkpoint_path .is_dir ()) or
199
- (gguf_path and gguf_path .is_file ()) or
200
- (dso_path and Path (dso_path ).is_file ()) or
201
- (pte_path and Path (pte_path ).is_file ())
202
- ), "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
203
- assert not (dso_path and pte_path ), "specify either DSO path or PTE path, but not both"
204
-
205
- if (checkpoint_path and (dso_path or pte_path )):
206
- print ("Warning: checkpoint path ignored because an exported DSO or PTE path specified" )
207
- if (checkpoint_dir and (dso_path or pte_path )):
208
- print ("Warning: checkpoint dir ignored because an exported DSO or PTE path specified" )
209
- if (gguf_path and (dso_path or pte_path )):
210
- print ("Warning: GGUF path ignored because an exported DSO or PTE path specified" )
211
-
212
212
print ("Loading model ..." )
213
213
t0 = time .time ()
214
214
model_ = _load_model (
215
- checkpoint_path ,
216
- checkpoint_dir ,
217
- params_path ,
218
- params_table ,
219
- gguf_path ,
220
- device ,
221
- precision ,
222
- use_tp
215
+ builder_args
223
216
)
224
- device_sync (device = device ) # MKG
217
+ device_sync (device = builder_args . device )
225
218
print (f"Time to load model: { time .time () - t0 :.02f} seconds" )
226
219
227
- if dso_path :
220
+ if builder_args . dso_path :
228
221
# make sure user did not try to set dtype
229
222
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
230
223
assert quantize is None or quantize == "{ }" , f"quantize not valid for exported DSO model. Specify quantization during export."
@@ -236,33 +229,36 @@ def _initialize_model(
236
229
# attributes will NOT be seen on by AOTI-compiled forward
237
230
# function, e.g. calling model.setup_cache will NOT touch
238
231
# AOTI compiled and maintained model buffers such as kv_cache.
239
- model .forward = torch ._export .aot_load (str (dso_path .absolute ()), device )
232
+ model .forward = torch ._export .aot_load (str (builder_args . dso_path .absolute ()), builder_args . device )
240
233
except :
241
- raise RuntimeError (f"Failed to load AOTI compiled { dso_path } " )
242
- elif pte_path :
234
+ raise RuntimeError (f"Failed to load AOTI compiled { builder_args . dso_path } " )
235
+ elif builder_args . pte_path :
243
236
# make sure user did not try to set dtype
244
237
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
245
238
assert quantize is None or quantize == "{ }" , f"quantize not valid for exported PTE model. Specify quantization during export."
246
239
try :
247
240
from model_et import PTEModel
248
- model = PTEModel (model_ .config , pte_path )
241
+ model = PTEModel (model_ .config , builder_args . pte_path )
249
242
except Exception as e :
250
- raise RuntimeError (f"Failed to load ET compiled { pte_path } " )
243
+ raise RuntimeError (f"Failed to load ET compiled { builder_args . pte_path } " )
251
244
else :
252
245
model = model_
253
246
254
247
if quantize :
255
248
t0q = time .time ()
256
249
quantize_model (model , quantize )
257
- device_sync (device = device ) # MKG
250
+ device_sync (device = builder_args . device )
258
251
print (f"Time to quantize model: { time .time () - t0q :.02f} seconds" )
259
252
260
- if setup_caches :
253
+ if builder_args . setup_caches :
261
254
max_seq_length = 350
262
- with torch .device (device ):
263
- model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
255
+ with torch .device (builder_args .device ):
256
+ model .setup_caches (
257
+ max_batch_size = 1 ,
258
+ max_seq_length = max_seq_length
259
+ )
264
260
265
- model .to (dtype = precision )
261
+ model .to (dtype = builder_args . precision )
266
262
267
263
return model
268
264
0 commit comments