Skip to content

using zero stage3 finetune sd2, dimension error occurs #1865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Nipi64310 opened this issue Dec 30, 2022 · 5 comments · Fixed by #3076
Closed

using zero stage3 finetune sd2, dimension error occurs #1865

Nipi64310 opened this issue Dec 30, 2022 · 5 comments · Fixed by #3076
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@Nipi64310
Copy link

Nipi64310 commented Dec 30, 2022

Describe the bug

An error is reported when using deepspeed's zero stage3 finetune diffusers/examples/text_to_image/train_text_to_image.py script. My machine's GPU is 4*2080ti, and because a single GPU cannot accommodate all SD2 parameters, the deepspeed zero stage3 strategy must be used.

Reproduction

accelerate.yaml

compute_environment: LOCAL_MACHINE
deepspeed_config:
 deepspeed_config_file: /home/kas/zero_stage3_offload_config.json
 zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
use_cpu: false

/home/kas/zero_stage3_offload_config.json

{
  "train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps":2,
  "train_batch_size":128,
  "steps_per_print": 2,
  "gradient_clipping": 1,
  "zero_optimization": {
    "stage": 3,
    "allgather_partitions": false,
    "allgather_bucket_size": 2e8,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 2e8,
    "contiguous_gradients": true,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "stage3_max_live_parameters" : 2e8,
    "stage3_max_reuse_distance" : 2e8,
    "stage3_prefetch_bucket_size": 2e8,
    "stage3_param_persistence_threshold": 2e8,
    "sub_group_size" : 2e8,
    "round_robin_gradients": true
  },
  "bf16": {
    "enabled": true
  }
}

launch script

pip install deepspeed
export MODEL_NAME="stabilityai/stable-diffusion-2"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch --config_file ./accelerate.yaml --mixed_precision="fp16"  train_text_to_image.py  \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --use_ema \
  --resolution=224 --center_crop --random_flip \
  --train_batch_size=16 \
  --gradient_accumulation_steps=2 \
  --gradient_checkpointing \
  --max_train_steps=500 \
  --learning_rate=6e-5 \
  --max_grad_norm=1 \
  --lr_scheduler="constant_with_warmup" --lr_warmup_steps=0 \
  --output_dir="sd-pokemon-model"

Logs

`0%| | 0/500 [00:00<?, ?it/s] Steps: 0%| | 0/500 [00:00<?, ?it/s]Traceback (most recent call last):
File "train_text_to_image.py ", line 718, in <module>
main()
File "train_text_to_image.py ", line 648, in main
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/lib/python3.7/site-packages/transformers/models/clip/modeling_clip.py", line 739, in forward
return_dict=return_dict,
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/lib/python3.7/site-packages/transformers/models/clip/modeling_clip.py", line 636, in forward
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/lib/python3.7/site-packages/transformers/models/clip/modeling_clip.py", line 165, in forward
inputs_embeds = self.token_embedding(input_ids)
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 160, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 2183, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D`

System Info

  • diffusers version: 0.11.1
  • Platform: Linux-4.15.0-29-generic-x86_64-with-debian-buster-sid
  • Python version: 3.7.7
  • PyTorch version (GPU?): 1.11.0+cu113 (True)
  • Huggingface_hub version: 0.10.1
  • Transformers version: 4.23.1
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes
@Nipi64310 Nipi64310 added the bug Something isn't working label Dec 30, 2022
@patrickvonplaten
Copy link
Contributor

@patil-suraj
Copy link
Contributor

I don't really know much about deepspeed maybe @williamberman knows more here.
As far as I have tried, the script works well with stage 2, using CPU offloading which should help fit the model on 2080ti.

@williamberman
Copy link
Contributor

I don't know off the top of my head. Someone opened an issue with the same error message using deepspeed stage 3 with transformers. deepspeedai/DeepSpeed#2746 Maybe best to see if the deepspeed team knows before digging in :)

@stas00
Copy link
Contributor

stas00 commented Apr 12, 2023

Fixed in #3076 (but zero-3 support is only partial)

Please carefully read the OP of the PR for details.

@github-actions
Copy link
Contributor

github-actions bot commented May 7, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants