Skip to content

Commit a72a057

Browse files
move test num_images_per_prompt to pipeline mixin (#2488)
* attend and excite batch test causing timeouts * move test num_images_per_prompt to pipeline mixin * style * prompt_key -> self.batch_params
1 parent 2f48957 commit a72a057

13 files changed

+55
-329
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,30 @@ def check_inputs(
517517
f" {negative_prompt_embeds.shape}."
518518
)
519519

520-
if (indices is None) or (indices is not None and not isinstance(indices, List)):
521-
raise ValueError(f"`indices` has to be a list but is {type(indices)}")
520+
indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int)
521+
indices_is_list_list_ints = (
522+
isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int)
523+
)
524+
525+
if not indices_is_list_ints and not indices_is_list_list_ints:
526+
raise TypeError("`indices` must be a list of ints or a list of a list of ints")
527+
528+
if indices_is_list_ints:
529+
indices_batch_size = 1
530+
elif indices_is_list_list_ints:
531+
indices_batch_size = len(indices)
532+
533+
if prompt is not None and isinstance(prompt, str):
534+
prompt_batch_size = 1
535+
elif prompt is not None and isinstance(prompt, list):
536+
prompt_batch_size = len(prompt)
537+
elif prompt_embeds is not None:
538+
prompt_batch_size = prompt_embeds.shape[0]
539+
540+
if indices_batch_size != prompt_batch_size:
541+
raise ValueError(
542+
f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}"
543+
)
522544

