Skip to content

error when launching reward model training with the weights produced in SFT stage #11

@javismiles

Description

@javismiles

I followed your instructions for the SFT training and I got the .pt trained weights,
then I run:

python train_rm.py -b 2 -n experiment_name -p "./runs/sft_javSFT2_202405021344/sft_javSFT2_202405021344_step4000.pt

and I get this error in line 87 of train_rm.py

File "/xxxxxxxxxxxxxxxxxxxx/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GPT:
Missing key(s) in state_dict: "transformer.decoder_blocks.0.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.0.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.0.mmsa.output_projection.lora_A", "transformer.decoder_blocks.0.mmsa.output_projection.lora_B", "transformer.decoder_blocks.0.ffn.fc1.lora_A", "transformer.decoder_blocks.0.ffn.fc1.lora_B", "transformer.decoder_blocks.0.ffn.fc2.lora_A", "transformer.decoder_blocks.0.ffn.fc2.lora_B", "transformer.decoder_blocks.1.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.1.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.1.mmsa.output_projection.lora_A", "transformer.decoder_blocks.1.mmsa.output_projection.lora_B", "transformer.decoder_blocks.1.ffn.fc1.lora_A", "transformer.decoder_blocks.1.ffn.fc1.lora_B", "transformer.decoder_blocks.1.ffn.fc2.lora_A", "transformer.decoder_blocks.1.ffn.fc2.lora_B", "transformer.decoder_blocks.2.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.2.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.2.mmsa.output_projection.lora_A", "transformer.decoder_blocks.2.mmsa.output_projection.lora_B", "transformer.decoder_blocks.2.ffn.fc1.lora_A", "transformer.decoder_blocks.2.ffn.fc1.lora_B", "transformer.decoder_blocks.2.ffn.fc2.lora_A", "transformer.decoder_blocks.2.ffn.fc2.lora_B", "transformer.decoder_blocks.3.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.3.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.3.mmsa.output_projection.lora_A", "transformer.decoder_blocks.3.mmsa.output_projection.lora_B", "transformer.decoder_blocks.3.ffn.fc1.lora_A", "transformer.decoder_blocks.3.ffn.fc1.lora_B", "transformer.decoder_blocks.3.ffn.fc2.lora_A", "transformer.decoder_blocks.3.ffn.fc2.lora_B", "transformer.decoder_blocks.4.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.4.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.4.mmsa.output_projection.lora_A", "transformer.decoder_blocks.4.mmsa.output_projection.lora_B", "transformer.decoder_blocks.4.ffn.fc1.lora_A", "transformer.decoder_blocks.4.ffn.fc1.lora_B", "transformer.decoder_blocks.4.ffn.fc2.lora_A", "transformer.decoder_blocks.4.ffn.fc2.lora_B", "transformer.decoder_blocks.5.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.5.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.5.mmsa.output_projection.lora_A", "transformer.decoder_blocks.5.mmsa.output_projection.lora_B", "transformer.decoder_blocks.5.ffn.fc1.lora_A", "transformer.decoder_blocks.5.ffn.fc1.lora_B", "transformer.decoder_blocks.5.ffn.fc2.lora_A", "transformer.decoder_blocks.5.ffn.fc2.lora_B", "transformer.decoder_blocks.6.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.6.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.6.mmsa.output_projection.lora_A", "transformer.decoder_blocks.6.mmsa.output_projection.lora_B", "transformer.decoder_blocks.6.ffn.fc1.lora_A", "transformer.decoder_blocks.6.ffn.fc1.lora_B", "transformer.decoder_blocks.6.ffn.fc2.lora_A", "transformer.decoder_blocks.6.ffn.fc2.lora_B", "transformer.decoder_blocks.7.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.7.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.7.mmsa.output_projection.lora_A", "transformer.decoder_blocks.7.mmsa.output_projection.lora_B", "transformer.decoder_blocks.7.ffn.fc1.lora_A", "transformer.decoder_blocks.7.ffn.fc1.lora_B", "transformer.decoder_blocks.7.ffn.fc2.lora_A", "transformer.decoder_blocks.7.ffn.fc2.lora_B", "transformer.decoder_blocks.8.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.8.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.8.mmsa.output_projection.lora_A", "transformer.decoder_blocks.8.mmsa.output_projection.lora_B", "transformer.decoder_blocks.8.ffn.fc1.lora_A", "transformer.decoder_blocks.8.ffn.fc1.lora_B", "transformer.decoder_blocks.8.ffn.fc2.lora_A", "transformer.decoder_blocks.8.ffn.fc2.lora_B", "transformer.decoder_blocks.9.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.9.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.9.mmsa.output_projection.lora_A", "transformer.decoder_blocks.9.mmsa.output_projection.lora_B", "transformer.decoder_blocks.9.ffn.fc1.lora_A", "transformer.decoder_blocks.9.ffn.fc1.lora_B", "transformer.decoder_blocks.9.ffn.fc2.lora_A", "transformer.decoder_blocks.9.ffn.fc2.lora_B", "transformer.decoder_blocks.10.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.10.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.10.mmsa.output_projection.lora_A", "transformer.decoder_blocks.10.mmsa.output_projection.lora_B", "transformer.decoder_blocks.10.ffn.fc1.lora_A", "transformer.decoder_blocks.10.ffn.fc1.lora_B", "transformer.decoder_blocks.10.ffn.fc2.lora_A", "transformer.decoder_blocks.10.ffn.fc2.lora_B", "transformer.decoder_blocks.11.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.11.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.11.mmsa.output_projection.lora_A", "transformer.decoder_blocks.11.mmsa.output_projection.lora_B", "transformer.decoder_blocks.11.ffn.fc1.lora_A", "transformer.decoder_blocks.11.ffn.fc1.lora_B", "transformer.decoder_blocks.11.ffn.fc2.lora_A", "transformer.decoder_blocks.11.ffn.fc2.lora_B", "transformer.decoder_blocks.12.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.12.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.12.mmsa.output_projection.lora_A", "transformer.decoder_blocks.12.mmsa.output_projection.lora_B", "transformer.decoder_blocks.12.ffn.fc1.lora_A", "transformer.decoder_blocks.12.ffn.fc1.lora_B", "transformer.decoder_blocks.12.ffn.fc2.lora_A", "transformer.decoder_blocks.12.ffn.fc2.lora_B", "transformer.decoder_blocks.13.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.13.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.13.mmsa.output_projection.lora_A", "transformer.decoder_blocks.13.mmsa.output_projection.lora_B", "transformer.decoder_blocks.13.ffn.fc1.lora_A", "transformer.decoder_blocks.13.ffn.fc1.lora_B", "transformer.decoder_blocks.13.ffn.fc2.lora_A", "transformer.decoder_blocks.13.ffn.fc2.lora_B", "transformer.decoder_blocks.14.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.14.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.14.mmsa.output_projection.lora_A", "transformer.decoder_blocks.14.mmsa.output_projection.lora_B", "transformer.decoder_blocks.14.ffn.fc1.lora_A", "transformer.decoder_blocks.14.ffn.fc1.lora_B", "transformer.decoder_blocks.14.ffn.fc2.lora_A", "transformer.decoder_blocks.14.ffn.fc2.lora_B", "transformer.decoder_blocks.15.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.15.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.15.mmsa.output_projection.lora_A", "transformer.decoder_blocks.15.mmsa.output_projection.lora_B", "transformer.decoder_blocks.15.ffn.fc1.lora_A", "transformer.decoder_blocks.15.ffn.fc1.lora_B", "transformer.decoder_blocks.15.ffn.fc2.lora_A", "transformer.decoder_blocks.15.ffn.fc2.lora_B", "transformer.decoder_blocks.16.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.16.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.16.mmsa.output_projection.lora_A", "transformer.decoder_blocks.16.mmsa.output_projection.lora_B", "transformer.decoder_blocks.16.ffn.fc1.lora_A", "transformer.decoder_blocks.16.ffn.fc1.lora_B", "transformer.decoder_blocks.16.ffn.fc2.lora_A", "transformer.decoder_blocks.16.ffn.fc2.lora_B", "transformer.decoder_blocks.17.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.17.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.17.mmsa.output_projection.lora_A", "transformer.decoder_blocks.17.mmsa.output_projection.lora_B", "transformer.decoder_blocks.17.ffn.fc1.lora_A", "transformer.decoder_blocks.17.ffn.fc1.lora_B", "transformer.decoder_blocks.17.ffn.fc2.lora_A", "transformer.decoder_blocks.17.ffn.fc2.lora_B", "transformer.decoder_blocks.18.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.18.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.18.mmsa.output_projection.lora_A", "transformer.decoder_blocks.18.mmsa.output_projection.lora_B", "transformer.decoder_blocks.18.ffn.fc1.lora_A", "transformer.decoder_blocks.18.ffn.fc1.lora_B", "transformer.decoder_blocks.18.ffn.fc2.lora_A", "transformer.decoder_blocks.18.ffn.fc2.lora_B", "transformer.decoder_blocks.19.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.19.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.19.mmsa.output_projection.lora_A", "transformer.decoder_blocks.19.mmsa.output_projection.lora_B", "transformer.decoder_blocks.19.ffn.fc1.lora_A", "transformer.decoder_blocks.19.ffn.fc1.lora_B", "transformer.decoder_blocks.19.ffn.fc2.lora_A", "transformer.decoder_blocks.19.ffn.fc2.lora_B", "transformer.decoder_blocks.20.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.20.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.20.mmsa.output_projection.lora_A", "transformer.decoder_blocks.20.mmsa.output_projection.lora_B", "transformer.decoder_blocks.20.ffn.fc1.lora_A", "transformer.decoder_blocks.20.ffn.fc1.lora_B", "transformer.decoder_blocks.20.ffn.fc2.lora_A", "transformer.decoder_blocks.20.ffn.fc2.lora_B", "transformer.decoder_blocks.21.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.21.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.21.mmsa.output_projection.lora_A", "transformer.decoder_blocks.21.mmsa.output_projection.lora_B", "transformer.decoder_blocks.21.ffn.fc1.lora_A", "transformer.decoder_blocks.21.ffn.fc1.lora_B", "transformer.decoder_blocks.21.ffn.fc2.lora_A", "transformer.decoder_blocks.21.ffn.fc2.lora_B", "transformer.decoder_blocks.22.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.22.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.22.mmsa.output_projection.lora_A", "transformer.decoder_blocks.22.mmsa.output_projection.lora_B", "transformer.decoder_blocks.22.ffn.fc1.lora_A", "transformer.decoder_blocks.22.ffn.fc1.lora_B", "transformer.decoder_blocks.22.ffn.fc2.lora_A", "transformer.decoder_blocks.22.ffn.fc2.lora_B", "transformer.decoder_blocks.23.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.23.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.23.mmsa.output_projection.lora_A", "transformer.decoder_blocks.23.mmsa.output_projection.lora_B", "transformer.decoder_blocks.23.ffn.fc1.lora_A", "transformer.decoder_blocks.23.ffn.fc1.lora_B", "transformer.decoder_blocks.23.ffn.fc2.lora_A", "transformer.decoder_blocks.23.ffn.fc2.lora_B", "lm_head.lora_A", "lm_head.lora_B".

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions