Skip to content

Commit 856dad5

Browse files
is_safetensors_compatible refactor (#2499)
* is_safetensors_compatible refactor * files list comma
1 parent a75ac3f commit 856dad5

File tree

2 files changed

+176
-14
lines changed

2 files changed

+176
-14
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,21 +129,49 @@ class AudioPipelineOutput(BaseOutput):
129129

130130

131131
def is_safetensors_compatible(filenames, variant=None) -> bool:
132-
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
133-
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
134-
135-
for pt_filename in pt_filenames:
136-
_variant = f".{variant}" if (variant is not None and variant in pt_filename) else ""
137-
prefix, raw = os.path.split(pt_filename)
138-
if raw == f"pytorch_model{_variant}.bin":
139-
# transformers specific
140-
sf_filename = os.path.join(prefix, f"model{_variant}.safetensors")
132+
"""
133+
Checking for safetensors compatibility:
134+
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
135+
files to know which safetensors files are needed.
136+
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
137+
138+
Converting default pytorch serialized filenames to safetensors serialized filenames:
139+
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
140+
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
141+
extension is replaced with ".safetensors"
142+
"""
143+
pt_filenames = []
144+
145+
sf_filenames = set()
146+
147+
for filename in filenames:
148+
_, extension = os.path.splitext(filename)
149+
150+
if extension == ".bin":
151+
pt_filenames.append(filename)
152+
elif extension == ".safetensors":
153+
sf_filenames.add(filename)
154+
155+
for filename in pt_filenames:
156+
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
157+
path, filename = os.path.split(filename)
158+
filename, extension = os.path.splitext(filename)
159+
160+
if filename == "pytorch_model":
161+
filename = "model"
162+
elif filename == f"pytorch_model.{variant}":
163+
filename = f"model.{variant}"
141164
else:
142-
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
143-
if is_safetensors_compatible and sf_filename not in filenames:
144-
logger.warning(f"{sf_filename} not found")
145-
is_safetensors_compatible = False
146-
return is_safetensors_compatible
165+
filename = filename
166+
167+
expected_sf_filename = os.path.join(path, filename)
168+
expected_sf_filename = f"{expected_sf_filename}.safetensors"
169+
170+
if expected_sf_filename not in sf_filenames:
171+
logger.warning(f"{expected_sf_filename} not found")
172+
return False
173+
174+
return True
147175

148176

149177
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import unittest
2+
3+
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
4+
5+
6+
class IsSafetensorsCompatibleTests(unittest.TestCase):
7+
def test_all_is_compatible(self):
8+
filenames = [
9+
"safety_checker/pytorch_model.bin",
10+
"safety_checker/model.safetensors",
11+
"vae/diffusion_pytorch_model.bin",
12+
"vae/diffusion_pytorch_model.safetensors",
13+
"text_encoder/pytorch_model.bin",
14+
"text_encoder/model.safetensors",
15+
"unet/diffusion_pytorch_model.bin",
16+
"unet/diffusion_pytorch_model.safetensors",
17+
]
18+
self.assertTrue(is_safetensors_compatible(filenames))
19+
20+
def test_diffusers_model_is_compatible(self):
21+
filenames = [
22+
"unet/diffusion_pytorch_model.bin",
23+
"unet/diffusion_pytorch_model.safetensors",
24+
]
25+
self.assertTrue(is_safetensors_compatible(filenames))
26+
27+
def test_diffusers_model_is_not_compatible(self):
28+
filenames = [
29+
"safety_checker/pytorch_model.bin",
30+
"safety_checker/model.safetensors",
31+
"vae/diffusion_pytorch_model.bin",
32+
"vae/diffusion_pytorch_model.safetensors",
33+
"text_encoder/pytorch_model.bin",
34+
"text_encoder/model.safetensors",
35+
"unet/diffusion_pytorch_model.bin",
36+
# Removed: 'unet/diffusion_pytorch_model.safetensors',
37+
]
38+
self.assertFalse(is_safetensors_compatible(filenames))
39+
40+
def test_transformer_model_is_compatible(self):
41+
filenames = [
42+
"text_encoder/pytorch_model.bin",
43+
"text_encoder/model.safetensors",
44+
]
45+
self.assertTrue(is_safetensors_compatible(filenames))
46+
47+
def test_transformer_model_is_not_compatible(self):
48+
filenames = [
49+
"safety_checker/pytorch_model.bin",
50+
"safety_checker/model.safetensors",
51+
"vae/diffusion_pytorch_model.bin",
52+
"vae/diffusion_pytorch_model.safetensors",
53+
"text_encoder/pytorch_model.bin",
54+
# Removed: 'text_encoder/model.safetensors',
55+
"unet/diffusion_pytorch_model.bin",
56+
"unet/diffusion_pytorch_model.safetensors",
57+
]
58+
self.assertFalse(is_safetensors_compatible(filenames))
59+
60+
def test_all_is_compatible_variant(self):
61+
filenames = [
62+
"safety_checker/pytorch_model.fp16.bin",
63+
"safety_checker/model.fp16.safetensors",
64+
"vae/diffusion_pytorch_model.fp16.bin",
65+
"vae/diffusion_pytorch_model.fp16.safetensors",
66+
"text_encoder/pytorch_model.fp16.bin",
67+
"text_encoder/model.fp16.safetensors",
68+
"unet/diffusion_pytorch_model.fp16.bin",
69+
"unet/diffusion_pytorch_model.fp16.safetensors",
70+
]
71+
variant = "fp16"
72+
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
73+
74+
def test_diffusers_model_is_compatible_variant(self):
75+
filenames = [
76+
"unet/diffusion_pytorch_model.fp16.bin",
77+
"unet/diffusion_pytorch_model.fp16.safetensors",
78+
]
79+
variant = "fp16"
80+
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
81+
82+
def test_diffusers_model_is_compatible_variant_partial(self):
83+
# pass variant but use the non-variant filenames
84+
filenames = [
85+
"unet/diffusion_pytorch_model.bin",
86+
"unet/diffusion_pytorch_model.safetensors",
87+
]
88+
variant = "fp16"
89+
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
90+
91+
def test_diffusers_model_is_not_compatible_variant(self):
92+
filenames = [
93+
"safety_checker/pytorch_model.fp16.bin",
94+
"safety_checker/model.fp16.safetensors",
95+
"vae/diffusion_pytorch_model.fp16.bin",
96+
"vae/diffusion_pytorch_model.fp16.safetensors",
97+
"text_encoder/pytorch_model.fp16.bin",
98+
"text_encoder/model.fp16.safetensors",
99+
"unet/diffusion_pytorch_model.fp16.bin",
100+
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
101+
]
102+
variant = "fp16"
103+
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
104+
105+
def test_transformer_model_is_compatible_variant(self):
106+
filenames = [
107+
"text_encoder/pytorch_model.fp16.bin",
108+
"text_encoder/model.fp16.safetensors",
109+
]
110+
variant = "fp16"
111+
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
112+
113+
def test_transformer_model_is_compatible_variant_partial(self):
114+
# pass variant but use the non-variant filenames
115+
filenames = [
116+
"text_encoder/pytorch_model.bin",
117+
"text_encoder/model.safetensors",
118+
]
119+
variant = "fp16"
120+
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
121+
122+
def test_transformer_model_is_not_compatible_variant(self):
123+
filenames = [
124+
"safety_checker/pytorch_model.fp16.bin",
125+
"safety_checker/model.fp16.safetensors",
126+
"vae/diffusion_pytorch_model.fp16.bin",
127+
"vae/diffusion_pytorch_model.fp16.safetensors",
128+
"text_encoder/pytorch_model.fp16.bin",
129+
# 'text_encoder/model.fp16.safetensors',
130+
"unet/diffusion_pytorch_model.fp16.bin",
131+
"unet/diffusion_pytorch_model.fp16.safetensors",
132+
]
133+
variant = "fp16"
134+
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))

0 commit comments

Comments
 (0)