Skip to content

AutoencoderKL: clamp indices of blend_h and blend_v to input size #2660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
return DecoderOutput(sample=decoded)

def blend_v(self, a, b, blend_extent):
for y in range(blend_extent):
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b

def blend_h(self, a, b, blend_extent):
for x in range(blend_extent):
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b

Expand Down
6 changes: 6 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ def test_stable_diffusion_vae_tiling(self):

assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 5e-1

# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome addition to the test!

for shape in shapes:
zeros = torch.zeros(shape).to(device)
sd_pipe.vae.decode(zeros)

def test_stable_diffusion_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
Expand Down