12
12
from diffusers .pipelines .stable_diffusion import StableDiffusionPipelineOutput
13
13
from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
14
14
from diffusers .schedulers import DDIMScheduler , LMSDiscreteScheduler , PNDMScheduler
15
- from diffusers .utils import deprecate , logging
15
+ from diffusers .utils import deprecate , is_accelerate_available , logging
16
16
from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
17
17
18
18
@@ -340,13 +340,15 @@ def get_weighted_text_embeddings(
340
340
# assign weights to the prompts and normalize in the sense of mean
341
341
# TODO: should we normalize by chunk or in a whole (current implementation)?
342
342
if (not skip_parsing ) and (not skip_weighting ):
343
- previous_mean = text_embeddings .mean (axis = [- 2 , - 1 ])
343
+ previous_mean = text_embeddings .float (). mean (axis = [- 2 , - 1 ]). to ( text_embeddings . dtype )
344
344
text_embeddings *= prompt_weights .unsqueeze (- 1 )
345
- text_embeddings *= (previous_mean / text_embeddings .mean (axis = [- 2 , - 1 ])).unsqueeze (- 1 ).unsqueeze (- 1 )
345
+ current_mean = text_embeddings .float ().mean (axis = [- 2 , - 1 ]).to (text_embeddings .dtype )
346
+ text_embeddings *= (previous_mean / current_mean ).unsqueeze (- 1 ).unsqueeze (- 1 )
346
347
if uncond_prompt is not None :
347
- previous_mean = uncond_embeddings .mean (axis = [- 2 , - 1 ])
348
+ previous_mean = uncond_embeddings .float (). mean (axis = [- 2 , - 1 ]). to ( uncond_embeddings . dtype )
348
349
uncond_embeddings *= uncond_weights .unsqueeze (- 1 )
349
- uncond_embeddings *= (previous_mean / uncond_embeddings .mean (axis = [- 2 , - 1 ])).unsqueeze (- 1 ).unsqueeze (- 1 )
350
+ current_mean = uncond_embeddings .float ().mean (axis = [- 2 , - 1 ]).to (uncond_embeddings .dtype )
351
+ uncond_embeddings *= (previous_mean / current_mean ).unsqueeze (- 1 ).unsqueeze (- 1 )
350
352
351
353
if uncond_prompt is not None :
352
354
return text_embeddings , uncond_embeddings
@@ -431,6 +433,19 @@ def __init__(
431
433
new_config ["steps_offset" ] = 1
432
434
scheduler ._internal_dict = FrozenDict (new_config )
433
435
436
+ if hasattr (scheduler .config , "clip_sample" ) and scheduler .config .clip_sample is True :
437
+ deprecation_message = (
438
+ f"The configuration file of this scheduler: { scheduler } has not set the configuration `clip_sample`."
439
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
440
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
441
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
442
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
443
+ )
444
+ deprecate ("clip_sample not set" , "1.0.0" , deprecation_message , standard_warn = False )
445
+ new_config = dict (scheduler .config )
446
+ new_config ["clip_sample" ] = False
447
+ scheduler ._internal_dict = FrozenDict (new_config )
448
+
434
449
if safety_checker is None :
435
450
logger .warn (
436
451
f"You have disabled the safety checker for { self .__class__ } by passing `safety_checker=None`. Ensure"
@@ -451,6 +466,24 @@ def __init__(
451
466
feature_extractor = feature_extractor ,
452
467
)
453
468
469
+ def enable_xformers_memory_efficient_attention (self ):
470
+ r"""
471
+ Enable memory efficient attention as implemented in xformers.
472
+
473
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
474
+ time. Speed up at training time is not guaranteed.
475
+
476
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
477
+ is used.
478
+ """
479
+ self .unet .set_use_memory_efficient_attention_xformers (True )
480
+
481
+ def disable_xformers_memory_efficient_attention (self ):
482
+ r"""
483
+ Disable memory efficient attention as implemented in xformers.
484
+ """
485
+ self .unet .set_use_memory_efficient_attention_xformers (False )
486
+
454
487
def enable_attention_slicing (self , slice_size : Optional [Union [str , int ]] = "auto" ):
455
488
r"""
456
489
Enable sliced attention computation.
@@ -478,6 +511,23 @@ def disable_attention_slicing(self):
478
511
# set slice_size = `None` to disable `attention slicing`
479
512
self .enable_attention_slicing (None )
480
513
514
+ def enable_sequential_cpu_offload (self ):
515
+ r"""
516
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
517
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
518
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
519
+ """
520
+ if is_accelerate_available ():
521
+ from accelerate import cpu_offload
522
+ else :
523
+ raise ImportError ("Please install accelerate via `pip install accelerate`" )
524
+
525
+ device = self .device
526
+
527
+ for cpu_offloaded_model in [self .unet , self .text_encoder , self .vae , self .safety_checker ]:
528
+ if cpu_offloaded_model is not None :
529
+ cpu_offload (cpu_offloaded_model , device )
530
+
481
531
@torch .no_grad ()
482
532
def __call__ (
483
533
self ,
0 commit comments