37
37
from jetstream_pt .config import FLAGS
38
38
from jetstream_pt .third_party .gemma import model as gemma_model
39
39
from jetstream_pt .third_party .llama import model_exportable as llama_model
40
+ from safetensors import safe_open
40
41
from safetensors .torch import save_file
41
42
42
43
_INPUT_CHECKPOINT_DIR = epath .DEFINE_path (
69
70
"When set to true, save to HugginFace SafeTensors format" ,
70
71
)
71
72
73
+ _FROM_HF = flags .DEFINE_bool (
74
+ "from_hf" ,
75
+ False ,
76
+ "Set to True if the input is a HuggingFace checkpoint." ,
77
+ )
78
+
72
79
73
80
def _find_scale_name (name , map ):
74
81
for key , val in map .items ():
@@ -252,7 +259,7 @@ def _load_from_gcs(input_ckpt_dir: epath.Path):
252
259
return checkpoints , params
253
260
254
261
255
- def _load_from_local (input_ckpt_dir : epath .Path ):
262
+ def _load_orig_llama_weight (input_ckpt_dir : epath .Path ):
256
263
checkpoints = []
257
264
params = json .loads ((input_ckpt_dir / "params.json" ).read_text ())
258
265
@@ -268,6 +275,84 @@ def _load_from_local(input_ckpt_dir: epath.Path):
268
275
return checkpoints , params
269
276
270
277
278
+ def _load_hf_llama_weight (input_ckpt_dir : epath .Path ):
279
+ print (f"Loading checkpoint files from { input_ckpt_dir } ." )
280
+ safetensors_files = input_ckpt_dir .glob ("*.safetensors" )
281
+ if len (list (safetensors_files )) == 0 :
282
+ raise ValueError (
283
+ f"No *.safetensors found in the input dir { input_ckpt_dir } "
284
+ )
285
+ checkpoint = {}
286
+ for st_f in safetensors_files :
287
+ with safe_open (st_f , framework = "pt" , device = "cpu" ) as f :
288
+ for key in f .keys ():
289
+ if "inv_freq" in key :
290
+ # Don't include 'rotary_emb.inv_freq' in the converted
291
+ # checkpoint, because in JetStream implementation we
292
+ # precompute it during weight loading.
293
+ continue
294
+ new_key = key
295
+ # Remove 'model.' prefix for all weights.
296
+ prefix_to_remove = "model."
297
+ if key .startswith (prefix_to_remove ):
298
+ new_key = new_key .removeprefix (prefix_to_remove )
299
+
300
+ # Weight name substring mapping between hf and jetstream.
301
+ _load_hf_llama_weight .hf_to_jetstream_keys_mapping = {
302
+ "lm_head" : "output" ,
303
+ "embed_tokens" : "tok_embeddings" ,
304
+ "input_layernorm" : "attention_norm" ,
305
+ "post_attention_layernorm" : "ffn_norm" ,
306
+ "self_attn.q_proj" : "attention.wq" ,
307
+ "self_attn.k_proj" : "attention.wk" ,
308
+ "self_attn.v_proj" : "attention.wv" ,
309
+ "self_attn.o_proj" : "attention.wo" ,
310
+ "mlp.gate_proj" : "feed_forward.w1" ,
311
+ "mlp.down_proj" : "feed_forward.w2" ,
312
+ "mlp.up_proj" : "feed_forward.w3" ,
313
+ "model.norm.weight" : "norm.weight" ,
314
+ }
315
+ found_substute = False
316
+ for (
317
+ hf_weight_key
318
+ ) in _load_hf_llama_weight .hf_to_jetstream_keys_mapping .keys ():
319
+ if hf_weight_key in key :
320
+ jet_stream_key = _load_hf_llama_weight .hf_to_jetstream_keys_mapping [
321
+ hf_weight_key
322
+ ]
323
+ new_key = new_key .replace (hf_weight_key , jet_stream_key )
324
+ found_substute = True
325
+ break
326
+ assert found_substute , f"No substitute name found for { key } ."
327
+ print (f"convert weight name { key } to { new_key } ." )
328
+ weight_tensor = f .get_tensor (key )
329
+ if weight_tensor .dtype == torch .float16 :
330
+ # JetStream expects bf16 weight, since activation is in bf16
331
+ # float16 x bf16 will hit mix precision assertion.
332
+ weight_tensor = weight_tensor .to (torch .bfloat16 )
333
+ print (f"convert weight name { new_key } from float16 to bfloat16." )
334
+ if "wq" in new_key or "wk" in new_key :
335
+ # In HF weight, wq and wk are interleaved differently
336
+ weight_shape = weight_tensor .shape
337
+ weight_tensor = (
338
+ weight_tensor .reshape (- 1 , 2 , 64 , weight_shape [1 ])
339
+ .transpose (1 , 2 )
340
+ .reshape (weight_shape )
341
+ )
342
+ checkpoint [new_key ] = weight_tensor
343
+ return [checkpoint ], None
344
+
345
+
346
+ def _load_from_local (input_ckpt_dir : epath .Path ):
347
+ if not _FROM_HF .value :
348
+ return _load_orig_llama_weight (input_ckpt_dir )
349
+ else :
350
+ assert (
351
+ not FLAGS .quantize_weights
352
+ ), "Quantization not supported for HF checkpoint."
353
+ return _load_hf_llama_weight (input_ckpt_dir )
354
+
355
+
271
356
def _export_to_gcs (output_ckpt_dir : epath .Path , params , state_dict ):
272
357
# pylint: disable-next=all
273
358
bucket_name , output_ckpt = str (output_ckpt_dir ).split ("//" )[- 1 ].split ("/" , 1 )
@@ -276,11 +361,12 @@ def _export_to_gcs(output_ckpt_dir: epath.Path, params, state_dict):
276
361
bucket = storage_client .bucket (bucket_name )
277
362
278
363
ckpt_blob = bucket .blob (os .path .join (output_ckpt , "consolidated.00.pth" ))
279
- param_blob = bucket .blob (os .path .join (output_ckpt , "params.json" ))
280
364
checklist_blob = bucket .blob (os .path .join (output_ckpt , "checklist.chk" ))
281
- with param_blob .open ("w" ) as f :
282
- f .write (json .dumps (params ))
283
- f .close ()
365
+ if params is not None :
366
+ param_blob = bucket .blob (os .path .join (output_ckpt , "params.json" ))
367
+ with param_blob .open ("w" ) as f :
368
+ f .write (json .dumps (params ))
369
+ f .close ()
284
370
with ckpt_blob .open ("w" ) as f :
285
371
torch .save (state_dict , f )
286
372
f .close ()
@@ -291,7 +377,8 @@ def _export_to_gcs(output_ckpt_dir: epath.Path, params, state_dict):
291
377
292
378
def _export_to_local (output_ckpt_dir : epath .Path , params , state_dict ):
293
379
output_ckpt_dir .mkdir (parents = True , exist_ok = True )
294
- (output_ckpt_dir / "params.json" ).write_text (json .dumps (params ))
380
+ if params is not None :
381
+ (output_ckpt_dir / "params.json" ).write_text (json .dumps (params ))
295
382
if _OUTPUT_SAFETENSORS .value :
296
383
# safetensors.torch.save_file expects tensor to be contiguous.
297
384
state_dict = pytree .tree_map_only (
0 commit comments