@@ -278,7 +278,7 @@ def _load_orig_llama_weight(input_ckpt_dir: epath.Path):
278
278
279
279
def _load_hf_llama_weight (input_ckpt_dir : epath .Path ):
280
280
print (f"Loading checkpoint files from { input_ckpt_dir } ." )
281
- safetensors_files = input_ckpt_dir .glob ("*.safetensors" )
281
+ safetensors_files = list ( input_ckpt_dir .glob ("*.safetensors" ) )
282
282
if len (list (safetensors_files )) == 0 :
283
283
raise ValueError (
284
284
f"No *.safetensors found in the input dir { input_ckpt_dir } "
@@ -418,6 +418,12 @@ def _get_llama_state_dict(input_ckpt_dir):
418
418
print (f"Merging weights takes { end - start } seconds" )
419
419
return state_dict , params
420
420
421
+ def fix_json (text ):
422
+ text = text .replace ("'" , '"' )
423
+ lines = text .split ('\n ' )
424
+ lines [- 3 ] = lines [- 3 ].replace ("," , "" )
425
+ return '\n ' .join (lines )
426
+
421
427
422
428
def _get_gemma_state_dict (input_ckpt_dir ):
423
429
ckpt_file = list (input_ckpt_dir .glob ("*.ckpt" ))
@@ -426,7 +432,9 @@ def _get_gemma_state_dict(input_ckpt_dir):
426
432
state_dict = torch .load (str (ckpt_file ), map_location = torch .device ("cpu" ))[
427
433
"model_state_dict"
428
434
]
429
- model_config = json .loads ((input_ckpt_dir / "config.json" ).read_text ())
435
+ config_text = fix_json ((input_ckpt_dir / "config.json" ).read_text ())
436
+ print ('gemma config is' , config_text )
437
+ model_config = json .loads (config_text )
430
438
for key in list (state_dict .keys ()):
431
439
if state_dict [key ].dtype .is_complex and _OUTPUT_SAFETENSORS .value :
432
440
assert (
0 commit comments