|
10 | 10 | import argparse
|
11 | 11 | from glob import glob
|
12 | 12 | import shutil
|
| 13 | +import requests |
13 | 14 |
|
14 | 15 | model_config_dicts = get_json_file(
|
15 | 16 | os.path.join(
|
|
19 | 20 | )
|
20 | 21 |
|
21 | 22 |
|
| 23 | +def get_inpaint_inputs(): |
| 24 | + os.mkdir("./test_images/inputs") |
| 25 | + img_url = ( |
| 26 | + "https://huggingface.co/datasets/diffusers/test-arrays/resolve" |
| 27 | + "/main/stable_diffusion_inpaint/input_bench_image.png" |
| 28 | + ) |
| 29 | + mask_url = ( |
| 30 | + "https://huggingface.co/datasets/diffusers/test-arrays/resolve" |
| 31 | + "/main/stable_diffusion_inpaint/input_bench_mask.png" |
| 32 | + ) |
| 33 | + img = requests.get(img_url) |
| 34 | + mask = requests.get(mask_url) |
| 35 | + open("./test_images/inputs/image.png", "wb").write(img.content) |
| 36 | + open("./test_images/inputs/mask.png", "wb").write(mask.content) |
| 37 | + |
| 38 | + |
22 | 39 | def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
23 | 40 | # Get golden values from tank
|
24 | 41 | shutil.rmtree("./test_images", ignore_errors=True)
|
25 | 42 | os.mkdir("./test_images")
|
26 | 43 | os.mkdir("./test_images/golden")
|
| 44 | + get_inpaint_inputs() |
27 | 45 | hf_model_names = model_config_dicts[0].values()
|
28 | 46 | tuned_options = ["--no-use_tuned", "--use_tuned"]
|
29 | 47 | import_options = ["--import_mlir", "--no-import_mlir"]
|
30 | 48 | prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
|
| 49 | + inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" |
31 | 50 | if os.name == "nt":
|
32 | 51 | prompt_text = '--prompt="cyberpunk forest by Salvador Dali"'
|
| 52 | + inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' |
33 | 53 | if beta:
|
34 | 54 | extra_flags.append("--beta_models=True")
|
35 | 55 | for import_opt in import_options:
|
36 | 56 | for model_name in hf_model_names:
|
37 | 57 | for use_tune in tuned_options:
|
38 |
| - command = [ |
39 |
| - executable, # executable is the python from the venv used to run this |
40 |
| - "apps/stable_diffusion/scripts/txt2img.py", |
41 |
| - "--device=" + device, |
42 |
| - prompt_text, |
43 |
| - "--negative_prompts=" + '""', |
44 |
| - "--seed=42", |
45 |
| - import_opt, |
46 |
| - "--output_dir=" |
47 |
| - + os.path.join(os.getcwd(), "test_images", model_name), |
48 |
| - "--hf_model_id=" + model_name, |
49 |
| - use_tune, |
50 |
| - ] |
| 58 | + command = ( |
| 59 | + [ |
| 60 | + executable, # executable is the python from the venv used to run this |
| 61 | + "apps/stable_diffusion/scripts/txt2img.py", |
| 62 | + "--device=" + device, |
| 63 | + prompt_text, |
| 64 | + "--negative_prompts=" + '""', |
| 65 | + "--seed=42", |
| 66 | + import_opt, |
| 67 | + "--output_dir=" |
| 68 | + + os.path.join(os.getcwd(), "test_images", model_name), |
| 69 | + "--hf_model_id=" + model_name, |
| 70 | + use_tune, |
| 71 | + ] |
| 72 | + if "inpainting" not in model_name |
| 73 | + else [ |
| 74 | + "python", |
| 75 | + "apps/stable_diffusion/scripts/inpaint.py", |
| 76 | + "--device=" + device, |
| 77 | + inpaint_prompt_text, |
| 78 | + "--negative_prompts=" + '""', |
| 79 | + "--img_path=./test_images/inputs/image.png", |
| 80 | + "--mask_path=./test_images/inputs/mask.png", |
| 81 | + "--seed=42", |
| 82 | + "--import_mlir", |
| 83 | + "--output_dir=" |
| 84 | + + os.path.join(os.getcwd(), "test_images", model_name), |
| 85 | + "--hf_model_id=" + model_name, |
| 86 | + use_tune, |
| 87 | + ] |
| 88 | + ) |
51 | 89 | command += extra_flags
|
52 | 90 | if os.name == "nt":
|
53 | 91 | command = " ".join(command)
|
|
0 commit comments