Skip to content

Commit d63169e

Browse files
committed
Fix convert_checkpoint.py for hf and gemma
1 parent 4535bdf commit d63169e

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

convert_checkpoints.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def _load_orig_llama_weight(input_ckpt_dir: epath.Path):
278278

279279
def _load_hf_llama_weight(input_ckpt_dir: epath.Path):
280280
print(f"Loading checkpoint files from {input_ckpt_dir}.")
281-
safetensors_files = input_ckpt_dir.glob("*.safetensors")
281+
safetensors_files = list(input_ckpt_dir.glob("*.safetensors"))
282282
if len(list(safetensors_files)) == 0:
283283
raise ValueError(
284284
f"No *.safetensors found in the input dir {input_ckpt_dir}"
@@ -419,14 +419,23 @@ def _get_llama_state_dict(input_ckpt_dir):
419419
return state_dict, params
420420

421421

422+
def fix_json(text):
423+
text = text.replace("'", '"')
424+
lines = text.split("\n")
425+
lines[-3] = lines[-3].replace(",", "")
426+
return "\n".join(lines)
427+
428+
422429
def _get_gemma_state_dict(input_ckpt_dir):
423430
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
424431
assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model."
425432
ckpt_file = ckpt_file[0]
426433
state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[
427434
"model_state_dict"
428435
]
429-
model_config = json.loads((input_ckpt_dir / "config.json").read_text())
436+
config_text = fix_json((input_ckpt_dir / "config.json").read_text())
437+
print("gemma config is", config_text)
438+
model_config = json.loads(config_text)
430439
for key in list(state_dict.keys()):
431440
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:
432441
assert (

0 commit comments

Comments
 (0)