Skip to content

Commit ad7330e

Browse files
authored
Add inpainting test (huggingface#1011)
1 parent cf126e4 commit ad7330e

File tree

1 file changed

+51
-13
lines changed

1 file changed

+51
-13
lines changed

build_tools/stable_diffusion_testing.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import argparse
1111
from glob import glob
1212
import shutil
13+
import requests
1314

1415
model_config_dicts = get_json_file(
1516
os.path.join(
@@ -19,35 +20,72 @@
1920
)
2021

2122

23+
def get_inpaint_inputs():
24+
os.mkdir("./test_images/inputs")
25+
img_url = (
26+
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
27+
"/main/stable_diffusion_inpaint/input_bench_image.png"
28+
)
29+
mask_url = (
30+
"https://huggingface.co/datasets/diffusers/test-arrays/resolve"
31+
"/main/stable_diffusion_inpaint/input_bench_mask.png"
32+
)
33+
img = requests.get(img_url)
34+
mask = requests.get(mask_url)
35+
open("./test_images/inputs/image.png", "wb").write(img.content)
36+
open("./test_images/inputs/mask.png", "wb").write(mask.content)
37+
38+
2239
def test_loop(device="vulkan", beta=False, extra_flags=[]):
2340
# Get golden values from tank
2441
shutil.rmtree("./test_images", ignore_errors=True)
2542
os.mkdir("./test_images")
2643
os.mkdir("./test_images/golden")
44+
get_inpaint_inputs()
2745
hf_model_names = model_config_dicts[0].values()
2846
tuned_options = ["--no-use_tuned", "--use_tuned"]
2947
import_options = ["--import_mlir", "--no-import_mlir"]
3048
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
49+
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
3150
if os.name == "nt":
3251
prompt_text = '--prompt="cyberpunk forest by Salvador Dali"'
52+
inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"'
3353
if beta:
3454
extra_flags.append("--beta_models=True")
3555
for import_opt in import_options:
3656
for model_name in hf_model_names:
3757
for use_tune in tuned_options:
38-
command = [
39-
executable, # executable is the python from the venv used to run this
40-
"apps/stable_diffusion/scripts/txt2img.py",
41-
"--device=" + device,
42-
prompt_text,
43-
"--negative_prompts=" + '""',
44-
"--seed=42",
45-
import_opt,
46-
"--output_dir="
47-
+ os.path.join(os.getcwd(), "test_images", model_name),
48-
"--hf_model_id=" + model_name,
49-
use_tune,
50-
]
58+
command = (
59+
[
60+
executable, # executable is the python from the venv used to run this
61+
"apps/stable_diffusion/scripts/txt2img.py",
62+
"--device=" + device,
63+
prompt_text,
64+
"--negative_prompts=" + '""',
65+
"--seed=42",
66+
import_opt,
67+
"--output_dir="
68+
+ os.path.join(os.getcwd(), "test_images", model_name),
69+
"--hf_model_id=" + model_name,
70+
use_tune,
71+
]
72+
if "inpainting" not in model_name
73+
else [
74+
"python",
75+
"apps/stable_diffusion/scripts/inpaint.py",
76+
"--device=" + device,
77+
inpaint_prompt_text,
78+
"--negative_prompts=" + '""',
79+
"--img_path=./test_images/inputs/image.png",
80+
"--mask_path=./test_images/inputs/mask.png",
81+
"--seed=42",
82+
"--import_mlir",
83+
"--output_dir="
84+
+ os.path.join(os.getcwd(), "test_images", model_name),
85+
"--hf_model_id=" + model_name,
86+
use_tune,
87+
]
88+
)
5189
command += extra_flags
5290
if os.name == "nt":
5391
command = " ".join(command)

0 commit comments

Comments
 (0)