Skip to content

[LoRA] feat: add lora attention processor for pt 2.0. #3594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/en/api/attnprocessor.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ An attention processor is a class for applying different types of attention mech
## LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor

## LoRAAttnProcessor2_0
[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0

## CustomDiffusionAttnProcessor
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor

Expand Down
6 changes: 4 additions & 2 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
AttnAddedKVProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
Expand Down Expand Up @@ -844,8 +845,9 @@ def main(args):
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
lora_attn_processor_class = LoRAAttnAddedKVProcessor
else:
lora_attn_processor_class = LoRAAttnProcessor

lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Callable, Dict, List, Optional, Union

import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download

from .models.attention_processor import (
Expand All @@ -27,6 +28,7 @@
CustomDiffusionXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor,
XFormersAttnProcessor,
Expand Down Expand Up @@ -287,7 +289,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
attn_processor_class = LoRAXFormersAttnProcessor
else:
attn_processor_class = LoRAAttnProcessor
attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)

attn_processors[key] = attn_processor_class(
hidden_size=hidden_size,
Expand Down Expand Up @@ -927,11 +931,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di

# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
logger.info(f"Loading {self.text_encoder_name}.")
text_encoder_lora_state_dict = {
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {self.text_encoder_name}.")
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
text_encoder_lora_state_dict, network_alpha=network_alpha
)
Expand Down
117 changes: 106 additions & 11 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Callable, Optional, Union

import torch
Expand Down Expand Up @@ -166,7 +165,8 @@ def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor)
self.processor,
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
)
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
Expand Down Expand Up @@ -200,14 +200,6 @@ def set_use_memory_efficient_attention_xformers(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
warnings.warn(
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
"native efficient flash attention."
)
Comment on lines -203 to -210
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided to remove this rather confusing warning message. But LMK if you think otherwise.

We still want our users to take advantage of xformers for LoRA, Custom Diffusion, etc. even when the rest of the attention processors run with SDPA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree! In my experiments xformers is still also sometimes faster and more memory efficient

else:
try:
# Make sure we can run the memory efficient attention
Expand All @@ -220,6 +212,8 @@ def set_use_memory_efficient_attention_xformers(
raise e

if is_lora:
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually think for now, let's give the user full freedom over what to use

processor = LoRAXFormersAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
Expand Down Expand Up @@ -252,7 +246,10 @@ def set_use_memory_efficient_attention_xformers(
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
if is_lora:
processor = LoRAAttnProcessor(
attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
processor = attn_processor_class(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
Expand Down Expand Up @@ -548,6 +545,8 @@ class LoRAAttnProcessor(nn.Module):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
"""

def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
Expand Down Expand Up @@ -843,6 +842,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.

"""

def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
Expand Down Expand Up @@ -1162,6 +1162,9 @@ class LoRAXFormersAttnProcessor(nn.Module):
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.

"""

def __init__(
Expand Down Expand Up @@ -1236,6 +1239,97 @@ def __call__(
return hidden_states


class LoRAAttnProcessor2_0(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
attention.

Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
"""

def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)

head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class CustomDiffusionXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
Expand Down Expand Up @@ -1520,6 +1614,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
XFormersAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
LoRAAttnAddedKVProcessor,
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
Expand Down
19 changes: 16 additions & 3 deletions tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
Expand All @@ -28,6 +29,7 @@
AttnProcessor,
AttnProcessor2_0,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
Expand All @@ -46,16 +48,24 @@ def create_unet_lora_layers(unet: nn.Module):
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers


def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
text_lora_attn_procs = {}
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
for name, module in text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
text_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=module.out_proj.out_features, cross_attention_dim=None
)
return text_lora_attn_procs
Expand Down Expand Up @@ -368,7 +378,10 @@ def test_lora_unet_attn_processors(self):
# check if lora attention processors are used
for _, module in sd_pipe.unet.named_modules():
if isinstance(module, Attention):
self.assertIsInstance(module.processor, LoRAAttnProcessor)
attn_proc_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
self.assertIsInstance(module.processor, attn_proc_class)

@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_lora_unet_attn_processors_with_xformers(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_models_unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_lora_save_load(self):
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample

assert (sample - new_sample).abs().max() < 1e-4
assert (sample - new_sample).abs().max() < 5e-4
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of PyTorch SDPA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok for me!


# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_lora_save_load_safetensors(self):
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample

assert (sample - new_sample).abs().max() < 1e-4
assert (sample - new_sample).abs().max() < 3e-4

# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
Expand Down