-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] feat: add lora attention processor for pt 2.0. #3594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bfcd0ad
53c5199
00df0a4
71b8ad2
4142edb
bf60598
e5fad84
1519075
a193d26
7898c11
6867427
4d3afd2
b694e3f
ffb136d
8c304bc
9d12c34
3c3c2f7
ba3f7ad
5017e92
06e9016
0c76451
b13c5df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,6 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import warnings | ||
from typing import Callable, Optional, Union | ||
|
||
import torch | ||
|
@@ -166,7 +165,8 @@ def set_use_memory_efficient_attention_xformers( | |
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None | ||
): | ||
is_lora = hasattr(self, "processor") and isinstance( | ||
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor) | ||
self.processor, | ||
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor), | ||
) | ||
is_custom_diffusion = hasattr(self, "processor") and isinstance( | ||
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) | ||
|
@@ -200,14 +200,6 @@ def set_use_memory_efficient_attention_xformers( | |
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" | ||
" only available for GPU " | ||
) | ||
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk: | ||
warnings.warn( | ||
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. " | ||
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) " | ||
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall " | ||
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 " | ||
"native efficient flash attention." | ||
) | ||
else: | ||
try: | ||
# Make sure we can run the memory efficient attention | ||
|
@@ -220,6 +212,8 @@ def set_use_memory_efficient_attention_xformers( | |
raise e | ||
|
||
if is_lora: | ||
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers | ||
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually think for now, let's give the user full freedom over what to use |
||
processor = LoRAXFormersAttnProcessor( | ||
hidden_size=self.processor.hidden_size, | ||
cross_attention_dim=self.processor.cross_attention_dim, | ||
|
@@ -252,7 +246,10 @@ def set_use_memory_efficient_attention_xformers( | |
processor = XFormersAttnProcessor(attention_op=attention_op) | ||
else: | ||
if is_lora: | ||
processor = LoRAAttnProcessor( | ||
attn_processor_class = ( | ||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor | ||
) | ||
processor = attn_processor_class( | ||
hidden_size=self.processor.hidden_size, | ||
cross_attention_dim=self.processor.cross_attention_dim, | ||
rank=self.processor.rank, | ||
|
@@ -548,6 +545,8 @@ class LoRAAttnProcessor(nn.Module): | |
The number of channels in the `encoder_hidden_states`. | ||
rank (`int`, defaults to 4): | ||
The dimension of the LoRA update matrices. | ||
network_alpha (`int`, *optional*): | ||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. | ||
""" | ||
|
||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): | ||
|
@@ -843,6 +842,7 @@ class LoRAAttnAddedKVProcessor(nn.Module): | |
The number of channels in the `encoder_hidden_states`. | ||
rank (`int`, defaults to 4): | ||
The dimension of the LoRA update matrices. | ||
|
||
""" | ||
|
||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): | ||
|
@@ -1162,6 +1162,9 @@ class LoRAXFormersAttnProcessor(nn.Module): | |
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to | ||
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best | ||
operator. | ||
network_alpha (`int`, *optional*): | ||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. | ||
|
||
""" | ||
|
||
def __init__( | ||
|
@@ -1236,6 +1239,97 @@ def __call__( | |
return hidden_states | ||
|
||
|
||
class LoRAAttnProcessor2_0(nn.Module): | ||
r""" | ||
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product | ||
attention. | ||
|
||
Args: | ||
hidden_size (`int`): | ||
The hidden size of the attention layer. | ||
cross_attention_dim (`int`, *optional*): | ||
The number of channels in the `encoder_hidden_states`. | ||
rank (`int`, defaults to 4): | ||
The dimension of the LoRA update matrices. | ||
network_alpha (`int`, *optional*): | ||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. | ||
""" | ||
|
||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): | ||
super().__init__() | ||
if not hasattr(F, "scaled_dot_product_attention"): | ||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | ||
|
||
self.hidden_size = hidden_size | ||
self.cross_attention_dim = cross_attention_dim | ||
self.rank = rank | ||
|
||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
|
||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
residual = hidden_states | ||
|
||
input_ndim = hidden_states.ndim | ||
|
||
if input_ndim == 4: | ||
batch_size, channel, height, width = hidden_states.shape | ||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | ||
|
||
batch_size, sequence_length, _ = ( | ||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||
) | ||
inner_dim = hidden_states.shape[-1] | ||
|
||
if attention_mask is not None: | ||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
# scaled_dot_product_attention expects attention_mask shape to be | ||
# (batch, heads, source_length, target_length) | ||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | ||
|
||
if attn.group_norm is not None: | ||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||
|
||
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
|
||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
elif attn.norm_cross: | ||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
|
||
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
|
||
head_dim = inner_dim // attn.heads | ||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
|
||
# TODO: add support for attn.scale when we move to Torch 2.1 | ||
hidden_states = F.scaled_dot_product_attention( | ||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | ||
) | ||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | ||
hidden_states = hidden_states.to(query.dtype) | ||
|
||
# linear proj | ||
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
# dropout | ||
hidden_states = attn.to_out[1](hidden_states) | ||
|
||
if input_ndim == 4: | ||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | ||
|
||
if attn.residual_connection: | ||
hidden_states = hidden_states + residual | ||
|
||
hidden_states = hidden_states / attn.rescale_output_factor | ||
|
||
return hidden_states | ||
|
||
|
||
class CustomDiffusionXFormersAttnProcessor(nn.Module): | ||
r""" | ||
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. | ||
|
@@ -1520,6 +1614,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, | |
XFormersAttnAddedKVProcessor, | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
LoRAAttnProcessor, | ||
LoRAXFormersAttnProcessor, | ||
LoRAAttnProcessor2_0, | ||
LoRAAttnAddedKVProcessor, | ||
CustomDiffusionAttnProcessor, | ||
CustomDiffusionXFormersAttnProcessor, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -261,7 +261,7 @@ def test_lora_save_load(self): | |
with torch.no_grad(): | ||
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample | ||
|
||
assert (sample - new_sample).abs().max() < 1e-4 | ||
assert (sample - new_sample).abs().max() < 5e-4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because of PyTorch SDPA. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok for me! |
||
|
||
# LoRA and no LoRA should NOT be the same | ||
assert (sample - old_sample).abs().max() > 1e-4 | ||
|
@@ -295,7 +295,7 @@ def test_lora_save_load_safetensors(self): | |
with torch.no_grad(): | ||
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample | ||
|
||
assert (sample - new_sample).abs().max() < 1e-4 | ||
assert (sample - new_sample).abs().max() < 3e-4 | ||
|
||
# LoRA and no LoRA should NOT be the same | ||
assert (sample - old_sample).abs().max() > 1e-4 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decided to remove this rather confusing warning message. But LMK if you think otherwise.
We still want our users to take advantage of xformers for LoRA, Custom Diffusion, etc. even when the rest of the attention processors run with SDPA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree! In my experiments xformers is still also sometimes faster and more memory efficient