Skip to content

[Feature] Finetune text encoder in train_text_to_image_lora #3912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from

Conversation

okotaku
Copy link
Contributor

@okotaku okotaku commented Jun 30, 2023

What does this PR do?

Fixes #3418

Add train_text_encoder args in train_text_to_image_lora.py. We can finetune text encoder.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@patrickvonplaten
Copy link
Contributor

cc @sayakpaul here

@sayakpaul
Copy link
Member

Thanks for adding this pipeline. Could you also share some comparative results with and without text encoder LoRA fine-tuning? Just trying to gauge the effectiveness.

@okotaku
Copy link
Contributor Author

okotaku commented Jul 6, 2023

@sayakpaul
I used this dataset for finetune.

only unet

test prompt = '1girl, X X'
tmp9

test prompt = '1girl, >_<'
tmp8

test prompt = '1girl, @_@'
tmp11

test prompt = '1girl, =_='
tmp12

with text encoder

test prompt = '1girl, X X'
tmp

test prompt = '1girl, >_<'
tmp2

test prompt = '1girl, @_@'
tmp3

test prompt = '1girl, =_='
tmp4

test prompt = '1girl, Jitome'
tmp5

test prompt = '1girl, :I'
tmp6

test prompt = '1girl, ._.'
tmp7

@sayakpaul
Copy link
Member

Thanks for your hard work!

Maybe let's wait for #3778 to get merged, as there are some refactoring-related changes that will make this simpler.

Okay for you?

@okotaku
Copy link
Contributor Author

okotaku commented Jul 7, 2023

@sayakpaul No problems.

@patrickvonplaten
Copy link
Contributor

Gentle ping @sayakpaul - do you think we could merge this?

@okotaku
Copy link
Contributor Author

okotaku commented Jul 13, 2023

@patrickvonplaten Should we fix some codes based on this PR?

@sayakpaul
Copy link
Member

@okotaku my apologies for forgetting to ping you here.

Yes, let's refactor this PR based on the changes introduced in #3778. Happy to help you out in any way :)

@okotaku
Copy link
Contributor Author

okotaku commented Jul 15, 2023

