@@ -278,7 +278,7 @@ def _load_orig_llama_weight(input_ckpt_dir: epath.Path):
278
278
279
279
def _load_hf_llama_weight (input_ckpt_dir : epath .Path ):
280
280
print (f"Loading checkpoint files from { input_ckpt_dir } ." )
281
- safetensors_files = input_ckpt_dir .glob ("*.safetensors" )
281
+ safetensors_files = list ( input_ckpt_dir .glob ("*.safetensors" ) )
282
282
if len (list (safetensors_files )) == 0 :
283
283
raise ValueError (
284
284
f"No *.safetensors found in the input dir { input_ckpt_dir } "
@@ -419,14 +419,22 @@ def _get_llama_state_dict(input_ckpt_dir):
419
419
return state_dict , params
420
420
421
421
422
+ def fix_json (text ):
423
+ text = text .replace ("'" , '"' )
424
+ lines = text .split ("\n " )
425
+ lines [- 3 ] = lines [- 3 ].replace ("," , "" )
426
+ return "\n " .join (lines )
427
+
428
+
422
429
def _get_gemma_state_dict (input_ckpt_dir ):
423
430
ckpt_file = list (input_ckpt_dir .glob ("*.ckpt" ))
424
431
assert len (ckpt_file ) == 1 , "only expect 1 ckpt file for Gemma model."
425
432
ckpt_file = ckpt_file [0 ]
426
433
state_dict = torch .load (str (ckpt_file ), map_location = torch .device ("cpu" ))[
427
434
"model_state_dict"
428
435
]
429
- model_config = json .loads ((input_ckpt_dir / "config.json" ).read_text ())
436
+ config_text = fix_json ((input_ckpt_dir / "config.json" ).read_text ())
437
+ model_config = json .loads (config_text )
430
438
for key in list (state_dict .keys ()):
431
439
if state_dict [key ].dtype .is_complex and _OUTPUT_SAFETENSORS .value :
432
440
assert (
0 commit comments