Skip to content

Commit 8a22bab

Browse files
committed
Fix convert_checkpoint.py for hf and gemma
1 parent 4535bdf commit 8a22bab

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

convert_checkpoints.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def _load_orig_llama_weight(input_ckpt_dir: epath.Path):
278278

279279
def _load_hf_llama_weight(input_ckpt_dir: epath.Path):
280280
print(f"Loading checkpoint files from {input_ckpt_dir}.")
281-
safetensors_files = input_ckpt_dir.glob("*.safetensors")
281+
safetensors_files = list(input_ckpt_dir.glob("*.safetensors"))
282282
if len(list(safetensors_files)) == 0:
283283
raise ValueError(
284284
f"No *.safetensors found in the input dir {input_ckpt_dir}"
@@ -419,14 +419,22 @@ def _get_llama_state_dict(input_ckpt_dir):
419419
return state_dict, params
420420

421421

422+
def fix_json(text):
423+
text = text.replace("'", '"')
424+
lines = text.split("\n")
425+
lines[-3] = lines[-3].replace(",", "")
426+
return "\n".join(lines)
427+
428+
422429
def _get_gemma_state_dict(input_ckpt_dir):
423430
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
424431
assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model."
425432
ckpt_file = ckpt_file[0]
426433
state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[
427434
"model_state_dict"
428435
]
429-
model_config = json.loads((input_ckpt_dir / "config.json").read_text())
436+
config_text = fix_json((input_ckpt_dir / "config.json").read_text())
437+
model_config = json.loads(config_text)
430438
for key in list(state_dict.keys()):
431439
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:
432440
assert (

0 commit comments

Comments
 (0)