@@ -445,6 +445,43 @@ def test_lora_save_load_safetensors(self):
445
445
# LoRA and no LoRA should NOT be the same
446
446
assert (sample - old_sample ).abs ().max () > 1e-4
447
447
448
+ def test_lora_save_load_safetensors_load_torch (self ):
449
+ # enable deterministic behavior for gradient checkpointing
450
+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
451
+
452
+ init_dict ["attention_head_dim" ] = (8 , 16 )
453
+
454
+ torch .manual_seed (0 )
455
+ model = self .model_class (** init_dict )
456
+ model .to (torch_device )
457
+
458
+ lora_attn_procs = {}
459
+ for name in model .attn_processors .keys ():
460
+ cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
461
+ if name .startswith ("mid_block" ):
462
+ hidden_size = model .config .block_out_channels [- 1 ]
463
+ elif name .startswith ("up_blocks" ):
464
+ block_id = int (name [len ("up_blocks." )])
465
+ hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
466
+ elif name .startswith ("down_blocks" ):
467
+ block_id = int (name [len ("down_blocks." )])
468
+ hidden_size = model .config .block_out_channels [block_id ]
469
+
470
+ lora_attn_procs [name ] = LoRACrossAttnProcessor (
471
+ hidden_size = hidden_size , cross_attention_dim = cross_attention_dim
472
+ )
473
+ lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
474
+
475
+ model .set_attn_processor (lora_attn_procs )
476
+ # Saving as torch, properly reloads with directly filename
477
+ with tempfile .TemporaryDirectory () as tmpdirname :
478
+ model .save_attn_procs (tmpdirname )
479
+ self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.bin" )))
480
+ torch .manual_seed (0 )
481
+ new_model = self .model_class (** init_dict )
482
+ new_model .to (torch_device )
483
+ new_model .load_attn_procs (tmpdirname , weight_name = "pytorch_lora_weights.bin" )
484
+
448
485
def test_lora_on_off (self ):
449
486
# enable deterministic behavior for gradient checkpointing
450
487
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
0 commit comments