Skip to content

StableDiffusionInpaintingPipeline - resize image w.r.t height and width #3322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def prepare_mask_and_masked_image(image, mask):
def prepare_mask_and_masked_image(image, mask, height, width):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
Expand Down Expand Up @@ -62,6 +62,13 @@ def prepare_mask_and_masked_image(image, mask):
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""

if image is None:
raise ValueError("`image` input cannot be undefined.")

if mask is None:
raise ValueError("`mask_image` input cannot be undefined.")

if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
Expand Down Expand Up @@ -109,8 +116,9 @@ def prepare_mask_and_masked_image(image, mask):
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]

if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
# resize all images w.r.t passed height an width
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
Copy link
Member

@sayakpaul sayakpaul May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A natural extension to the above conditional branches, IMO, is what happens when:

elif isinstance(image, list) and isinstance(image[0], np.ndarray):

We could convert to PIL from the ndarrays and then unify the logic here.

elif isinstance(image, list) and isinstance(image[0], np.ndarray):
Expand All @@ -124,6 +132,7 @@ def prepare_mask_and_masked_image(image, mask):
mask = [mask]

if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
Expand Down Expand Up @@ -787,12 +796,6 @@ def __call__(
negative_prompt_embeds,
)

if image is None:
raise ValueError("`image` input cannot be undefined.")

if mask_image is None:
raise ValueError("`mask_image` input cannot be undefined.")

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand All @@ -818,8 +821,8 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
)

# 4. Preprocess mask and image
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)

# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down
138 changes: 91 additions & 47 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,25 @@ def test_inpaint_compile(self):
assert np.abs(expected_slice - image_slice).max() < 1e-4
assert np.abs(expected_slice - image_slice).max() < 1e-3

def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", safety_checker=None
)
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()

inputs = self.get_inputs(torch_device)
# change input image to a random size (one that would cause a tensor mismatch error)
inputs['image'] = inputs['image'].resize((127,127))
inputs['mask_image'] = inputs['mask_image'].resize((127,127))
Comment on lines +314 to +315
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 127?

inputs['height'] = 128
inputs['width'] = 128
image = pipe(**inputs).images
# verify that the returned image has the same height and width as the input height and width
assert image.shape == (1, inputs['height'], inputs['width'], 3)


@nightly
@require_torch_gpu
Expand Down Expand Up @@ -397,21 +416,22 @@ def test_inpaint_dpm(self):

class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
def test_pil_inputs(self):
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
height, width = 32, 32
im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
im = Image.fromarray(im)
mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
mask = Image.fromarray((mask * 255).astype(np.uint8))

t_mask, t_masked = prepare_mask_and_masked_image(im, mask)
t_mask, t_masked = prepare_mask_and_masked_image(im, mask, height, width)

self.assertTrue(isinstance(t_mask, torch.Tensor))
self.assertTrue(isinstance(t_masked, torch.Tensor))

self.assertEqual(t_mask.ndim, 4)
self.assertEqual(t_masked.ndim, 4)

self.assertEqual(t_mask.shape, (1, 1, 32, 32))
self.assertEqual(t_masked.shape, (1, 3, 32, 32))
self.assertEqual(t_mask.shape, (1, 1, height, width))
self.assertEqual(t_masked.shape, (1, 3, height, width))

self.assertTrue(t_mask.dtype == torch.float32)
self.assertTrue(t_masked.dtype == torch.float32)
Expand All @@ -424,141 +444,165 @@ def test_pil_inputs(self):
self.assertTrue(t_mask.sum() > 0.0)

def test_np_inputs(self):
im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
height, width = 32, 32

im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
im_pil = Image.fromarray(im_np)
mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
mask_np = np.random.randint(0, 255, (height, width,), dtype=np.uint8) > 127.5
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))

t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width)

self.assertTrue((t_mask_np == t_mask_pil).all())
self.assertTrue((t_masked_np == t_masked_pil).all())

def test_torch_3D_2D_inputs(self):
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
height, width = 32, 32

im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_3D_3D_inputs(self):
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
height, width = 32, 32

im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_4D_2D_inputs(self):
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
height, width = 32, 32

im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_4D_3D_inputs(self):
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
height, width = 32, 32

im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_4D_4D_inputs(self):
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5
height, width = 32, 32

im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, 1, height, width,), dtype=torch.uint8) > 127.5
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0][0]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_batch_4D_3D(self):
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5
height, width = 32, 32

im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (2, height, width,), dtype=torch.uint8) > 127.5

im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy() for mask in mask_tensor]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_torch_batch_4D_4D(self):
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5
height, width = 32, 32

im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (2, 1, height, width,), dtype=torch.uint8) > 127.5

im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy()[0] for mask in mask_tensor]

t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])

self.assertTrue((t_mask_tensor == t_mask_np).all())
self.assertTrue((t_masked_tensor == t_masked_np).all())

def test_shape_mismatch(self):
height, width = 32, 32

# test height and width
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64))
prepare_mask_and_masked_image(torch.randn(3, height, width,), torch.randn(64, 64), height, width)
# test batch dim
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64))
prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 64, 64), height, width)
# test batch dim
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64))
prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 1, 64, 64), height, width)

def test_type_mismatch(self):
height, width = 32, 32

# test tensors-only
with self.assertRaises(TypeError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy())
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.rand(3, height, width,).numpy(), height, width)
# test tensors-only
with self.assertRaises(TypeError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32))
prepare_mask_and_masked_image(torch.rand(3, height, width,).numpy(), torch.rand(3, height, width,), height, width)

def test_channels_first(self):
height, width = 32, 32

# test channels first for 3D tensors
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32))
prepare_mask_and_masked_image(torch.rand(height, width, 3), torch.rand(3, height, width,), height, width)

def test_tensor_range(self):
height, width = 32, 32

# test im <= 1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32))
prepare_mask_and_masked_image(torch.ones(3, height, width,) * 2, torch.rand(height, width,), height, width)
# test im >= -1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32))
prepare_mask_and_masked_image(torch.ones(3, height, width,) * (-2), torch.rand(height, width,), height, width)
# test mask <= 1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2)
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * 2, height, width)
# test mask >= 0
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1)
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * -1, height, width)