11
11
# See the License for the specific language governing permissions and
12
12
# limitations under the License.
13
13
14
+ import copy
14
15
import inspect
15
16
import warnings
16
17
from typing import Any , Callable , Dict , List , Optional , Union
21
22
from ...image_processor import VaeImageProcessor
22
23
from ...loaders import TextualInversionLoaderMixin
23
24
from ...models import AutoencoderKL , UNet2DConditionModel
24
- from ...schedulers import DDIMScheduler , PNDMScheduler
25
+ from ...schedulers import DDIMScheduler
25
26
from ...utils import is_accelerate_available , is_accelerate_version , logging , randn_tensor , replace_example_docstring
26
27
from ..pipeline_utils import DiffusionPipeline
27
28
from . import StableDiffusionPipelineOutput
@@ -96,9 +97,6 @@ def __init__(
96
97
):
97
98
super ().__init__ ()
98
99
99
- if isinstance (scheduler , PNDMScheduler ):
100
- logger .error ("PNDMScheduler for this pipeline is currently not supported." )
101
-
102
100
if safety_checker is None and requires_safety_checker :
103
101
logger .warning (
104
102
f"You have disabled the safety checker for { self .__class__ } by passing `safety_checker=None`. Ensure"
@@ -612,7 +610,7 @@ def __call__(
612
610
613
611
# 6. Define panorama grid and initialize views for synthesis.
614
612
views = self .get_views (height , width )
615
- blocks_model_outputs = [None ] * len (views )
613
+ views_scheduler_status = [copy . deepcopy ( self . scheduler . __dict__ ) ] * len (views )
616
614
count = torch .zeros_like (latents )
617
615
value = torch .zeros_like (latents )
618
616
@@ -637,6 +635,9 @@ def __call__(
637
635
# get the latents corresponding to the current view coordinates
638
636
latents_for_view = latents [:, :, h_start :h_end , w_start :w_end ]
639
637
638
+ # rematch block's scheduler status
639
+ self .scheduler .__dict__ .update (views_scheduler_status [j ])
640
+
640
641
# expand the latents if we are doing classifier free guidance
641
642
latent_model_input = (
642
643
torch .cat ([latents_for_view ] * 2 ) if do_classifier_free_guidance else latents_for_view
@@ -657,21 +658,13 @@ def __call__(
657
658
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
658
659
659
660
# compute the previous noisy sample x_t -> x_t-1
660
- if hasattr (self .scheduler , "model_outputs" ):
661
- # rematch model_outputs in each block
662
- if i >= 1 :
663
- self .scheduler .model_outputs = blocks_model_outputs [j ]
664
- latents_view_denoised = self .scheduler .step (
665
- noise_pred , t , latents_for_view , ** extra_step_kwargs
666
- ).prev_sample
667
- # collect model_outputs
668
- blocks_model_outputs [j ] = [
669
- output if output is not None else None for output in self .scheduler .model_outputs
670
- ]
671
- else :
672
- latents_view_denoised = self .scheduler .step (
673
- noise_pred , t , latents_for_view , ** extra_step_kwargs
674
- ).prev_sample
661
+ latents_view_denoised = self .scheduler .step (
662
+ noise_pred , t , latents_for_view , ** extra_step_kwargs
663
+ ).prev_sample
664
+
665
+ # save views scheduler status after sample
666
+ views_scheduler_status [j ] = copy .deepcopy (self .scheduler .__dict__ )
667
+
675
668
value [:, :, h_start :h_end , w_start :w_end ] += latents_view_denoised
676
669
count [:, :, h_start :h_end , w_start :w_end ] += 1
677
670
0 commit comments