Skip to content

Commit 6d8eb3f

Browse files
committed
Fix convert_checkpoint.py for hf and gemma
1 parent 4535bdf commit 6d8eb3f

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

convert_checkpoints.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def _load_orig_llama_weight(input_ckpt_dir: epath.Path):
278278

279279
def _load_hf_llama_weight(input_ckpt_dir: epath.Path):
280280
print(f"Loading checkpoint files from {input_ckpt_dir}.")
281-
safetensors_files = input_ckpt_dir.glob("*.safetensors")
281+
safetensors_files = list(input_ckpt_dir.glob("*.safetensors"))
282282
if len(list(safetensors_files)) == 0:
283283
raise ValueError(
284284
f"No *.safetensors found in the input dir {input_ckpt_dir}"
@@ -418,6 +418,12 @@ def _get_llama_state_dict(input_ckpt_dir):
418418
print(f"Merging weights takes {end - start} seconds")
419419
return state_dict, params
420420

421+
def fix_json(text):
422+
text = text.replace("'", '"')
423+
lines = text.split('\n')
424+
lines[-3] = lines[-3].replace(",", "")
425+
return '\n'.join(lines)
426+
421427

422428
def _get_gemma_state_dict(input_ckpt_dir):
423429
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
@@ -426,7 +432,9 @@ def _get_gemma_state_dict(input_ckpt_dir):
426432
state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[
427433
"model_state_dict"
428434
]
429-
model_config = json.loads((input_ckpt_dir / "config.json").read_text())
435+
config_text = fix_json((input_ckpt_dir / "config.json").read_text())
436+
print('gemma config is', config_text)
437+
model_config = json.loads(config_text)
430438
for key in list(state_dict.keys()):
431439
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:
432440
assert (

0 commit comments

Comments
 (0)