Skip to content

Commit a8d0977

Browse files
authored
[StableDiffusionInpaintPipeline] fix batch_size for mask and masked latents (#1279)
fix bs for mask and masked latents
1 parent c9b3463 commit a8d0977

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def __call__(
536536
mask, masked_image_latents = self.prepare_mask_latents(
537537
mask,
538538
masked_image,
539-
batch_size,
539+
batch_size * num_images_per_prompt,
540540
height,
541541
width,
542542
text_embeddings.dtype,

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,47 @@ def test_stable_diffusion_inpaint(self):
215215
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
216216
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
217217

218+
def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
219+
device = "cpu"
220+
unet = self.dummy_cond_unet_inpaint
221+
scheduler = PNDMScheduler(skip_prk_steps=True)
222+
vae = self.dummy_vae
223+
bert = self.dummy_text_encoder
224+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
225+
226+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
227+
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
228+
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
229+
230+
# make sure here that pndm scheduler skips prk
231+
sd_pipe = StableDiffusionInpaintPipeline(
232+
unet=unet,
233+
scheduler=scheduler,
234+
vae=vae,
235+
text_encoder=bert,
236+
tokenizer=tokenizer,
237+
safety_checker=None,
238+
feature_extractor=None,
239+
)
240+
sd_pipe = sd_pipe.to(device)
241+
sd_pipe.set_progress_bar_config(disable=None)
242+
243+
prompt = "A painting of a squirrel eating a burger"
244+
generator = torch.Generator(device=device).manual_seed(0)
245+
images = sd_pipe(
246+
[prompt],
247+
generator=generator,
248+
guidance_scale=6.0,
249+
num_inference_steps=2,
250+
output_type="np",
251+
image=init_image,
252+
mask_image=mask_image,
253+
num_images_per_prompt=2,
254+
).images
255+
256+
# check if the output is a list of 2 images
257+
assert len(images) == 2
258+
218259
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
219260
def test_stable_diffusion_inpaint_fp16(self):
220261
"""Test that stable diffusion inpaint_legacy works with fp16"""

0 commit comments

Comments
 (0)