19
19
20
20
from .models .cross_attention import LoRACrossAttnProcessor
21
21
from .models .modeling_utils import _get_model_file
22
- from .utils import DIFFUSERS_CACHE , HF_HUB_OFFLINE , logging
22
+ from .utils import DIFFUSERS_CACHE , HF_HUB_OFFLINE , is_safetensors_available , logging
23
+
24
+
25
+ if is_safetensors_available ():
26
+ import safetensors
23
27
24
28
25
29
logger = logging .get_logger (__name__ )
26
30
27
31
28
32
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
33
+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
29
34
30
35
31
36
class AttnProcsLayers (torch .nn .Module ):
@@ -136,28 +141,53 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
136
141
use_auth_token = kwargs .pop ("use_auth_token" , None )
137
142
revision = kwargs .pop ("revision" , None )
138
143
subfolder = kwargs .pop ("subfolder" , None )
139
- weight_name = kwargs .pop ("weight_name" , LORA_WEIGHT_NAME )
144
+ weight_name = kwargs .pop ("weight_name" , None )
140
145
141
146
user_agent = {
142
147
"file_type" : "attn_procs_weights" ,
143
148
"framework" : "pytorch" ,
144
149
}
145
150
151
+ model_file = None
146
152
if not isinstance (pretrained_model_name_or_path_or_dict , dict ):
147
- model_file = _get_model_file (
148
- pretrained_model_name_or_path_or_dict ,
149
- weights_name = weight_name ,
150
- cache_dir = cache_dir ,
151
- force_download = force_download ,
152
- resume_download = resume_download ,
153
- proxies = proxies ,
154
- local_files_only = local_files_only ,
155
- use_auth_token = use_auth_token ,
156
- revision = revision ,
157
- subfolder = subfolder ,
158
- user_agent = user_agent ,
159
- )
160
- state_dict = torch .load (model_file , map_location = "cpu" )
153
+ if is_safetensors_available ():
154
+ if weight_name is None :
155
+ weight_name = LORA_WEIGHT_NAME_SAFE
156
+ try :
157
+ model_file = _get_model_file (
158
+ pretrained_model_name_or_path_or_dict ,
159
+ weights_name = weight_name ,
160
+ cache_dir = cache_dir ,
161
+ force_download = force_download ,
162
+ resume_download = resume_download ,
163
+ proxies = proxies ,
164
+ local_files_only = local_files_only ,
165
+ use_auth_token = use_auth_token ,
166
+ revision = revision ,
167
+ subfolder = subfolder ,
168
+ user_agent = user_agent ,
169
+ )
170
+ state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
171
+ except EnvironmentError :
172
+ if weight_name == LORA_WEIGHT_NAME_SAFE :
173
+ weight_name = None
174
+ if model_file is None :
175
+ if weight_name is None :
176
+ weight_name = LORA_WEIGHT_NAME
177
+ model_file = _get_model_file (
178
+ pretrained_model_name_or_path_or_dict ,
179
+ weights_name = weight_name ,
180
+ cache_dir = cache_dir ,
181
+ force_download = force_download ,
182
+ resume_download = resume_download ,
183
+ proxies = proxies ,
184
+ local_files_only = local_files_only ,
185
+ use_auth_token = use_auth_token ,
186
+ revision = revision ,
187
+ subfolder = subfolder ,
188
+ user_agent = user_agent ,
189
+ )
190
+ state_dict = torch .load (model_file , map_location = "cpu" )
161
191
else :
162
192
state_dict = pretrained_model_name_or_path_or_dict
163
193
@@ -195,8 +225,9 @@ def save_attn_procs(
195
225
self ,
196
226
save_directory : Union [str , os .PathLike ],
197
227
is_main_process : bool = True ,
198
- weights_name : str = LORA_WEIGHT_NAME ,
228
+ weights_name : str = None ,
199
229
save_function : Callable = None ,
230
+ safe_serialization : bool = False ,
200
231
):
201
232
r"""
202
233
Save an attention processor to a directory, so that it can be re-loaded using the
@@ -219,7 +250,13 @@ def save_attn_procs(
219
250
return
220
251
221
252
if save_function is None :
222
- save_function = torch .save
253
+ if safe_serialization :
254
+
255
+ def save_function (weights , filename ):
256
+ return safetensors .torch .save_file (weights , filename , metadata = {"format" : "pt" })
257
+
258
+ else :
259
+ save_function = torch .save
223
260
224
261
os .makedirs (save_directory , exist_ok = True )
225
262
@@ -237,6 +274,12 @@ def save_attn_procs(
237
274
if filename .startswith (weights_no_suffix ) and os .path .isfile (full_filename ) and is_main_process :
238
275
os .remove (full_filename )
239
276
277
+ if weights_name is None :
278
+ if safe_serialization :
279
+ weights_name = LORA_WEIGHT_NAME_SAFE
280
+ else :
281
+ weights_name = LORA_WEIGHT_NAME
282
+
240
283
# Save the model
241
284
save_function (state_dict , os .path .join (save_directory , weights_name ))
242
285
0 commit comments