Skip to content

Commit 2ea1da8

Browse files
authored
Fix regression introduced in #2448 (#2551)
* Fix regression introduced in #2448 * Style.
1 parent fa6d52d commit 2ea1da8

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

src/diffusers/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
150150

151151
model_file = None
152152
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
153-
if is_safetensors_available():
153+
if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"):
154154
if weight_name is None:
155155
weight_name = LORA_WEIGHT_NAME_SAFE
156156
try:

tests/models/test_models_unet_2d_condition.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,43 @@ def test_lora_save_load_safetensors(self):
445445
# LoRA and no LoRA should NOT be the same
446446
assert (sample - old_sample).abs().max() > 1e-4
447447

448+
def test_lora_save_load_safetensors_load_torch(self):
449+
# enable deterministic behavior for gradient checkpointing
450+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
451+
452+
init_dict["attention_head_dim"] = (8, 16)
453+
454+
torch.manual_seed(0)
455+
model = self.model_class(**init_dict)
456+
model.to(torch_device)
457+
458+
lora_attn_procs = {}
459+
for name in model.attn_processors.keys():
460+
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
461+
if name.startswith("mid_block"):
462+
hidden_size = model.config.block_out_channels[-1]
463+
elif name.startswith("up_blocks"):
464+
block_id = int(name[len("up_blocks.")])
465+
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
466+
elif name.startswith("down_blocks"):
467+
block_id = int(name[len("down_blocks.")])
468+
hidden_size = model.config.block_out_channels[block_id]
469+
470+
lora_attn_procs[name] = LoRACrossAttnProcessor(
471+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
472+
)
473+
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
474+
475+
model.set_attn_processor(lora_attn_procs)
476+
# Saving as torch, properly reloads with directly filename
477+
with tempfile.TemporaryDirectory() as tmpdirname:
478+
model.save_attn_procs(tmpdirname)
479+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
480+
torch.manual_seed(0)
481+
new_model = self.model_class(**init_dict)
482+
new_model.to(torch_device)
483+
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
484+
448485
def test_lora_on_off(self):
449486
# enable deterministic behavior for gradient checkpointing
450487
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)