@sayakpaul I updated codes. But it caused following errors.
It looks like this issues.
Do you have any ideas to solve it?

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮█████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  4.69it/s]
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:1054 in       │
│ <module>                                                                                         │
│                                                                                                  │
│   1051                                                                                           │
│   1052                                                                                           │
│   1053 if __name__ == "__main__":                                                                │
│ ❱ 1054 │   main()                                                                                │
│   1055                                                                                           │
│                                                                                                  │
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:961 in main   │
│                                                                                                  │
│    958 │   │   │   │   images = []                                                               │
│    959 │   │   │   │   for _ in range(args.num_validation_images):                               │
│    960 │   │   │   │   │   #with torch.cuda.amp.autocast(enabled=False): #dtype=torch.float32):  │
│ ❱  961 │   │   │   │   │   image = pipeline(args.validation_prompt, num_inference_steps=30, gen  │
│    962 │   │   │   │   │   images.append(image)                                                  │
│    963 │   │   │   │                                                                             │
│    964 │   │   │   │   for tracker in accelerator.trackers:                                      │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27 in                │
│ decorate_context                                                                                 │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_st │
│ able_diffusion.py:728 in __call__                                                                │
│                                                                                                  │
│   725 │   │   │   │   latent_model_input = self.scheduler.scale_model_input(latent_model_input   │
│   726 │   │   │   │                                                                              │
│   727 │   │   │   │   # predict the noise residual                                               │
│ ❱ 728 │   │   │   │   noise_pred = self.unet(                                                    │
│   729 │   │   │   │   │   latent_model_input,                                                    │
│   730 │   │   │   │   │   t,                                                                     │
│   731 │   │   │   │   │   encoder_hidden_states=prompt_embeds,                                   │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/accelerate/utils/operations.py:521 in forward    │
│                                                                                                  │
│   518 │   model_forward = ConvertOutputsToFp32(model_forward)                                    │
│   519 │                                                                                          │
│   520 │   def forward(*args, **kwargs):                                                          │
│ ❱ 521 │   │   return model_forward(*args, **kwargs)                                              │
│   522 │                                                                                          │
│   523 │   # To act like a decorator so that it can be popped when doing `extract_model_from_pa   │
│   524 │   forward.__wrapped__ = model_forward                                                    │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/accelerate/utils/operations.py:509 in __call__   │
│                                                                                                  │
│   506 │   │   update_wrapper(self, model_forward)                                                │
│   507 │                                                                                          │
│   508 │   def __call__(self, *args, **kwargs):                                                   │
│ ❱ 509 │   │   return convert_to_fp32(self.model_forward(*args, **kwargs))                        │
│   510 │                                                                                          │
│   511 │   def __getstate__(self):                                                                │
│   512 │   │   raise pickle.PicklingError(                                                        │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/amp/autocast_mode.py:14 in                 │
│ decorate_autocast                                                                                │
│                                                                                                  │
│    11 │   @functools.wraps(func)                                                                 │
│    12 │   def decorate_autocast(*args, **kwargs):                                                │
│    13 │   │   with autocast_instance:                                                            │
│ ❱  14 │   │   │   return func(*args, **kwargs)                                                   │
│    15 │   decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in    │
│    16 │   return decorate_autocast                                                               │
│    17                                                                                            │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py:905 in     │
│ forward                                                                                          │
│                                                                                                  │
│   902 │   │   down_block_res_samples = (sample,)                                                 │
│   903 │   │   for downsample_block in self.down_blocks:                                          │
│   904 │   │   │   if hasattr(downsample_block, "has_cross_attention") and downsample_block.has   │
│ ❱ 905 │   │   │   │   sample, res_samples = downsample_block(                                    │
│   906 │   │   │   │   │   hidden_states=sample,                                                  │
│   907 │   │   │   │   │   temb=emb,                                                              │
│   908 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                           │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py:993 in        │
│ forward                                                                                          │
│                                                                                                  │
│    990 │   │   │   │   )[0]                                                                      │
│    991 │   │   │   else:                                                                         │
│    992 │   │   │   │   hidden_states = resnet(hidden_states, temb)                               │
│ ❱  993 │   │   │   │   hidden_states = attn(                                                     │
│    994 │   │   │   │   │   hidden_states,                                                        │
│    995 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                          │
│    996 │   │   │   │   │   cross_attention_kwargs=cross_attention_kwargs,                        │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/transformer_2d.py:291 in        │
│ forward                                                                                          │
│                                                                                                  │
│   288 │   │                                                                                      │
│   289 │   │   # 2. Blocks                                                                        │
│   290 │   │   for block in self.transformer_blocks:                                              │
│ ❱ 291 │   │   │   hidden_states = block(                                                         │
│   292 │   │   │   │   hidden_states,                                                             │
│   293 │   │   │   │   attention_mask=attention_mask,                                             │
│   294 │   │   │   │   encoder_hidden_states=encoder_hidden_states,                               │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/attention.py:170 in forward     │
│                                                                                                  │
│   167 │   │   │   │   self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self   │
│   168 │   │   │   )                                                                              │
│   169 │   │   │                                                                                  │
│ ❱ 170 │   │   │   attn_output = self.attn2(                                                      │
│   171 │   │   │   │   norm_hidden_states,                                                        │
│   172 │   │   │   │   encoder_hidden_states=encoder_hidden_states,                               │
│   173 │   │   │   │   attention_mask=encoder_attention_mask,                                     │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/attention_processor.py:321 in   │
│ forward                                                                                          │
│                                                                                                  │
│    318 │   │   # The `Attention` class can call different attention processors / attention func  │
│    319 │   │   # here we simply pass along all tensors to the selected processor class           │
│    320 │   │   # For standard processors that are defined here, `**cross_attention_kwargs` is e  │
│ ❱  321 │   │   return self.processor(                                                            │
│    322 │   │   │   self,                                                                         │
│    323 │   │   │   hidden_states,                                                                │
│    324 │   │   │   encoder_hidden_states=encoder_hidden_states,                                  │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/attention_processor.py:1224 in  │
│ __call__                                                                                         │
│                                                                                                  │
│   1221 │   │   key = attn.head_to_batch_dim(key).contiguous()                                    │
│   1222 │   │   value = attn.head_to_batch_dim(value).contiguous()                                │
│   1223 │   │                                                                                     │
│ ❱ 1224 │   │   hidden_states = xformers.ops.memory_efficient_attention(                          │
│   1225 │   │   │   query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=att  │
│   1226 │   │   )                                                                                 │
│   1227 │   │   hidden_states = attn.batch_to_head_dim(hidden_states)                             │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py:192 in             │
│ memory_efficient_attention                                                                       │
│                                                                                                  │
│   189 │   │   and options.                                                                       │
│   190 │   :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``                     │
│   191 │   """                                                                                    │
│ ❱ 192 │   return _memory_efficient_attention(                                                    │
│   193 │   │   Inputs(                                                                            │
│   194 │   │   │   query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale       │
│   195 │   │   ),                                                                                 │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py:290 in             │
│ _memory_efficient_attention                                                                      │
│                                                                                                  │
│   287 ) -> torch.Tensor:                                                                         │
│   288 │   # fast-path that doesn't require computing the logsumexp for backward computation      │
│   289 │   if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):             │
│ ❱ 290 │   │   return _memory_efficient_attention_forward(                                        │
│   291 │   │   │   inp, op=op[0] if op is not None else None                                      │
│   292 │   │   )                                                                                  │
│   293                                                                                            │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py:303 in             │
│ _memory_efficient_attention_forward                                                              │
│                                                                                                  │
│   300 def _memory_efficient_attention_forward(                                                   │
│   301 │   inp: Inputs, op: Optional[Type[AttentionFwOpBase]]                                     │
│   302 ) -> torch.Tensor:                                                                         │
│ ❱ 303 │   inp.validate_inputs()                                                                  │
│   304 │   output_shape = inp.normalize_bmhk()                                                    │
│   305 │   if op is None:                                                                         │
│   306 │   │   op = _dispatch_fw(inp)                                                             │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/xformers/ops/fmha/common.py:73 in                │
│ validate_inputs                                                                                  │
│                                                                                                  │
│    70 │   │   if any(x.device != self.query.device for x in qkv):                                │
│    71 │   │   │   raise ValueError("Query/Key/Value should all be on the same device")           │
│    72 │   │   if any(x.dtype != self.query.dtype for x in qkv):                                  │
│ ❱  73 │   │   │   raise ValueError(                                                              │
│    74 │   │   │   │   "Query/Key/Value should all have the same dtype\n"                         │
│    75 │   │   │   │   f"  query.dtype: {self.query.dtype}\n"                                     │
│    76 │   │   │   │   f"  key.dtype  : {self.key.dtype}\n"                                       │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Query/Key/Value should all have the same dtype
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16

@sayakpaul
Copy link
Member

Hi,

Could you also try out the solutions provided in that thread to see if the errors persist?

Also, what happens if we do this in Torch 2.0 taking advantage of SDPA disabling xformers?

@okotaku
Copy link
Contributor Author

okotaku commented Jul 15, 2023

When disabling xformers,

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮███████████████████████████████████████████▎                 | 6/7 [00:01<00:00,  5.61it/s]
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:1054 in       │
│ <module>                                                                                         │
│                                                                                                  │
│   1051                                                                                           │
│   1052                                                                                           │
│   1053 if __name__ == "__main__":                                                                │
│ ❱ 1054 │   main()                                                                                │
│   1055                                                                                           │
│                                                                                                  │
│ /home/ubuntu/Desktop/cartoon-name-generation/tools/train_text_to_image_lora_new.py:961 in main   │
│                                                                                                  │
│    958 │   │   │   │   images = []                                                               │
│    959 │   │   │   │   for _ in range(args.num_validation_images):                               │
│    960 │   │   │   │   │   #with torch.cuda.amp.autocast(enabled=False): #dtype=torch.float32):  │
│ ❱  961 │   │   │   │   │   image = pipeline(args.validation_prompt, num_inference_steps=30, gen  │
│    962 │   │   │   │   │   images.append(image)                                                  │
│    963 │   │   │   │                                                                             │
│    964 │   │   │   │   for tracker in accelerator.trackers:                                      │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27 in                │
│ decorate_context                                                                                 │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_st │
│ able_diffusion.py:755 in __call__                                                                │
│                                                                                                  │
│   752 │   │   │   │   │   │   callback(i, t, latents)                                            │
│   753 │   │                                                                                      │
│   754 │   │   if not output_type == "latent":                                                    │
│ ❱ 755 │   │   │   image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dic   │
│   756 │   │   │   image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embe   │
│   757 │   │   else:                                                                              │
│   758 │   │   │   image = latents                                                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/utils/accelerate_utils.py:46 in        │
│ wrapper                                                                                          │
│                                                                                                  │
│   43 │   def wrapper(self, *args, **kwargs):                                                     │
│   44 │   │   if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):             │
│   45 │   │   │   self._hf_hook.pre_forward(self)                                                 │
│ ❱ 46 │   │   return method(self, *args, **kwargs)                                                │
│   47 │                                                                                           │
│   48 │   return wrapper                                                                          │
│   49                                                                                             │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py:264 in decode │
│                                                                                                  │
│   261 │   │   │   decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]      │
│   262 │   │   │   decoded = torch.cat(decoded_slices)                                            │
│   263 │   │   else:                                                                              │
│ ❱ 264 │   │   │   decoded = self._decode(z).sample                                               │
│   265 │   │                                                                                      │
│   266 │   │   if not return_dict:                                                                │
│   267 │   │   │   return (decoded,)                                                              │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py:250 in        │
│ _decode                                                                                          │
│                                                                                                  │
│   247 │   │   if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] >   │
│   248 │   │   │   return self.tiled_decode(z, return_dict=return_dict)                           │
│   249 │   │                                                                                      │
│ ❱ 250 │   │   z = self.post_quant_conv(z)                                                        │
│   251 │   │   dec = self.decoder(z)                                                              │
│   252 │   │                                                                                      │
│   253 │   │   if not return_dict:                                                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1194 in _call_impl    │
│                                                                                                  │
│   1191 │   │   # this function, and just call forward.                                           │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1195 │   │   # Do not call functions when jit is used                                          │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:463 in forward          │
│                                                                                                  │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
│ ❱  463 │   │   return self._conv_forward(input, self.weight, self.bias)                          │
│    464                                                                                           │
│    465 class Conv3d(_ConvNd):                                                                    │
│    466 │   __doc__ = r"""Applies a 3D convolution over an input signal composed of several inpu  │
│                                                                                                  │
│ /home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:459 in _conv_forward    │
│                                                                                                  │
│    456 │   │   │   return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel  │
│    457 │   │   │   │   │   │   │   weight, bias, self.stride,                                    │
│    458 │   │   │   │   │   │   │   _pair(0), self.dilation, self.groups)                         │
│ ❱  459 │   │   return F.conv2d(input, weight, bias, self.stride,                                 │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Input type (float) and bias type (c10::Half) should be the same

@okotaku
Copy link
Contributor Author

okotaku commented Jul 15, 2023

I tried some pattens, but all failed.

images = []
for _ in range(args.num_validation_images):
     with torch.cuda.amp.autocast():
          image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
          images.append(image)
images = []
for _ in range(args.num_validation_images):
     with torch.autocast('cuda'):
          image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
          images.append(image)
images = []
for _ in range(args.num_validation_images):
     with torch.cuda.amp.autocast(dtype=weight_dtype):
          image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
          images.append(image)
images = []
for _ in range(args.num_validation_images):
      image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
      images.append(image)

@okotaku
Copy link
Contributor Author

okotaku commented Jul 15, 2023

@sayakpaul pytorch2.0 and LoRAAttnProcessor2_0 works well.
This PR is ready for review.

@sayakpaul
Copy link
Member

Thanks for letting me know and for your efforts.

Could we maybe try to dive a bit deeper to see what is the issue? For example does the issue persist when not training the text encoder?

@okotaku
Copy link
Contributor Author

okotaku commented Jul 15, 2023

For example does the issue persist when not training the text encoder?

It also persist when training only unet.

@sayakpaul
Copy link
Member

Okay. Can we verify if the version in the main branch works with the following settings?

  • PyTorch 1.13.1 and xformers
  • PyTorch 2.0 and SDPA

Just trying to double down the issue. If you have other ideas to try out please let me know.

@okotaku
Copy link
Contributor Author

okotaku commented Jul 15, 2023

PyTorch 2.0 and SDPA works well.
PyTorch 1.13.1 and xformers fails.

@sayakpaul
Copy link
Member

Thanks for reporting.

This is indeed weird as the main version should support both. Did you try it out on the main version of the script or on your version?

@okotaku
Copy link
Contributor Author

okotaku commented Jul 16, 2023

The error raised when I used my branch of the script. The script of main branch works well.

@sayakpaul
Copy link
Member

Oh okay. Let's try to investigate the differences then :-)

@okotaku
Copy link
Contributor Author

okotaku commented Jul 16, 2023

The error solved when changing the place of applying xformers.
I don't know the reason :)

@okotaku
Copy link
Contributor Author

okotaku commented Jul 16, 2023

@sayakpaul This PR is ready.

Comment on lines +474 to +479
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work here. Thanks so much for all your experiments and iterations!

Let's also update the README of the example so that readers are aware that text encoder training is supported? We also need to add a test case for for the text encoder training to test_examples.py. Then this PR is good to be shipped 🚀

Also, could we do a full run with the changes to ensure the results you're getting are as expected?

@okotaku
Copy link
Contributor Author

okotaku commented Jul 18, 2023

@sayakpaul

Also, could we do a full run with the changes to ensure the results you're getting are as expected?

https://civitai.com/models/25613/classic-anime-expressions

You can download images from this page'sTraining Images tab.
CSV file is here.

metadata.csv

Params are followings,

accelerate launch tools/train_text_to_image_lora.py \
  --pretrained_model_name_or_path=models/anythingv5  \
  --train_data_dir=data/ExpressionTraining \
  --image_column=image \
  --center_crop --random_flip \
  --output_dir=work_dirs/ExpressionTraining \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --num_train_epochs=100 \
  --learning_rate=1e-4 \
  --lr_scheduler="cosine_with_restarts" --lr_warmup_steps=150 \
  --validation_prompt="1girl, >_<" \
  --checkpointing_steps=500 \
  --rank=128 \
  --snr_gamma=5 \
  --train_text_encoder \
  --seed="0"

Comment on lines 185 to 196
from huggingface_hub.repocard import RepoCard
from diffusers import StableDiffusionPipeline
import torch

lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]

pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.load_lora_weights(lora_model_id)
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use a LoRA you trained with this script? We'd also need to update the example prompt :)

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks amazing. Thanks for iterating.

Additionally, I'd also include a note about --train_text_encoder flag in the README here.

@patrickvonplaten
Copy link
Contributor

@williamberman could you also take a look here?

Comment on lines +470 to +482
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers

xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless I'm missing something, In the future it'd be helpful to not move a code block unrelated to the PR as it makes the diff harder to read :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3912 (comment)
#3912 (comment)

By moving the code block, I avoid this error.

images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
with torch.cuda.amp.autocast():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did something change that required autocast to be added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3912 (comment)

It also related to this error.
You can check this thread.

@williamberman
Copy link
Contributor

Looks basically good, a few small questions :)

@sayakpaul sayakpaul requested a review from williamberman July 21, 2023 02:54
@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Aug 14, 2023
@github-actions github-actions bot closed this Aug 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add text-encoder part in train_text_to_image_lora.py
5 participants