Skip to content

Commit a94977b

Browse files
authored
Fix panorama to support all schedulers (#3546)
* refactor blocks init * refactor blocks loop * remove unused function and warnings * fix scheduler update location * reformat code * reformat code again * fix PNDM test case * reformat pndm test case
1 parent 8e69708 commit a94977b

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import copy
1415
import inspect
1516
import warnings
1617
from typing import Any, Callable, Dict, List, Optional, Union
@@ -21,7 +22,7 @@
2122
from ...image_processor import VaeImageProcessor
2223
from ...loaders import TextualInversionLoaderMixin
2324
from ...models import AutoencoderKL, UNet2DConditionModel
24-
from ...schedulers import DDIMScheduler, PNDMScheduler
25+
from ...schedulers import DDIMScheduler
2526
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
2627
from ..pipeline_utils import DiffusionPipeline
2728
from . import StableDiffusionPipelineOutput
@@ -96,9 +97,6 @@ def __init__(
9697
):
9798
super().__init__()
9899

99-
if isinstance(scheduler, PNDMScheduler):
100-
logger.error("PNDMScheduler for this pipeline is currently not supported.")
101-
102100
if safety_checker is None and requires_safety_checker:
103101
logger.warning(
104102
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -612,7 +610,7 @@ def __call__(
612610

613611
# 6. Define panorama grid and initialize views for synthesis.
614612
views = self.get_views(height, width)
615-
blocks_model_outputs = [None] * len(views)
613+
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views)
616614
count = torch.zeros_like(latents)
617615
value = torch.zeros_like(latents)
618616

@@ -637,6 +635,9 @@ def __call__(
637635
# get the latents corresponding to the current view coordinates
638636
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
639637

638+
# rematch block's scheduler status
639+
self.scheduler.__dict__.update(views_scheduler_status[j])
640+
640641
# expand the latents if we are doing classifier free guidance
641642
latent_model_input = (
642643
torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
@@ -657,21 +658,13 @@ def __call__(
657658
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
658659

659660
# compute the previous noisy sample x_t -> x_t-1
660-
if hasattr(self.scheduler, "model_outputs"):
661-
# rematch model_outputs in each block
662-
if i >= 1:
663-
self.scheduler.model_outputs = blocks_model_outputs[j]
664-
latents_view_denoised = self.scheduler.step(
665-
noise_pred, t, latents_for_view, **extra_step_kwargs
666-
).prev_sample
667-
# collect model_outputs
668-
blocks_model_outputs[j] = [
669-
output if output is not None else None for output in self.scheduler.model_outputs
670-
]
671-
else:
672-
latents_view_denoised = self.scheduler.step(
673-
noise_pred, t, latents_for_view, **extra_step_kwargs
674-
).prev_sample
661+
latents_view_denoised = self.scheduler.step(
662+
noise_pred, t, latents_for_view, **extra_step_kwargs
663+
).prev_sample
664+
665+
# save views scheduler status after sample
666+
views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
667+
675668
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
676669
count[:, :, h_start:h_end, w_start:w_end] += 1
677670

tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,22 @@ def test_stable_diffusion_panorama_euler(self):
174174
def test_stable_diffusion_panorama_pndm(self):
175175
device = "cpu" # ensure determinism for the device-dependent torch.Generator
176176
components = self.get_dummy_components()
177-
components["scheduler"] = PNDMScheduler()
177+
components["scheduler"] = PNDMScheduler(
178+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
179+
)
178180
sd_pipe = StableDiffusionPanoramaPipeline(**components)
179181
sd_pipe = sd_pipe.to(device)
180182
sd_pipe.set_progress_bar_config(disable=None)
181183

182184
inputs = self.get_dummy_inputs(device)
183-
# the pipeline does not expect pndm so test if it raises error.
184-
with self.assertRaises(ValueError):
185-
_ = sd_pipe(**inputs).images
185+
image = sd_pipe(**inputs).images
186+
image_slice = image[0, -3:, -3:, -1]
187+
188+
assert image.shape == (1, 64, 64, 3)
189+
190+
expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539])
191+
192+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
186193

187194

188195
@slow

0 commit comments

Comments
 (0)