@@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
494
494
_optional_components = []
495
495
_exclude_from_cpu_offload = []
496
496
_load_connected_pipes = False
497
+ _is_onnx = False
497
498
498
499
def register_modules (self , ** kwargs ):
499
500
# import it here to avoid circular import
@@ -839,6 +840,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
839
840
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
840
841
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
841
842
weights. If set to `False`, safetensors weights are not loaded.
843
+ use_onnx (`bool`, *optional*, defaults to `None`):
844
+ If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
845
+ will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
846
+ `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
847
+ with `.onnx` and `.pb`.
842
848
kwargs (remaining dictionary of keyword arguments, *optional*):
843
849
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
844
850
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -1268,6 +1274,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1268
1274
variant (`str`, *optional*):
1269
1275
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
1270
1276
loading `from_flax`.
1277
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1278
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1279
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1280
+ weights. If set to `False`, safetensors weights are not loaded.
1281
+ use_onnx (`bool`, *optional*, defaults to `False`):
1282
+ If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
1283
+ will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
1284
+ `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
1285
+ with `.onnx` and `.pb`.
1271
1286
1272
1287
Returns:
1273
1288
`os.PathLike`:
@@ -1293,6 +1308,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1293
1308
custom_revision = kwargs .pop ("custom_revision" , None )
1294
1309
variant = kwargs .pop ("variant" , None )
1295
1310
use_safetensors = kwargs .pop ("use_safetensors" , None )
1311
+ use_onnx = kwargs .pop ("use_onnx" , None )
1296
1312
load_connected_pipeline = kwargs .pop ("load_connected_pipeline" , False )
1297
1313
1298
1314
if use_safetensors and not is_safetensors_available ():
@@ -1364,7 +1380,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1364
1380
pretrained_model_name , use_auth_token , variant , revision , model_filenames
1365
1381
)
1366
1382
1367
- model_folder_names = {os .path .split (f )[0 ] for f in model_filenames }
1383
+ model_folder_names = {os .path .split (f )[0 ] for f in model_filenames if os . path . split ( f )[ 0 ] in folder_names }
1368
1384
1369
1385
# all filenames compatible with variant will be added
1370
1386
allow_patterns = list (model_filenames )
@@ -1411,6 +1427,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1411
1427
):
1412
1428
ignore_patterns = ["*.bin" , "*.msgpack" ]
1413
1429
1430
+ use_onnx = use_onnx if use_onnx is not None else pipeline_class ._is_onnx
1431
+ if not use_onnx :
1432
+ ignore_patterns += ["*.onnx" , "*.pb" ]
1433
+
1414
1434
safetensors_variant_filenames = {f for f in variant_filenames if f .endswith (".safetensors" )}
1415
1435
safetensors_model_filenames = {f for f in model_filenames if f .endswith (".safetensors" )}
1416
1436
if (
@@ -1423,6 +1443,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1423
1443
else :
1424
1444
ignore_patterns = ["*.safetensors" , "*.msgpack" ]
1425
1445
1446
+ use_onnx = use_onnx if use_onnx is not None else pipeline_class ._is_onnx
1447
+ if not use_onnx :
1448
+ ignore_patterns += ["*.onnx" , "*.pb" ]
1449
+
1426
1450
bin_variant_filenames = {f for f in variant_filenames if f .endswith (".bin" )}
1427
1451
bin_model_filenames = {f for f in model_filenames if f .endswith (".bin" )}
1428
1452
if len (bin_variant_filenames ) > 0 and bin_model_filenames != bin_variant_filenames :
0 commit comments