diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 74aed2fec86f..39a266dec972 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -12,7 +12,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from diffusers.utils import deprecate, logging +from diffusers.utils import deprecate, is_accelerate_available, logging from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -340,13 +340,15 @@ def get_weighted_text_embeddings( # assign weights to the prompts and normalize in the sense of mean # TODO: should we normalize by chunk or in a whole (current implementation)? if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.mean(axis=[-2, -1]) + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) text_embeddings *= prompt_weights.unsqueeze(-1) - text_embeddings *= (previous_mean / text_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) if uncond_prompt is not None: - previous_mean = uncond_embeddings.mean(axis=[-2, -1]) + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) uncond_embeddings *= uncond_weights.unsqueeze(-1) - uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) if uncond_prompt is not None: return text_embeddings, uncond_embeddings @@ -431,6 +433,19 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + if safety_checker is None: logger.warn( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" @@ -451,6 +466,24 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. @@ -478,6 +511,23 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = self.device + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + @torch.no_grad() def __call__( self, diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index 69b942f9ef1a..12e306a612e1 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -701,7 +701,7 @@ def __call__( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) - has_nsfw_concept.append(has_nsfw_concept_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None