Skip to content

RuntimeError: 'weight' must be 2-D while training Flan-T5 models with stage 3 #2746

@smitanannaware

Description

@smitanannaware

I am using Huggingface Seq2SeqTrainer for training Flan-T5-xl model with deepspeed stage 3.

trainer = Seq2SeqTrainer(
                #model_init = self.model_init,
                model=self.model,
                args=training_args,
                train_dataset=train_ds,
                eval_dataset=val_ds,
                tokenizer = self.tokenizer,
                data_collator=self.data_collator,
                compute_metrics=self.compute_metrics,
            )
        
trainer.train()

I am stuck on below error:

  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/trainer.py", line 1527, in train
    return inner_training_loop(
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/trainer.py", line 1773, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/trainer.py", line 2523, in training_step
    loss = self.compute_loss(model, inputs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/trainer.py", line 2555, in compute_loss
    outputs = model(**inputs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1158, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1111, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1611, in forward
    encoder_outputs = self.encoder(
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 941, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D

The code works with Zero2 config but not working with Zero 3. I have tried a couple of settings but no luck.

{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": true
    },

    "gradient_accumulation_steps": 8,
    "gradient_clipping": "auto",
    "steps_per_print": 10,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

Any help would be appreciated.

Metadata

Metadata

Labels

bugSomething isn't workingtraining

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions