Skip to content

run train_dreambooth_lora.py failed with accelerate #3284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
webliupeng opened this issue Apr 29, 2023 · 23 comments
Closed

run train_dreambooth_lora.py failed with accelerate #3284

webliupeng opened this issue Apr 29, 2023 · 23 comments
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@webliupeng
Copy link

webliupeng commented Apr 29, 2023

Describe the bug

Thanks for this awesome project!
When I run the script "train_dreambooth_lora.py" without acceleration, it works fine. But when I use acceleration launch, it fails when the number of steps reaches "checkpointing_steps".
I am running the script in a Docker with 4 * 3090 vGPUs. And I ran accelerate test, it's successed.
I am new to this and would appreciate any guidance or suggestions you can offer.

Reproduction

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="/diffusers/examples/dreambooth/dunhuang512"
export OUTPUT_DIR="path-to-save-model"
cd /diffusers/examples/dreambooth/
accelerate launch train_dreambooth_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --logging_dir='./logs' \
  --instance_prompt="dhstyle_test" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=100 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="dhstyle_test" \
  --validation_epochs=50 \
  --seed="0"\
  --enable_xformers_memory_efficient_attention \
  --use_8bit_adam

Logs

  File "/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 1093, in <module>
    main(args)
  File "/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 972, in main
    LoraLoaderMixin.save_lora_weights(
  File "/diffusers/src/diffusers/loaders.py", line 1111, in save_lora_weights
    for module_name, param in unet_lora_layers.state_dict().items()
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1818, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1820, in state_dict
    hook_result = hook(self, destination, prefix, local_metadata)
  File "/diffusers/src/diffusers/loaders.py", line 74, in map_to
    num = int(key.split(".")[1])  # 0 is always "layers"
ValueError: invalid literal for int() with base 10: 'layers'
Steps:  20%|████████████████████▊                                                                                   | 100/500 [03:35<14:20,  2.15s/it, loss=0.217, lr=0.0001]
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 63642 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 63643 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 63644 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 63641) of binary: /usr/local/bin/python
Traceback (most recent call last):
  File "/usr/local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    args.func(args)
  File "/usr/local/lib/python3.10/site-packages/accelerate/commands/launch.py", line 914, in launch_command
    multi_gpu_launcher(args)
  File "/usr/local/lib/python3.10/site-packages/accelerate/commands/launch.py", line 603, in multi_gpu_launcher
    distrib_run.run(args)
  File "/usr/local/lib/python3.10/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
train_dreambooth_lora.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-04-29_00:59:00
  host      : sd-5b564dfd58-7v76h
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 63641)
  error_file: <N/A>

System Info

  • diffusers version: 0.17.0.dev0

  • Platform: Linux-5.4.0-146-generic-x86_64-with-glibc2.31

  • Python version: 3.10.9

  • PyTorch version (GPU?): 2.0.0+cu117 (True)

  • Huggingface_hub version: 0.14.0

  • Transformers version: 4.25.1

  • Accelerate version: 0.18.0

  • xFormers version: 0.0.19

  • Using GPU in script?:

  • Using distributed or parallel set-up in script?:

  • Accelerate default config:
    - compute_environment: LOCAL_MACHINE
    - distributed_type: MULTI_GPU
    - mixed_precision: no
    - use_cpu: False
    - num_processes: 4
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []

@webliupeng webliupeng added the bug Something isn't working label Apr 29, 2023
@SeunghyunSEO
Copy link

SeunghyunSEO commented Apr 29, 2023

i got same error

and i also got problem when i run train_custom_diffuson.py

it also happens with accelerate and when i run with python train_custom_diffuson.py, error does not occur

│ /workspace/diffusers/src/diffusers/models/attention_processor.py:546 in __call__                 │                                
│                                                                                                  │                                
│    543 │   │   key = attn.head_to_batch_dim(key)                                                 │                                
│    544 │   │   value = attn.head_to_batch_dim(value)                                             │                                
│    545 │   │                                                                                     │                                
│ ❱  546 │   │   attention_probs = attn.get_attention_scores(query, key, attention_mask)           │                                
│    547 │   │   hidden_states = torch.bmm(attention_probs, value)                                 │                                
│    548 │   │   hidden_states = attn.batch_to_head_dim(hidden_states)                             │                                
│    549                                                                                           │                                
│                                                                                                  │                               
│ /workspace/diffusers/src/diffusers/models/attention_processor.py:308 in get_attention_scores     │                               
│                                                                                                  │                               
│    305 │   │   │   baddbmm_input = attention_mask                                                │                               
│    306 │   │   │   beta = 1                                                                      │                               
│    307 │   │                                                                                     │                               
│ ❱  308 │   │   attention_scores = torch.baddbmm(                                                 │                               
│    309 │   │   │   baddbmm_input,                                                                │                               
│    310 │   │   │   query,                                                                        │                               
│    311 │   │   │   key.transpose(-1, -2),                                                        │                               
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯                               
RuntimeError: expected scalar type Half but found Float
Steps:   0%|    

@SeunghyunSEO
Copy link

SeunghyunSEO commented Apr 29, 2023

i'm not sure why error occurs with accelerate, but it can be fixed by modfiy following lines of here

class AttnProcsLayers(torch.nn.Module):
    def __init__(self, state_dict: Dict[str, torch.Tensor]):
        super().__init__()
        self.layers = torch.nn.ModuleList(state_dict.values())
        self.mapping = dict(enumerate(state_dict.keys()))
        self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}

        # we add a hook to state_dict() and load_state_dict() so that the
        # naming fits with `unet.attn_processors`
        def map_to(module, state_dict, *args, **kwargs):
            new_state_dict = {}
            for key, value in state_dict.items():
                layer_index = 2 if 'module' in key else 1 ## you should add this line
                num = int(key.split(".")[layer_index])  # 0 is always "layers"
                new_key = key.replace(f"layers.{num}", module.mapping[num])
                new_state_dict[new_key] = value

            return new_state_dict

this is because key of state_dict is like module.layers.0.to_q_lora.down.weight not layers.0.to_q_lora.down.weight. so layers, the second elements of key could not be int.
i guess latest version did not check this with accelerate

this function is called by here when we run the code with accelerate

                    if accelerator.is_main_process:
                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        # We combine the text encoder and UNet LoRA parameters with a simple
                        # custom logic. `accelerator.save_state()` won't know that. So,
                        # use `LoraLoaderMixin.save_lora_weights()`.
                        LoraLoaderMixin.save_lora_weights(
                            save_directory=save_path,
                            unet_lora_layers=unet_lora_layers,
                            text_encoder_lora_layers=text_encoder_lora_layers,
                        )

should be

        def map_to(module, state_dict, *args, **kwargs):
            new_state_dict = {}
            for key, value in state_dict.items():
                # num = int(key.split(".")[layer_index])  # 0 is always "layers"
                # new_key = key.replace(f"layers.{num}", module.mapping[num])
                if 'module' in key:
                    num = int(key.split(".")[2]) 
                    replace_key = f"module.layers.{num}"
                else: 
                    num = int(key.split(".")[1]) 
                    replace_key = f"layers.{num}"
                new_key = key.replace(replace_key, module.mapping[num])
                new_state_dict[new_key] = value

so you can load pytorch_lora_weights.bin correctly

@webliupeng
Copy link
Author

i'm not sure why error occurs with accelerate, but it can be fixed by modfiy following lines of here

class AttnProcsLayers(torch.nn.Module):
    def __init__(self, state_dict: Dict[str, torch.Tensor]):
        super().__init__()
        self.layers = torch.nn.ModuleList(state_dict.values())
        self.mapping = dict(enumerate(state_dict.keys()))
        self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}

        # we add a hook to state_dict() and load_state_dict() so that the
        # naming fits with `unet.attn_processors`
        def map_to(module, state_dict, *args, **kwargs):
            new_state_dict = {}
            for key, value in state_dict.items():
                layer_index = 2 if 'module' in key else 1 ## you should add this line
                num = int(key.split(".")[layer_index])  # 0 is always "layers"
                new_key = key.replace(f"layers.{num}", module.mapping[num])
                new_state_dict[new_key] = value

            return new_state_dict

this is because key of state_dict is like module.layers.0.to_q_lora.down.weight not layers.0.to_q_lora.down.weight. so layers, the second elements of key could not be int. i guess latest version did not check this with accelerate

this function is called by here when we run the code with accelerate

                    if accelerator.is_main_process:
                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        # We combine the text encoder and UNet LoRA parameters with a simple
                        # custom logic. `accelerator.save_state()` won't know that. So,
                        # use `LoraLoaderMixin.save_lora_weights()`.
                        LoraLoaderMixin.save_lora_weights(
                            save_directory=save_path,
                            unet_lora_layers=unet_lora_layers,
                            text_encoder_lora_layers=text_encoder_lora_layers,
                        )

should be

        def map_to(module, state_dict, *args, **kwargs):
            new_state_dict = {}
            for key, value in state_dict.items():
                # num = int(key.split(".")[layer_index])  # 0 is always "layers"
                # new_key = key.replace(f"layers.{num}", module.mapping[num])
                if 'module' in key:
                    num = int(key.split(".")[2]) 
                    replace_key = f"module.layers.{num}"
                else: 
                    num = int(key.split(".")[1]) 
                    replace_key = f"layers.{num}"
                new_key = key.replace(replace_key, module.mapping[num])
                new_state_dict[new_key] = value

so you can load pytorch_lora_weights.bin correctly
Nice , I have changed this code. and then this error is gone.
But it still can't train successfully. I@SeunghyunSEO .

these are errors:

Traceback (most recent call last):
  File "/data/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 1093, in <module>
    main(args)
  File "/data/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 1048, in main
    pipeline.load_lora_weights(args.output_dir)
  File "/data/diffusers/src/diffusers/loaders.py", line 851, in load_lora_weights
    self.unet.load_attn_procs(unet_lora_state_dict)
  File "/data/diffusers/src/diffusers/loaders.py", line 305, in load_attn_procs
    self.set_attn_processor(attn_processors)
  File "/data/diffusers/src/diffusers/models/unet_2d_condition.py", line 533, in set_attn_processor
    fn_recursive_attn_processor(name, module, processor)
  File "/data/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor
    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
  File "/data/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor
    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
  File "/data/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor
    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
  [Previous line repeated 3 more times]
  File "/data/diffusers/src/diffusers/models/unet_2d_condition.py", line 527, in fn_recursive_attn_processor
    module.set_processor(processor.pop(f"{name}.processor"))
KeyError: 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor'
Steps: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 

@SeunghyunSEO
Copy link

i'm not sure why error occurs with accelerate, but it can be fixed by modfiy following lines of here

class AttnProcsLayers(torch.nn.Module):

def __init__(self, state_dict: Dict[str, torch.Tensor]):
    super().__init__()
    self.layers = torch.nn.ModuleList(state_dict.values())
    self.mapping = dict(enumerate(state_dict.keys()))
    self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
    # we add a hook to state_dict() and load_state_dict() so that the
    # naming fits with `unet.attn_processors`
    def map_to(module, state_dict, *args, **kwargs):
        new_state_dict = {}
        for key, value in state_dict.items():
            layer_index = 2 if 'module' in key else 1 ## you should add this line
            num = int(key.split(".")[layer_index])  # 0 is always "layers"
            new_key = key.replace(f"layers.{num}", module.mapping[num])
            new_state_dict[new_key] = value
        return new_state_dict

this is because key of state_dict is like module.layers.0.to_q_lora.down.weight not layers.0.to_q_lora.down.weight. so layers, the second elements of key could not be int. i guess latest version did not check this with accelerate

this function is called by here when we run the code with accelerate

                if accelerator.is_main_process:
                    save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                    # We combine the text encoder and UNet LoRA parameters with a simple
                    # custom logic. `accelerator.save_state()` won't know that. So,
                    # use `LoraLoaderMixin.save_lora_weights()`.
                    LoraLoaderMixin.save_lora_weights(
                        save_directory=save_path,
                        unet_lora_layers=unet_lora_layers,
                        text_encoder_lora_layers=text_encoder_lora_layers,
                    )

should be

    def map_to(module, state_dict, *args, **kwargs):
        new_state_dict = {}
        for key, value in state_dict.items():
            # num = int(key.split(".")[layer_index])  # 0 is always "layers"
            # new_key = key.replace(f"layers.{num}", module.mapping[num])
            if 'module' in key:
                num = int(key.split(".")[2]) 
                replace_key = f"module.layers.{num}"
            else: 
                num = int(key.split(".")[1]) 
                replace_key = f"layers.{num}"
            new_key = key.replace(replace_key, module.mapping[num])
            new_state_dict[new_key] = value

so you can load pytorch_lora_weights.bin correctly

Nice , I have changed this code. and then this error is gone.

But it still can't train successfully. I@SeunghyunSEO .

these are errors:

Traceback (most recent call last):

  File "/data/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 1093, in <module>

    main(args)

  File "/data/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 1048, in main

    pipeline.load_lora_weights(args.output_dir)

  File "/data/diffusers/src/diffusers/loaders.py", line 851, in load_lora_weights

    self.unet.load_attn_procs(unet_lora_state_dict)

  File "/data/diffusers/src/diffusers/loaders.py", line 305, in load_attn_procs

    self.set_attn_processor(attn_processors)

  File "/data/diffusers/src/diffusers/models/unet_2d_condition.py", line 533, in set_attn_processor

    fn_recursive_attn_processor(name, module, processor)

  File "/data/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor

    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

  File "/data/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor

    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

  File "/data/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor

    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

  [Previous line repeated 3 more times]

  File "/data/diffusers/src/diffusers/models/unet_2d_condition.py", line 527, in fn_recursive_attn_processor

    module.set_processor(processor.pop(f"{name}.processor"))

KeyError: 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor'

Steps: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 

i told you that you should also edit replace_key. check my last code block

@webliupeng
Copy link
Author

@SeunghyunSEO I mean another error occurred.

@SeunghyunSEO
Copy link

@SeunghyunSEO I mean another error occurred.

that's what i encountered error without this code

        def map_to(module, state_dict, *args, **kwargs):
            new_state_dict = {}
            for key, value in state_dict.items():
                # num = int(key.split(".")[layer_index])  # 0 is always "layers"
                # new_key = key.replace(f"layers.{num}", module.mapping[num])
                if 'module' in key:
                    num = int(key.split(".")[2]) 
                    replace_key = f"module.layers.{num}"
                else: 
                    num = int(key.split(".")[1]) 
                    replace_key = f"layers.{num}"
                new_key = key.replace(replace_key, module.mapping[num])

please check dictionary keys of lora weight with torch.load, maybe you will notice the keys are like module.layers.0. ~~~, that's what i mentioned. with my code, i got no problem to save or load weights.

@sayakpaul
Copy link
Member

We need to be able to reproduce this error.

I conducted some training runs with the following setups. First, I installed diffusers from source: pip -U install git+https://github.com/huggingface/diffusers.

  1. Without training the text encoder LoRA:
accelerate launch train_dreambooth_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=100 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=50 \
  --seed="0" \
  --push_to_hub

Final checkpoints: https://huggingface.co/sayakpaul/dog-test-lora.

  1. With text encoder LoRA along with the UNet LoRA:
accelerate launch train_dreambooth_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=100 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=50 \
  --seed="0" \
  --train_text_encoder \
  --push_to_hub

Final checkpoints: https://huggingface.co/sayakpaul/dreambooth-text-encoder-test.

I was able to conduct the above training runs without any failures.

What am I missing out on?

@webliupeng
Copy link
Author

We need to be able to reproduce this error.

I conducted some training runs with the following setups. First, I installed diffusers from source: pip -U install git+https://github.com/huggingface/diffusers.

  1. Without training the text encoder LoRA:
accelerate launch train_dreambooth_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=100 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=50 \
  --seed="0" \
  --push_to_hub

Final checkpoints: https://huggingface.co/sayakpaul/dog-test-lora.

  1. With text encoder LoRA along with the UNet LoRA:
accelerate launch train_dreambooth_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=100 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=50 \
  --seed="0" \
  --train_text_encoder \
  --push_to_hub

Final checkpoints: https://huggingface.co/sayakpaul/dreambooth-text-encoder-test.

I was able to conduct the above training runs without any failures.

What am I missing out on?

@sayakpaul
do you use MULTI GPU to training?
I followed the examples and tried to run train_dreambooth_lora.py in two different environments, both of which had 4 GPUs. One is docker container, Another is virual host , However, I encountered the same error in both environments.

@SeunghyunSEO
Copy link

We need to be able to reproduce this error.

I conducted some training runs with the following setups. First, I installed diffusers from source: pip -U install git+https://github.com/huggingface/diffusers.

  1. Without training the text encoder LoRA:
accelerate launch train_dreambooth_lora.py \

  --pretrained_model_name_or_path=$MODEL_NAME  \

  --instance_data_dir=$INSTANCE_DIR \

  --output_dir=$OUTPUT_DIR \

  --instance_prompt="a photo of sks dog" \

  --resolution=512 \

  --train_batch_size=1 \

  --gradient_accumulation_steps=1 \

  --checkpointing_steps=100 \

  --learning_rate=1e-4 \

  --report_to="wandb" \

  --lr_scheduler="constant" \

  --lr_warmup_steps=0 \

  --max_train_steps=500 \

  --validation_prompt="A photo of sks dog in a bucket" \

  --validation_epochs=50 \

  --seed="0" \

  --push_to_hub

Final checkpoints: https://huggingface.co/sayakpaul/dog-test-lora.

  1. With text encoder LoRA along with the UNet LoRA:
accelerate launch train_dreambooth_lora.py \

  --pretrained_model_name_or_path=$MODEL_NAME  \

  --instance_data_dir=$INSTANCE_DIR \

  --output_dir=$OUTPUT_DIR \

  --instance_prompt="a photo of sks dog" \

  --resolution=512 \

  --train_batch_size=1 \

  --gradient_accumulation_steps=1 \

  --checkpointing_steps=100 \

  --learning_rate=1e-4 \

  --report_to="wandb" \

  --lr_scheduler="constant" \

  --lr_warmup_steps=0 \

  --max_train_steps=500 \

  --validation_prompt="A photo of sks dog in a bucket" \

  --validation_epochs=50 \

  --seed="0" \

  --train_text_encoder \

  --push_to_hub

Final checkpoints: https://huggingface.co/sayakpaul/dreambooth-text-encoder-test.

I was able to conduct the above training runs without any failures.

What am I missing out on?

i installed main branch of diffusers from source

@sayakpaul
Copy link
Member

For train_dreambooth_lora.py single GPU be fine. Since we're only training a handful of parameters, single GPU should be fine.

@capntrips
Copy link

I'm running into the same thing with a single GPU using DeepSpeed. The above fix worked for me in that it prevented the initial error, but I get the same error as @webliupeng in load_lora_weights.

If I put a print statement in map_to after applying the fix, I see the keys are updated on every call to save_lora_weights. but the incorrect keys are still written to pytorch_lora_weights.bin.

@sayakpaul
Copy link
Member

Could you try updating your local clone of diffusers (git checkout main && git pull) and then retry?

As mentioned in #3284 (comment), I need to be able to reproduce the error minimally.

i installed main branch of diffusers from source

@SeunghyunSEO, could you also update your local clone of diffusers (git checkout main && git pull)?

@capntrips
Copy link

I did a git pull to efc48da and the issue persists. As for reproducing it, I downloaded the example dog dataset from the DreamBooth training example README here, enabled DeepSpeed per the same doc here, and ran the training script from the same section (using train_dreambooth_lora.py minus the class arguments and with checkpointing_steps set to 20 to speed up the process). Below are my accelerate config and modified training script:

default_config.yaml.gz
dog.sh.gz

The fix still allows the loras to be saved, though still with the incorrect keys.

I'm happy to provide system info, if any of that is useful.

@sayakpaul
Copy link
Member

sayakpaul commented May 3, 2023

What happens when you disable DeepSpeed?

Also, which method are you using to load the LoRA parameters obtained after training?

@capntrips
Copy link

What happens when you disable DeepSpeed?

It completes without error with the correct keys.

Also, which method are you using to load the LoRA parameters obtained after training?

The error on load I referenced is from the end of the end of train_dreambooth_lora.py here. Resuming from checkpoint here is also affected. The failure to load is probably no surprise given the incorrect keys.

@sayakpaul
Copy link
Member

Okay, so, probably narrowing it down -- it pops up after using DeepSpeed. Ccing @williamberman here.

@patrickvonplaten
Copy link
Contributor

Note here that DeepSpeed is still very much experimental and we probably won't have time to look into this more specifically. More than happy to review a PR though

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label May 29, 2023
@github-actions github-actions bot closed this as completed Jun 6, 2023
@markrmiller
Copy link

No deepspeed and issue persists. Pain in the behind, because everytime I step away for a while, I end up losing the hacked in fix and have to rediscover all this crap.

@sayakpaul sayakpaul reopened this Jul 2, 2023
@sayakpaul
Copy link
Member

Hi @markrmiller. I can understand the frustration. If you happen to know any fixes that have worked for you, happy to review any PR from you so that we can get it fixed.

But currently, we stand here: #3284 (comment).

@markrmiller
Copy link

Unfortunetly, I don't know what's behind it. I'm not, and have not used deepspeed though. I've also tried a large mix of accerate and diffusers versions, currently on latest source builds for both. I'm using the latest dreambooth-lora script - it fails as soon as it tries to save a checkpoint.

I can add that the previous fix above didn't actually work for me this time, and I had to expand it to this:

def map_to(module, state_dict, *args, **kwargs):
    new_state_dict = {}
    for key, value in state_dict.items():
        print("key:" + key)
        key_parts = key.split(".")
        if key_parts[0] == '_orig_mod':
            num = int(key_parts[2])
            replace_key = f"{key_parts[0]}.{key_parts[1]}.{num}"
        elif 'module' in key:
            num = int(key_parts[2])
            replace_key = f"{key_parts[0]}.{key_parts[1]}.{num}"
        else:
            num = int(key_parts[1])
            replace_key = f"{key_parts[0]}.{num}"
        print("replace_key:" + replace_key)
        new_key = key.replace(replace_key, module.mapping[num])
        print("new_key:" + new_key)
        new_state_dict[new_key] = value

So it's related to these keys mappings, but why I see it and someone else doesn't, I have no clue. As mentioned, if I don't run with accelerate, it doesn't happen.

@sayakpaul
Copy link
Member

When the checkpointing part works fine with that script with regular accelerate. This I can confirm because I ran it yesterday.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

6 participants