1111# See the License for the specific language governing permissions and
1212# limitations under the License.
1313
14+ import copy
1415import inspect
1516import warnings
1617from typing import Any , Callable , Dict , List , Optional , Union
2122from ...image_processor import VaeImageProcessor
2223from ...loaders import TextualInversionLoaderMixin
2324from ...models import AutoencoderKL , UNet2DConditionModel
24- from ...schedulers import DDIMScheduler , PNDMScheduler
25+ from ...schedulers import DDIMScheduler
2526from ...utils import is_accelerate_available , is_accelerate_version , logging , randn_tensor , replace_example_docstring
2627from ..pipeline_utils import DiffusionPipeline
2728from . import StableDiffusionPipelineOutput
@@ -96,9 +97,6 @@ def __init__(
9697 ):
9798 super ().__init__ ()
9899
99- if isinstance (scheduler , PNDMScheduler ):
100- logger .error ("PNDMScheduler for this pipeline is currently not supported." )
101-
102100 if safety_checker is None and requires_safety_checker :
103101 logger .warning (
104102 f"You have disabled the safety checker for { self .__class__ } by passing `safety_checker=None`. Ensure"
@@ -612,7 +610,7 @@ def __call__(
612610
613611 # 6. Define panorama grid and initialize views for synthesis.
614612 views = self .get_views (height , width )
615- blocks_model_outputs = [None ] * len (views )
613+ views_scheduler_status = [copy . deepcopy ( self . scheduler . __dict__ ) ] * len (views )
616614 count = torch .zeros_like (latents )
617615 value = torch .zeros_like (latents )
618616
@@ -637,6 +635,9 @@ def __call__(
637635 # get the latents corresponding to the current view coordinates
638636 latents_for_view = latents [:, :, h_start :h_end , w_start :w_end ]
639637
638+ # rematch block's scheduler status
639+ self .scheduler .__dict__ .update (views_scheduler_status [j ])
640+
640641 # expand the latents if we are doing classifier free guidance
641642 latent_model_input = (
642643 torch .cat ([latents_for_view ] * 2 ) if do_classifier_free_guidance else latents_for_view
@@ -657,21 +658,13 @@ def __call__(
657658 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
658659
659660 # compute the previous noisy sample x_t -> x_t-1
660- if hasattr (self .scheduler , "model_outputs" ):
661- # rematch model_outputs in each block
662- if i >= 1 :
663- self .scheduler .model_outputs = blocks_model_outputs [j ]
664- latents_view_denoised = self .scheduler .step (
665- noise_pred , t , latents_for_view , ** extra_step_kwargs
666- ).prev_sample
667- # collect model_outputs
668- blocks_model_outputs [j ] = [
669- output if output is not None else None for output in self .scheduler .model_outputs
670- ]
671- else :
672- latents_view_denoised = self .scheduler .step (
673- noise_pred , t , latents_for_view , ** extra_step_kwargs
674- ).prev_sample
661+ latents_view_denoised = self .scheduler .step (
662+ noise_pred , t , latents_for_view , ** extra_step_kwargs
663+ ).prev_sample
664+
665+ # save views scheduler status after sample
666+ views_scheduler_status [j ] = copy .deepcopy (self .scheduler .__dict__ )
667+
675668 value [:, :, h_start :h_end , w_start :w_end ] += latents_view_denoised
676669 count [:, :, h_start :h_end , w_start :w_end ] += 1
677670
0 commit comments