Skip to content

Commit a480229

Browse files
authored
[Community Pipeline] lpw_stable_diffusion: add xformers_memory_efficient_attention and sequential_cpu_offload (#1130)
lpw_stable_diffusion: xformers and cpu_offload
1 parent 5b20d3b commit a480229

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

examples/community/lpw_stable_diffusion.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
1313
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
1414
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15-
from diffusers.utils import deprecate, logging
15+
from diffusers.utils import deprecate, is_accelerate_available, logging
1616
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
1717

1818

@@ -340,13 +340,15 @@ def get_weighted_text_embeddings(
340340
# assign weights to the prompts and normalize in the sense of mean
341341
# TODO: should we normalize by chunk or in a whole (current implementation)?
342342
if (not skip_parsing) and (not skip_weighting):
343-
previous_mean = text_embeddings.mean(axis=[-2, -1])
343+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
344344
text_embeddings *= prompt_weights.unsqueeze(-1)
345-
text_embeddings *= (previous_mean / text_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
345+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
346+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
346347
if uncond_prompt is not None:
347-
previous_mean = uncond_embeddings.mean(axis=[-2, -1])
348+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
348349
uncond_embeddings *= uncond_weights.unsqueeze(-1)
349-
uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
350+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
351+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
350352

351353
if uncond_prompt is not None:
352354
return text_embeddings, uncond_embeddings
@@ -431,6 +433,19 @@ def __init__(
431433
new_config["steps_offset"] = 1
432434
scheduler._internal_dict = FrozenDict(new_config)
433435

436+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
437+
deprecation_message = (
438+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
439+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
440+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
441+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
442+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
443+
)
444+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
445+
new_config = dict(scheduler.config)
446+
new_config["clip_sample"] = False
447+
scheduler._internal_dict = FrozenDict(new_config)
448+
434449
if safety_checker is None:
435450
logger.warn(
436451
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -451,6 +466,24 @@ def __init__(
451466
feature_extractor=feature_extractor,
452467
)
453468

469+
def enable_xformers_memory_efficient_attention(self):
470+
r"""
471+
Enable memory efficient attention as implemented in xformers.
472+
473+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
474+
time. Speed up at training time is not guaranteed.
475+
476+
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
477+
is used.
478+
"""
479+
self.unet.set_use_memory_efficient_attention_xformers(True)
480+
481+
def disable_xformers_memory_efficient_attention(self):
482+
r"""
483+
Disable memory efficient attention as implemented in xformers.
484+
"""
485+
self.unet.set_use_memory_efficient_attention_xformers(False)
486+
454487
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
455488
r"""
456489
Enable sliced attention computation.
@@ -478,6 +511,23 @@ def disable_attention_slicing(self):
478511
# set slice_size = `None` to disable `attention slicing`
479512
self.enable_attention_slicing(None)
480513

514+
def enable_sequential_cpu_offload(self):
515+
r"""
516+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
517+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
518+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
519+
"""
520+
if is_accelerate_available():
521+
from accelerate import cpu_offload
522+
else:
523+
raise ImportError("Please install accelerate via `pip install accelerate`")
524+
525+
device = self.device
526+
527+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
528+
if cpu_offloaded_model is not None:
529+
cpu_offload(cpu_offloaded_model, device)
530+
481531
@torch.no_grad()
482532
def __call__(
483533
self,

examples/community/lpw_stable_diffusion_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ def __call__(
701701
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
702702
)
703703
images.append(image_i)
704-
has_nsfw_concept.append(has_nsfw_concept_i)
704+
has_nsfw_concept.append(has_nsfw_concept_i[0])
705705
image = np.concatenate(images)
706706
else:
707707
has_nsfw_concept = None

0 commit comments

Comments
 (0)