Skip to content

Commit d6ac5ae

Browse files
patrickvonplatencloneofsimopcuencapatil-suraj
authored
[LoRA] Add LoRA training script (huggingface#1884)
* [Lora] first upload * add first lora version * upload * more * first training * up * correct * improve * finish loaders and inference * up * up * fix more * up * finish more * finish more * up * up * change year * revert year change * Change lines * Add cloneofsimo as co-author. Co-authored-by: Simo Ryu <[email protected]> * finish * fix docs * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Suraj Patil <[email protected]> * upload * finish Co-authored-by: Simo Ryu <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent 465f59f commit d6ac5ae

9 files changed

+467
-107
lines changed

dependency_versions_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020 The HuggingFace Team. All rights reserved.
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

loaders.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from collections import defaultdict
16+
from typing import Callable, Dict, Union
17+
18+
import torch
19+
20+
from .models.cross_attention import LoRACrossAttnProcessor
21+
from .models.modeling_utils import _get_model_file
22+
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
23+
24+
25+
logger = logging.get_logger(__name__)
26+
27+
28+
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
29+
30+
31+
class AttnProcsLayers(torch.nn.Module):
32+
def __init__(self, state_dict: Dict[str, torch.Tensor]):
33+
super().__init__()
34+
self.layers = torch.nn.ModuleList(state_dict.values())
35+
self.mapping = {k: v for k, v in enumerate(state_dict.keys())}
36+
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
37+
38+
# we add a hook to state_dict() and load_state_dict() so that the
39+
# naming fits with `unet.attn_processors`
40+
def map_to(module, state_dict, *args, **kwargs):
41+
new_state_dict = {}
42+
for key, value in state_dict.items():
43+
num = int(key.split(".")[1]) # 0 is always "layers"
44+
new_key = key.replace(f"layers.{num}", module.mapping[num])
45+
new_state_dict[new_key] = value
46+
47+
return new_state_dict
48+
49+
def map_from(module, state_dict, *args, **kwargs):
50+
all_keys = list(state_dict.keys())
51+
for key in all_keys:
52+
replace_key = key.split(".processor")[0] + ".processor"
53+
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
54+
state_dict[new_key] = state_dict[key]
55+
del state_dict[key]
56+
57+
self._register_state_dict_hook(map_to)
58+
self._register_load_state_dict_pre_hook(map_from, with_module=True)
59+
60+
61+
class UNet2DConditionLoadersMixin:
62+
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
63+
r"""
64+
Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
65+
defined in
66+
[cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
67+
and be a `torch.nn.Module` class.
68+
69+
<Tip warning={true}>
70+
71+
This function is experimental and might change in the future
72+
73+
</Tip>
74+
75+
Parameters:
76+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
77+
Can be either:
78+
79+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
80+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
81+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
82+
`./my_model_directory/`.
83+
- A [torch state
84+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
85+
86+
cache_dir (`Union[str, os.PathLike]`, *optional*):
87+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
88+
standard cache should not be used.
89+
force_download (`bool`, *optional*, defaults to `False`):
90+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
91+
cached versions if they exist.
92+
resume_download (`bool`, *optional*, defaults to `False`):
93+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
94+
file exists.
95+
proxies (`Dict[str, str]`, *optional*):
96+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
97+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
98+
local_files_only(`bool`, *optional*, defaults to `False`):
99+
Whether or not to only look at local files (i.e., do not try to download the model).
100+
use_auth_token (`str` or *bool*, *optional*):
101+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
102+
when running `diffusers-cli login` (stored in `~/.huggingface`).
103+
revision (`str`, *optional*, defaults to `"main"`):
104+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
105+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
106+
identifier allowed by git.
107+
subfolder (`str`, *optional*, defaults to `""`):
108+
In case the relevant files are located inside a subfolder of the model repo (either remote in
109+
huggingface.co or downloaded locally), you can specify the folder name here.
110+
111+
mirror (`str`, *optional*):
112+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
113+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
114+
Please refer to the mirror site for more information.
115+
116+
<Tip>
117+
118+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
119+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
120+
121+
</Tip>
122+
123+
<Tip>
124+
125+
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
126+
this method in a firewalled environment.
127+
128+
</Tip>
129+
"""
130+
131+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
132+
force_download = kwargs.pop("force_download", False)
133+
resume_download = kwargs.pop("resume_download", False)
134+
proxies = kwargs.pop("proxies", None)
135+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
136+
use_auth_token = kwargs.pop("use_auth_token", None)
137+
revision = kwargs.pop("revision", None)
138+
subfolder = kwargs.pop("subfolder", None)
139+
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
140+
141+
user_agent = {
142+
"file_type": "attn_procs_weights",
143+
"framework": "pytorch",
144+
}
145+
146+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
147+
model_file = _get_model_file(
148+
pretrained_model_name_or_path_or_dict,
149+
weights_name=weight_name,
150+
cache_dir=cache_dir,
151+
force_download=force_download,
152+
resume_download=resume_download,
153+
proxies=proxies,
154+
local_files_only=local_files_only,
155+
use_auth_token=use_auth_token,
156+
revision=revision,
157+
subfolder=subfolder,
158+
user_agent=user_agent,
159+
)
160+
state_dict = torch.load(model_file, map_location="cpu")
161+
else:
162+
state_dict = pretrained_model_name_or_path_or_dict
163+
164+
# fill attn processors
165+
attn_processors = {}
166+
167+
is_lora = all("lora" in k for k in state_dict.keys())
168+
169+
if is_lora:
170+
lora_grouped_dict = defaultdict(dict)
171+
for key, value in state_dict.items():
172+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
173+
lora_grouped_dict[attn_processor_key][sub_key] = value
174+
175+
for key, value_dict in lora_grouped_dict.items():
176+
rank = value_dict["to_k_lora.down.weight"].shape[0]
177+
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
178+
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
179+
180+
attn_processors[key] = LoRACrossAttnProcessor(
181+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
182+
)
183+
attn_processors[key].load_state_dict(value_dict)
184+
185+
else:
186+
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
187+
188+
# set correct dtype & device
189+
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
190+
191+
# set layers
192+
self.set_attn_processor(attn_processors)
193+
194+
def save_attn_procs(
195+
self,
196+
save_directory: Union[str, os.PathLike],
197+
is_main_process: bool = True,
198+
weights_name: str = LORA_WEIGHT_NAME,
199+
save_function: Callable = None,
200+
):
201+
r"""
202+
Save an attention procesor to a directory, so that it can be re-loaded using the
203+
`[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method.
204+
205+
Arguments:
206+
save_directory (`str` or `os.PathLike`):
207+
Directory to which to save. Will be created if it doesn't exist.
208+
is_main_process (`bool`, *optional*, defaults to `True`):
209+
Whether the process calling this is the main process or not. Useful when in distributed training like
210+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
211+
the main process to avoid race conditions.
212+
save_function (`Callable`):
213+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
214+
need to replace `torch.save` by another method. Can be configured with the environment variable
215+
`DIFFUSERS_SAVE_MODE`.
216+
"""
217+
if os.path.isfile(save_directory):
218+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
219+
return
220+
221+
if save_function is None:
222+
save_function = torch.save
223+
224+
os.makedirs(save_directory, exist_ok=True)
225+
226+
model_to_save = AttnProcsLayers(self.attn_processors)
227+
228+
# Save the model
229+
state_dict = model_to_save.state_dict()
230+
231+
# Clean the folder from a previous save
232+
for filename in os.listdir(save_directory):
233+
full_filename = os.path.join(save_directory, filename)
234+
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
235+
# in distributed settings to avoid race conditions.
236+
weights_no_suffix = weights_name.replace(".bin", "")
237+
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
238+
os.remove(full_filename)
239+
240+
# Save the model
241+
save_function(state_dict, os.path.join(save_directory, weights_name))
242+
243+
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")

models/cross_attention.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,68 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
246246
return hidden_states
247247

248248

249+
class LoRALinearLayer(nn.Module):
250+
def __init__(self, in_features, out_features, rank=4):
251+
super().__init__()
252+
253+
if rank > min(in_features, out_features):
254+
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
255+
256+
self.down = nn.Linear(in_features, rank, bias=False)
257+
self.up = nn.Linear(rank, out_features, bias=False)
258+
self.scale = 1.0
259+
260+
nn.init.normal_(self.down.weight, std=1 / rank)
261+
nn.init.zeros_(self.up.weight)
262+
263+
def forward(self, hidden_states):
264+
orig_dtype = hidden_states.dtype
265+
dtype = self.down.weight.dtype
266+
267+
down_hidden_states = self.down(hidden_states.to(dtype))
268+
up_hidden_states = self.up(down_hidden_states)
269+
270+
return up_hidden_states.to(orig_dtype)
271+
272+
273+
class LoRACrossAttnProcessor(nn.Module):
274+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
275+
super().__init__()
276+
277+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
278+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
279+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
280+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
281+
282+
def __call__(
283+
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
284+
):
285+
batch_size, sequence_length, _ = hidden_states.shape
286+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
287+
288+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
289+
query = attn.head_to_batch_dim(query)
290+
291+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
292+
293+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
294+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
295+
296+
key = attn.head_to_batch_dim(key)
297+
value = attn.head_to_batch_dim(value)
298+
299+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
300+
hidden_states = torch.bmm(attention_probs, value)
301+
hidden_states = attn.batch_to_head_dim(hidden_states)
302+
303+
# linear proj
304+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
305+
# dropout
306+
hidden_states = attn.to_out[1](hidden_states)
307+
308+
return hidden_states
309+
310+
249311
class CrossAttnAddedKVProcessor:
250312
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
251313
residual = hidden_states
@@ -312,6 +374,41 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
312374
hidden_states = attn.to_out[0](hidden_states)
313375
# dropout
314376
hidden_states = attn.to_out[1](hidden_states)
377+
return hidden_states
378+
379+
380+
class LoRAXFormersCrossAttnProcessor(nn.Module):
381+
def __init__(self, hidden_size, cross_attention_dim, rank=4):
382+
super().__init__()
383+
384+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
385+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
386+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
387+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
388+
389+
def __call__(
390+
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
391+
):
392+
batch_size, sequence_length, _ = hidden_states.shape
393+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
394+
395+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
396+
query = attn.head_to_batch_dim(query).contiguous()
397+
398+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
399+
400+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
401+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
402+
403+
key = attn.head_to_batch_dim(key).contiguous()
404+
value = attn.head_to_batch_dim(value).contiguous()
405+
406+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
407+
408+
# linear proj
409+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
410+
# dropout
411+
hidden_states = attn.to_out[1](hidden_states)
315412

316413
return hidden_states
317414

0 commit comments

Comments
 (0)