523545
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
524546
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
@@ -675,7 +697,7 @@ def get_indices(self, prompt: str) -> Dict[str, int]:
675697
def __call__(
676698
self,
677699
prompt: Union[str, List[str]],
678-
token_indices: List[int],
700+
token_indices: Union[List[int], List[List[int]]],
679701
height: Optional[int] = None,
680702
width: Optional[int] = None,
681703
num_inference_steps: int = 50,
@@ -851,7 +873,9 @@ def __call__(
851873

852874
if isinstance(token_indices[0], int):
853875
token_indices = [token_indices]
876+
854877
indices = []
878+
855879
for ind in token_indices:
856880
indices = indices + [ind] * num_images_per_prompt
857881

tests/pipelines/paint_by_example/test_paint_by_example.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,6 @@ def test_paint_by_example_image_tensor(self):
160160
assert out_1.shape == (1, 64, 64, 3)
161161
assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
162162

163-
def test_paint_by_example_inpaint_with_num_images_per_prompt(self):
164-
device = "cpu"
165-
pipe = PaintByExamplePipeline(**self.get_dummy_components())
166-
pipe = pipe.to(device)
167-
pipe.set_progress_bar_config(disable=None)
168-
169-
inputs = self.get_dummy_inputs()
170-
171-
images = pipe(**inputs, num_images_per_prompt=2).images
172-
173-
# check if the output is a list of 2 images
174-
assert len(images) == 2
175-
176163

177164
@slow
178165
@require_torch_gpu

tests/pipelines/stable_diffusion/test_cycle_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4141
"negative_prompt_embeds",
4242
}
4343
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
44-
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
44+
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
4545

4646
def get_dummy_components(self):
4747
torch.manual_seed(0)

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -477,43 +477,6 @@ def test_stable_diffusion_negative_prompt(self):
477477

478478
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
479479

480-
def test_stable_diffusion_num_images_per_prompt(self):
481-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
482-
components = self.get_dummy_components()
483-
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
484-
sd_pipe = StableDiffusionPipeline(**components)
485-
sd_pipe = sd_pipe.to(device)
486-
sd_pipe.set_progress_bar_config(disable=None)
487-
488-
prompt = "A painting of a squirrel eating a burger"
489-
490-
# test num_images_per_prompt=1 (default)
491-
images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
492-
493-
assert images.shape == (1, 64, 64, 3)
494-
495-
# test num_images_per_prompt=1 (default) for batch of prompts
496-
batch_size = 2
497-
images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images
498-
499-
assert images.shape == (batch_size, 64, 64, 3)
500-
501-
# test num_images_per_prompt for single prompt
502-
num_images_per_prompt = 2
503-
images = sd_pipe(
504-
prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
505-
).images
506-
507-
assert images.shape == (num_images_per_prompt, 64, 64, 3)
508-
509-
# test num_images_per_prompt for batch of prompts
510-
batch_size = 2
511-
images = sd_pipe(
512-
[prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
513-
).images
514-
515-
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
516-
517480
def test_stable_diffusion_long_prompt(self):
518481
components = self.get_dummy_components()
519482
components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)

tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -143,42 +143,6 @@ def test_stable_diffusion_img_variation_multiple_images(self):
143143

144144
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
145145

146-
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
147-
device = "cpu"
148-
components = self.get_dummy_components()
149-
sd_pipe = StableDiffusionImageVariationPipeline(**components)
150-
sd_pipe = sd_pipe.to(device)
151-
sd_pipe.set_progress_bar_config(disable=None)
152-
153-
# test num_images_per_prompt=1 (default)
154-
inputs = self.get_dummy_inputs(device)
155-
images = sd_pipe(**inputs).images
156-
157-
assert images.shape == (1, 64, 64, 3)
158-
159-
# test num_images_per_prompt=1 (default) for batch of images
160-
batch_size = 2
161-
inputs = self.get_dummy_inputs(device)
162-
inputs["image"] = batch_size * [inputs["image"]]
163-
images = sd_pipe(**inputs).images
164-
165-
assert images.shape == (batch_size, 64, 64, 3)
166-
167-
# test num_images_per_prompt for single prompt
168-
num_images_per_prompt = 2
169-
inputs = self.get_dummy_inputs(device)
170-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
171-
172-
assert images.shape == (num_images_per_prompt, 64, 64, 3)
173-
174-
# test num_images_per_prompt for batch of prompts
175-
batch_size = 2
176-
inputs = self.get_dummy_inputs(device)
177-
inputs["image"] = batch_size * [inputs["image"]]
178-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
179-
180-
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
181-
182146

183147
@slow
184148
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -181,42 +181,6 @@ def test_stable_diffusion_img2img_k_lms(self):
181181

182182
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
183183

184-
def test_stable_diffusion_img2img_num_images_per_prompt(self):
185-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
186-
components = self.get_dummy_components()
187-
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
188-
sd_pipe = sd_pipe.to(device)
189-
sd_pipe.set_progress_bar_config(disable=None)
190-
191-
# test num_images_per_prompt=1 (default)
192-
inputs = self.get_dummy_inputs(device)
193-
images = sd_pipe(**inputs).images
194-
195-
assert images.shape == (1, 32, 32, 3)
196-
197-
# test num_images_per_prompt=1 (default) for batch of prompts
198-
batch_size = 2
199-
inputs = self.get_dummy_inputs(device)
200-
inputs["prompt"] = [inputs["prompt"]] * batch_size
201-
images = sd_pipe(**inputs).images
202-
203-
assert images.shape == (batch_size, 32, 32, 3)
204-
205-
# test num_images_per_prompt for single prompt
206-
num_images_per_prompt = 2
207-
inputs = self.get_dummy_inputs(device)
208-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
209-
210-
assert images.shape == (num_images_per_prompt, 32, 32, 3)
211-
212-
# test num_images_per_prompt for batch of prompts
213-
batch_size = 2
214-
inputs = self.get_dummy_inputs(device)
215-
inputs["prompt"] = [inputs["prompt"]] * batch_size
216-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
217-
218-
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
219-
220184
@skip_mps
221185
def test_save_load_local(self):
222186
return super().test_save_load_local()

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,6 @@ def test_stable_diffusion_inpaint_image_tensor(self):
151151
assert out_pil.shape == (1, 64, 64, 3)
152152
assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
153153

154-
def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
155-
device = "cpu"
156-
components = self.get_dummy_components()
157-
sd_pipe = StableDiffusionInpaintPipeline(**components)
158-
sd_pipe = sd_pipe.to(device)
159-
sd_pipe.set_progress_bar_config(disable=None)
160-
161-
inputs = self.get_dummy_inputs(device)
162-
images = sd_pipe(**inputs, num_images_per_prompt=2).images
163-
164-
# check if the output is a list of 2 images
165-
assert len(images) == 2
166-
167154

168155
@slow
169156
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -191,42 +191,6 @@ def test_stable_diffusion_pix2pix_euler(self):
191191

192192
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
193193

194-
def test_stable_diffusion_pix2pix_num_images_per_prompt(self):
195-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
196-
components = self.get_dummy_components()
197-
sd_pipe = StableDiffusionInstructPix2PixPipeline(**components)
198-
sd_pipe = sd_pipe.to(device)
199-
sd_pipe.set_progress_bar_config(disable=None)
200-
201-
# test num_images_per_prompt=1 (default)
202-
inputs = self.get_dummy_inputs(device)
203-
images = sd_pipe(**inputs).images
204-
205-
assert images.shape == (1, 32, 32, 3)
206-
207-
# test num_images_per_prompt=1 (default) for batch of prompts
208-
batch_size = 2
209-
inputs = self.get_dummy_inputs(device)
210-
inputs["prompt"] = [inputs["prompt"]] * batch_size
211-
images = sd_pipe(**inputs).images
212-
213-
assert images.shape == (batch_size, 32, 32, 3)
214-
215-
# test num_images_per_prompt for single prompt
216-
num_images_per_prompt = 2
217-
inputs = self.get_dummy_inputs(device)
218-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
219-
220-
assert images.shape == (num_images_per_prompt, 32, 32, 3)
221-
222-
# test num_images_per_prompt for batch of prompts
223-
batch_size = 2
224-
inputs = self.get_dummy_inputs(device)
225-
inputs["prompt"] = [inputs["prompt"]] * batch_size
226-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
227-
228-
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
229-
230194

231195
@slow
232196
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -177,42 +177,6 @@ def test_stable_diffusion_panorama_pndm(self):
177177
with self.assertRaises(ValueError):
178178
_ = sd_pipe(**inputs).images
179179

180-
def test_stable_diffusion_panorama_num_images_per_prompt(self):
181-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
182-
components = self.get_dummy_components()
183-
sd_pipe = StableDiffusionPanoramaPipeline(**components)
184-
sd_pipe = sd_pipe.to(device)
185-
sd_pipe.set_progress_bar_config(disable=None)
186-
187-
# test num_images_per_prompt=1 (default)
188-
inputs = self.get_dummy_inputs(device)
189-
images = sd_pipe(**inputs).images
190-
191-
assert images.shape == (1, 64, 64, 3)
192-
193-
# test num_images_per_prompt=1 (default) for batch of prompts
194-
batch_size = 2
195-
inputs = self.get_dummy_inputs(device)
196-
inputs["prompt"] = [inputs["prompt"]] * batch_size
197-
images = sd_pipe(**inputs).images
198-
199-
assert images.shape == (batch_size, 64, 64, 3)
200-
201-
# test num_images_per_prompt for single prompt
202-
num_images_per_prompt = 2
203-
inputs = self.get_dummy_inputs(device)
204-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
205-
206-
assert images.shape == (num_images_per_prompt, 64, 64, 3)
207-
208-
# test num_images_per_prompt for batch of prompts
209-
batch_size = 2
210-
inputs = self.get_dummy_inputs(device)
211-
inputs["prompt"] = [inputs["prompt"]] * batch_size
212-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
213-
214-
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
215-
216180

217181
@slow
218182
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -191,34 +191,6 @@ def test_stable_diffusion_pix2pix_zero_ddpm(self):
191191

192192
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
193193

194-
def test_stable_diffusion_pix2pix_zero_num_images_per_prompt(self):
195-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
196-
components = self.get_dummy_components()
197-
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
198-
sd_pipe = sd_pipe.to(device)
199-
sd_pipe.set_progress_bar_config(disable=None)
200-
201-
# test num_images_per_prompt=1 (default)
202-
inputs = self.get_dummy_inputs(device)
203-
images = sd_pipe(**inputs).images
204-
205-
assert images.shape == (1, 64, 64, 3)
206-
207-
# test num_images_per_prompt=2 for a single prompt
208-
num_images_per_prompt = 2
209-
inputs = self.get_dummy_inputs(device)
210-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
211-
212-
assert images.shape == (num_images_per_prompt, 64, 64, 3)
213-
214-
# test num_images_per_prompt for batch of prompts
215-
batch_size = 2
216-
inputs = self.get_dummy_inputs(device)
217-
inputs["prompt"] = [inputs["prompt"]] * batch_size
218-
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
219-
220-
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
221-
222194
# Non-determinism caused by the scheduler optimizing the latent inputs during inference
223195
@unittest.skip("non-deterministic pipeline")
224196
def test_inference_batch_single_identical(self):

0 commit comments

Comments
 (0)