From 43fd80ecb82bfef82344f6e8dbb79ca57931f068 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Apr 2023 10:52:05 +0800 Subject: [PATCH 1/7] Modified altdiffusion pipline to support altdiffusion-m18 --- .../alt_diffusion/modeling_roberta_series.py | 16 ++++++- .../alt_diffusion/pipeline_alt_diffusion.py | 46 ++++++++++--------- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py index 637d6dd18698..45b1945ba4dd 100644 --- a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -31,6 +31,8 @@ class TransformationModelOutput(ModelOutput): """ projection_state: Optional[torch.FloatTensor] = None + # NEW add for m18 + penultimate_hidden_state: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -56,7 +58,7 @@ def __init__( class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] base_model_prefix = "roberta" config_class = RobertaSeriesConfig @@ -65,6 +67,8 @@ def __init__(self, config): super().__init__(config) self.roberta = XLMRobertaModel(config) self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_init() def forward( @@ -95,14 +99,22 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + # output_hidden_states=output_hidden_states, + output_hidden_states=True, return_dict=return_dict, ) projection_state = self.transformation(outputs.last_hidden_state) + + + # New add for m18 + sequence_output2 = outputs['hidden_states'][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) return TransformationModelOutput( projection_state=projection_state, + penultimate_hidden_state=projection_state2, last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index c5bb8f9ac7b1..8c94a27078b6 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -17,12 +17,11 @@ import torch from packaging import version -from transformers import CLIPImageProcessor, XLMRobertaTokenizer +from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict -from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring @@ -50,7 +49,7 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class AltDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Alt Diffusion. @@ -74,7 +73,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPImageProcessor`]): + feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -87,10 +86,13 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, + feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, + use_penultimate: bool = False, ): super().__init__() + # New add for m18 + self.use_penultimate = use_penultimate if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( @@ -235,7 +237,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") @@ -295,8 +297,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -313,10 +315,6 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - text_inputs = self.tokenizer( prompt, padding="max_length", @@ -347,7 +345,11 @@ def _encode_prompt( text_input_ids.to(device), attention_mask=attention_mask, ) - prompt_embeds = prompt_embeds[0] + # m18 take the embedding from penultimate layer + if self.use_penultimate: + prompt_embeds = prompt_embeds["penultimate_hidden_state"] + else: + prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) @@ -377,10 +379,6 @@ def _encode_prompt( else: uncond_tokens = negative_prompt - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -399,8 +397,12 @@ def _encode_prompt( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - negative_prompt_embeds = negative_prompt_embeds[0] - + # m18 take the embedding from penultimate layer + if self.use_penultimate: + negative_prompt_embeds = negative_prompt_embeds['penultimate_hidden_state'] + else: + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] @@ -560,8 +562,8 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -594,7 +596,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). From 920b035fc9ccac8409aa01c6f87161e7fb6146e7 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Apr 2023 11:46:50 +0800 Subject: [PATCH 2/7] Modified altdiffusion pipline to support altdiffusion-m18 --- .../pipelines/alt_diffusion/modeling_roberta_series.py | 7 +++---- .../pipelines/alt_diffusion/pipeline_alt_diffusion.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py index 45b1945ba4dd..e94e6716e1c7 100644 --- a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -105,10 +105,9 @@ def forward( ) projection_state = self.transformation(outputs.last_hidden_state) - - - # New add for m18 - sequence_output2 = outputs['hidden_states'][-2] + + # New add for m18 + sequence_output2 = outputs["hidden_states"][-2] sequence_output2 = self.pre_LN(sequence_output2) projection_state2 = self.transformation_pre(sequence_output2) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 8c94a27078b6..697e8275680e 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -399,10 +399,10 @@ def _encode_prompt( ) # m18 take the embedding from penultimate layer if self.use_penultimate: - negative_prompt_embeds = negative_prompt_embeds['penultimate_hidden_state'] + negative_prompt_embeds = negative_prompt_embeds["penultimate_hidden_state"] else: negative_prompt_embeds = negative_prompt_embeds[0] - + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] From fd288a97614fd554e5a778a55e44ba9ae60430a8 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 7 Apr 2023 13:50:03 +0800 Subject: [PATCH 3/7] Modified altdiffusion pipline to support altdiffusion-m18 --- .../alt_diffusion/modeling_roberta_series.py | 85 ++++++++++++------- .../alt_diffusion/pipeline_alt_diffusion.py | 23 +++-- 2 files changed, 64 insertions(+), 44 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py index e94e6716e1c7..170294c58bfa 100644 --- a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -31,8 +31,8 @@ class TransformationModelOutput(ModelOutput): """ projection_state: Optional[torch.FloatTensor] = None - # NEW add for m18 - penultimate_hidden_state: torch.FloatTensor = None + # m18新加了一个返回参数 + # penultimate_hidden_state: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -66,9 +66,11 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.roberta = XLMRobertaModel(config) - self.transformation = nn.Linear(config.hidden_size, config.project_dim) - self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) - self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.has_pre_transformation = config.has_pre_transformation + if self.has_pre_transformation: + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_init() def forward( @@ -89,32 +91,53 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.base_model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - # output_hidden_states=output_hidden_states, - output_hidden_states=True, - return_dict=return_dict, - ) + if self.has_pre_transformation: + outputs = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + else: + outputs = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) projection_state = self.transformation(outputs.last_hidden_state) - # New add for m18 - sequence_output2 = outputs["hidden_states"][-2] - sequence_output2 = self.pre_LN(sequence_output2) - projection_state2 = self.transformation_pre(sequence_output2) - - return TransformationModelOutput( - projection_state=projection_state, - penultimate_hidden_state=projection_state2, - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + ### 这里m18新加的,需要拿出来倒二层的向量 ### + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return TransformationModelOutput( + projection_state=projection_state2, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 697e8275680e..f6d597e83eca 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -88,11 +88,8 @@ def __init__( safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, - use_penultimate: bool = False, ): super().__init__() - # New add for m18 - self.use_penultimate = use_penultimate if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( @@ -345,11 +342,11 @@ def _encode_prompt( text_input_ids.to(device), attention_mask=attention_mask, ) - # m18 take the embedding from penultimate layer - if self.use_penultimate: - prompt_embeds = prompt_embeds["penultimate_hidden_state"] - else: - prompt_embeds = prompt_embeds[0] + # m18取得是倒二层的向量 + # if self.use_penultimate: + # prompt_embeds = prompt_embeds["penultimate_hidden_state"] + # else: + prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) @@ -397,11 +394,11 @@ def _encode_prompt( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - # m18 take the embedding from penultimate layer - if self.use_penultimate: - negative_prompt_embeds = negative_prompt_embeds["penultimate_hidden_state"] - else: - negative_prompt_embeds = negative_prompt_embeds[0] + # m18取得是倒二层的向量 + # if self.use_penultimate: + # negative_prompt_embeds = negative_prompt_embeds['penultimate_hidden_state'] + # else: + negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method From 3326063265ffc473895abf0b4510f843f84fa2d4 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 7 Apr 2023 13:57:14 +0800 Subject: [PATCH 4/7] Modified altdiffusion pipline to support altdiffusion-m18 --- .../alt_diffusion/modeling_roberta_series.py | 5 +-- .../alt_diffusion/pipeline_alt_diffusion.py | 39 ++++++++++--------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py index 170294c58bfa..8908ec4f1551 100644 --- a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -31,8 +31,6 @@ class TransformationModelOutput(ModelOutput): """ projection_state: Optional[torch.FloatTensor] = None - # m18新加了一个返回参数 - # penultimate_hidden_state: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -66,9 +64,9 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) self.has_pre_transformation = config.has_pre_transformation if self.has_pre_transformation: - self.transformation = nn.Linear(config.hidden_size, config.project_dim) self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_init() @@ -122,7 +120,6 @@ def forward( projection_state = self.transformation(outputs.last_hidden_state) - ### 这里m18新加的,需要拿出来倒二层的向量 ### if self.has_pre_transformation: sequence_output2 = outputs["hidden_states"][-2] sequence_output2 = self.pre_LN(sequence_output2) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index f6d597e83eca..8d2a78dfa9e8 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -17,11 +17,12 @@ import torch from packaging import version -from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer +from transformers import CLIPImageProcessor, XLMRobertaTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring @@ -49,7 +50,7 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionPipeline(DiffusionPipeline): +class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Alt Diffusion. @@ -73,7 +74,7 @@ class AltDiffusionPipeline(DiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -86,7 +87,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -234,7 +235,7 @@ def enable_model_cpu_offload(self, gpu_id=0): if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") @@ -294,8 +295,8 @@ def _encode_prompt( whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -312,6 +313,10 @@ def _encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -342,10 +347,6 @@ def _encode_prompt( text_input_ids.to(device), attention_mask=attention_mask, ) - # m18取得是倒二层的向量 - # if self.use_penultimate: - # prompt_embeds = prompt_embeds["penultimate_hidden_state"] - # else: prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) @@ -376,6 +377,10 @@ def _encode_prompt( else: uncond_tokens = negative_prompt + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, @@ -394,10 +399,6 @@ def _encode_prompt( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - # m18取得是倒二层的向量 - # if self.use_penultimate: - # negative_prompt_embeds = negative_prompt_embeds['penultimate_hidden_state'] - # else: negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: @@ -559,8 +560,8 @@ def __call__( usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -593,7 +594,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). @@ -716,4 +717,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file From 472b3d8d57cd9164680603a2f146cfbb4420aa38 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 7 Apr 2023 14:26:18 +0800 Subject: [PATCH 5/7] Modified altdiffusion pipline to support altdiffusion-m18 --- src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 8d2a78dfa9e8..c5bb8f9ac7b1 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -717,4 +717,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 7a6314b4cf89ba7405f8640fe716bcf0b1168a40 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 7 Apr 2023 14:38:16 +0800 Subject: [PATCH 6/7] Modified altdiffusion pipline to support altdiffusion-m18 --- .../pipelines/alt_diffusion/modeling_roberta_series.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py index 8908ec4f1551..d613b9d901eb 100644 --- a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -118,8 +118,6 @@ def forward( return_dict=return_dict, ) - projection_state = self.transformation(outputs.last_hidden_state) - if self.has_pre_transformation: sequence_output2 = outputs["hidden_states"][-2] sequence_output2 = self.pre_LN(sequence_output2) @@ -132,6 +130,7 @@ def forward( attentions=outputs.attentions, ) else: + projection_state = self.transformation(outputs.last_hidden_state) return TransformationModelOutput( projection_state=projection_state, last_hidden_state=outputs.last_hidden_state, From d4b09dc5f71b709c29dfcee603315346bfb4a3a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8F=B6=E7=A6=8F=E9=BE=99?= Date: Thu, 13 Apr 2023 14:57:45 +0800 Subject: [PATCH 7/7] Modified altdiffusion pipline to support altdiffusion-m18 --- .../alt_diffusion/modeling_roberta_series.py | 43 ++++++------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py index d613b9d901eb..f73ef15d7de7 100644 --- a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -65,7 +65,7 @@ def __init__(self, config): super().__init__(config) self.roberta = XLMRobertaModel(config) self.transformation = nn.Linear(config.hidden_size, config.project_dim) - self.has_pre_transformation = config.has_pre_transformation + self.has_pre_transformation = getattr(config, "has_pre_transformation", False) if self.has_pre_transformation: self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -89,34 +89,19 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.has_pre_transformation: - outputs = self.base_model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=True, - return_dict=return_dict, - ) - else: - outputs = self.base_model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + outputs = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True if self.has_pre_transformation else output_hidden_states, + return_dict=return_dict, + ) if self.has_pre_transformation: sequence_output2 = outputs["hidden_states"][-2]