Skip to content

Commit 9cde56a

Browse files
patrickvonplatensayakpaul
authored andcommitted
[ONNX] Don't download ONNX model by default (#4338)
* [Download] Don't download ONNX weights by default * [Download] Don't download ONNX weights by default * [Download] Don't download ONNX weights by default * fix more * finish * finish * finish
1 parent c63d7cd commit 9cde56a

7 files changed

+74
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
494494
_optional_components = []
495495
_exclude_from_cpu_offload = []
496496
_load_connected_pipes = False
497+
_is_onnx = False
497498

498499
def register_modules(self, **kwargs):
499500
# import it here to avoid circular import
@@ -839,6 +840,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
839840
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
840841
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
841842
weights. If set to `False`, safetensors weights are not loaded.
843+
use_onnx (`bool`, *optional*, defaults to `None`):
844+
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
845+
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
846+
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
847+
with `.onnx` and `.pb`.
842848
kwargs (remaining dictionary of keyword arguments, *optional*):
843849
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
844850
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -1268,6 +1274,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12681274
variant (`str`, *optional*):
12691275
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
12701276
loading `from_flax`.
1277+
use_safetensors (`bool`, *optional*, defaults to `None`):
1278+
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1279+
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1280+
weights. If set to `False`, safetensors weights are not loaded.
1281+
use_onnx (`bool`, *optional*, defaults to `False`):
1282+
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
1283+
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
1284+
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
1285+
with `.onnx` and `.pb`.
12711286
12721287
Returns:
12731288
`os.PathLike`:
@@ -1293,6 +1308,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12931308
custom_revision = kwargs.pop("custom_revision", None)
12941309
variant = kwargs.pop("variant", None)
12951310
use_safetensors = kwargs.pop("use_safetensors", None)
1311+
use_onnx = kwargs.pop("use_onnx", None)
12961312
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
12971313

12981314
if use_safetensors and not is_safetensors_available():
@@ -1364,7 +1380,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13641380
pretrained_model_name, use_auth_token, variant, revision, model_filenames
13651381
)
13661382

1367-
model_folder_names = {os.path.split(f)[0] for f in model_filenames}
1383+
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
13681384

13691385
# all filenames compatible with variant will be added
13701386
allow_patterns = list(model_filenames)
@@ -1411,6 +1427,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14111427
):
14121428
ignore_patterns = ["*.bin", "*.msgpack"]
14131429

1430+
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
1431+
if not use_onnx:
1432+
ignore_patterns += ["*.onnx", "*.pb"]
1433+
14141434
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
14151435
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
14161436
if (
@@ -1423,6 +1443,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14231443
else:
14241444
ignore_patterns = ["*.safetensors", "*.msgpack"]
14251445

1446+
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
1447+
if not use_onnx:
1448+
ignore_patterns += ["*.onnx", "*.pb"]
1449+
14261450
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
14271451
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
14281452
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
4141
feature_extractor: CLIPImageProcessor
4242

4343
_optional_components = ["safety_checker", "feature_extractor"]
44+
_is_onnx = True
4445

4546
def __init__(
4647
self,

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
9898
feature_extractor: CLIPImageProcessor
9999

100100
_optional_components = ["safety_checker", "feature_extractor"]
101+
_is_onnx = True
101102

102103
def __init__(
103104
self,

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
9090
feature_extractor: CLIPImageProcessor
9191

9292
_optional_components = ["safety_checker", "feature_extractor"]
93+
_is_onnx = True
9394

9495
def __init__(
9596
self,

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
6767
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
6868
"""
6969
_optional_components = ["safety_checker", "feature_extractor"]
70+
_is_onnx = True
7071

7172
vae_encoder: OnnxRuntimeModel
7273
vae_decoder: OnnxRuntimeModel

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def preprocess(image):
4646

4747

4848
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
49+
_is_onnx = True
50+
4951
def __init__(
5052
self,
5153
vae: OnnxRuntimeModel,

tests/pipelines/test_pipelines.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,49 @@ def test_download_bin_index(self):
310310
assert len([f for f in files if ".bin" in f]) == 8
311311
assert not any(".safetensors" in f for f in files)
312312

313+
def test_download_no_openvino_by_default(self):
314+
with tempfile.TemporaryDirectory() as tmpdirname:
315+
tmpdirname = DiffusionPipeline.download(
316+
"hf-internal-testing/tiny-stable-diffusion-open-vino",
317+
cache_dir=tmpdirname,
318+
)
319+
320+
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
321+
files = [item for sublist in all_root_files for item in sublist]
322+
323+
# make sure that by default no openvino weights are downloaded
324+
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
325+
assert not any("openvino_" in f for f in files)
326+
327+
def test_download_no_onnx_by_default(self):
328+
with tempfile.TemporaryDirectory() as tmpdirname:
329+
tmpdirname = DiffusionPipeline.download(
330+
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
331+
cache_dir=tmpdirname,
332+
)
333+
334+
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
335+
files = [item for sublist in all_root_files for item in sublist]
336+
337+
# make sure that by default no onnx weights are downloaded
338+
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
339+
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
340+
341+
with tempfile.TemporaryDirectory() as tmpdirname:
342+
tmpdirname = DiffusionPipeline.download(
343+
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
344+
cache_dir=tmpdirname,
345+
use_onnx=True,
346+
)
347+
348+
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
349+
files = [item for sublist in all_root_files for item in sublist]
350+
351+
# if `use_onnx` is specified make sure weights are downloaded
352+
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
353+
assert any((f.endswith(".onnx")) for f in files)
354+
assert any((f.endswith(".pb")) for f in files)
355+
313356
def test_download_no_safety_checker(self):
314357
prompt = "hello"
315358
pipe = StableDiffusionPipeline.from_pretrained(

0 commit comments

Comments
 (0)