Skip to content

Commit 1f4deb6

Browse files
authored
Adding support for safetensors and LoRa. (#2448)
* Adding support for `safetensors` and LoRa. * Adding metadata.
1 parent f20c8f5 commit 1f4deb6

File tree

2 files changed

+121
-18
lines changed

2 files changed

+121
-18
lines changed

src/diffusers/loaders.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@
1919

2020
from .models.cross_attention import LoRACrossAttnProcessor
2121
from .models.modeling_utils import _get_model_file
22-
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
22+
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging
23+
24+
25+
if is_safetensors_available():
26+
import safetensors
2327

2428

2529
logger = logging.get_logger(__name__)
2630

2731

2832
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
33+
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
2934

3035

3136
class AttnProcsLayers(torch.nn.Module):
@@ -136,28 +141,53 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
136141
use_auth_token = kwargs.pop("use_auth_token", None)
137142
revision = kwargs.pop("revision", None)
138143
subfolder = kwargs.pop("subfolder", None)
139-
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
144+
weight_name = kwargs.pop("weight_name", None)
140145

141146
user_agent = {
142147
"file_type": "attn_procs_weights",
143148
"framework": "pytorch",
144149
}
145150

151+
model_file = None
146152
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
147-
model_file = _get_model_file(
148-
pretrained_model_name_or_path_or_dict,
149-
weights_name=weight_name,
150-
cache_dir=cache_dir,
151-
force_download=force_download,
152-
resume_download=resume_download,
153-
proxies=proxies,
154-
local_files_only=local_files_only,
155-
use_auth_token=use_auth_token,
156-
revision=revision,
157-
subfolder=subfolder,
158-
user_agent=user_agent,
159-
)
160-
state_dict = torch.load(model_file, map_location="cpu")
153+
if is_safetensors_available():
154+
if weight_name is None:
155+
weight_name = LORA_WEIGHT_NAME_SAFE
156+
try:
157+
model_file = _get_model_file(
158+
pretrained_model_name_or_path_or_dict,
159+
weights_name=weight_name,
160+
cache_dir=cache_dir,
161+
force_download=force_download,
162+
resume_download=resume_download,
163+
proxies=proxies,
164+
local_files_only=local_files_only,
165+
use_auth_token=use_auth_token,
166+
revision=revision,
167+
subfolder=subfolder,
168+
user_agent=user_agent,
169+
)
170+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
171+
except EnvironmentError:
172+
if weight_name == LORA_WEIGHT_NAME_SAFE:
173+
weight_name = None
174+
if model_file is None:
175+
if weight_name is None:
176+
weight_name = LORA_WEIGHT_NAME
177+
model_file = _get_model_file(
178+
pretrained_model_name_or_path_or_dict,
179+
weights_name=weight_name,
180+
cache_dir=cache_dir,
181+
force_download=force_download,
182+
resume_download=resume_download,
183+
proxies=proxies,
184+
local_files_only=local_files_only,
185+
use_auth_token=use_auth_token,
186+
revision=revision,
187+
subfolder=subfolder,
188+
user_agent=user_agent,
189+
)
190+
state_dict = torch.load(model_file, map_location="cpu")
161191
else:
162192
state_dict = pretrained_model_name_or_path_or_dict
163193

@@ -195,8 +225,9 @@ def save_attn_procs(
195225
self,
196226
save_directory: Union[str, os.PathLike],
197227
is_main_process: bool = True,
198-
weights_name: str = LORA_WEIGHT_NAME,
228+
weights_name: str = None,
199229
save_function: Callable = None,
230+
safe_serialization: bool = False,
200231
):
201232
r"""
202233
Save an attention processor to a directory, so that it can be re-loaded using the
@@ -219,7 +250,13 @@ def save_attn_procs(
219250
return
220251

221252
if save_function is None:
222-
save_function = torch.save
253+
if safe_serialization:
254+
255+
def save_function(weights, filename):
256+
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
257+
258+
else:
259+
save_function = torch.save
223260

224261
os.makedirs(save_directory, exist_ok=True)
225262

@@ -237,6 +274,12 @@ def save_attn_procs(
237274
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
238275
os.remove(full_filename)
239276

277+
if weights_name is None:
278+
if safe_serialization:
279+
weights_name = LORA_WEIGHT_NAME_SAFE
280+
else:
281+
weights_name = LORA_WEIGHT_NAME
282+
240283
# Save the model
241284
save_function(state_dict, os.path.join(save_directory, weights_name))
242285

tests/models/test_models_unet_2d_condition.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import gc
17+
import os
1718
import tempfile
1819
import unittest
1920

@@ -372,6 +373,65 @@ def test_lora_save_load(self):
372373

373374
with tempfile.TemporaryDirectory() as tmpdirname:
374375
model.save_attn_procs(tmpdirname)
376+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
377+
torch.manual_seed(0)
378+
new_model = self.model_class(**init_dict)
379+
new_model.to(torch_device)
380+
new_model.load_attn_procs(tmpdirname)
381+
382+
with torch.no_grad():
383+
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
384+
385+
assert (sample - new_sample).abs().max() < 1e-4
386+
387+
# LoRA and no LoRA should NOT be the same
388+
assert (sample - old_sample).abs().max() > 1e-4
389+
390+
def test_lora_save_load_safetensors(self):
391+
# enable deterministic behavior for gradient checkpointing
392+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
393+
394+
init_dict["attention_head_dim"] = (8, 16)
395+
396+
torch.manual_seed(0)
397+
model = self.model_class(**init_dict)
398+
model.to(torch_device)
399+
400+
with torch.no_grad():
401+
old_sample = model(**inputs_dict).sample
402+
403+
lora_attn_procs = {}
404+
for name in model.attn_processors.keys():
405+
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
406+
if name.startswith("mid_block"):
407+
hidden_size = model.config.block_out_channels[-1]
408+
elif name.startswith("up_blocks"):
409+
block_id = int(name[len("up_blocks.")])
410+
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
411+
elif name.startswith("down_blocks"):
412+
block_id = int(name[len("down_blocks.")])
413+
hidden_size = model.config.block_out_channels[block_id]
414+
415+
lora_attn_procs[name] = LoRACrossAttnProcessor(
416+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
417+
)
418+
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
419+
420+
# add 1 to weights to mock trained weights
421+
with torch.no_grad():
422+
lora_attn_procs[name].to_q_lora.up.weight += 1
423+
lora_attn_procs[name].to_k_lora.up.weight += 1
424+
lora_attn_procs[name].to_v_lora.up.weight += 1
425+
lora_attn_procs[name].to_out_lora.up.weight += 1
426+
427+
model.set_attn_processor(lora_attn_procs)
428+
429+
with torch.no_grad():
430+
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
431+
432+
with tempfile.TemporaryDirectory() as tmpdirname:
433+
model.save_attn_procs(tmpdirname, safe_serialization=True)
434+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
375435
torch.manual_seed(0)
376436
new_model = self.model_class(**init_dict)
377437
new_model.to(torch_device)

0 commit comments

Comments
 (0)