Skip to content

[Unclip] Make sure text_embeddings & image_embeddings can directly be passed to enable interpolation tasks. #1858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 63 additions & 33 deletions src/diffusers/pipelines/unclip/pipeline_unclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

import inspect
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import torch
from torch.nn import functional as F

from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput

from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
Expand Down Expand Up @@ -117,31 +118,44 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
latents = latents * scheduler.init_noise_sigma
return latents

def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device)

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
text_attention_mask: Optional[torch.Tensor] = None,
):
if text_model_output is None:
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device)

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]

text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_encoder_output = self.text_encoder(text_input_ids.to(device))

text_embeddings = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
text_embeddings = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state

else:
batch_size = text_model_output[0].shape[0]
text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask

text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
Expand All @@ -150,11 +164,10 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
if do_classifier_free_guidance:
uncond_tokens = [""] * batch_size

max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
Expand Down Expand Up @@ -235,7 +248,7 @@ def _execution_device(self):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
prior_num_inference_steps: int = 25,
decoder_num_inference_steps: int = 25,
Expand All @@ -244,6 +257,8 @@ def __call__(
prior_latents: Optional[torch.FloatTensor] = None,
decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None,
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
text_attention_mask: Optional[torch.Tensor] = None,
prior_guidance_scale: float = 4.0,
decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil",
Expand All @@ -254,7 +269,8 @@ def __call__(

Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
The prompt or prompts to guide the image generation. This can only be left undefined if
`text_model_output` and `text_attention_mask` is passed.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prior_num_inference_steps (`int`, *optional*, defaults to 25):
Expand Down Expand Up @@ -287,26 +303,37 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
text_model_output (`CLIPTextModelOutput`, *optional*):
Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
can be passed for tasks like text embedding interpolations. Make sure to also pass
`text_attention_mask` in this case. `prompt` can the be left to `None`.
text_attention_mask (`torch.Tensor`, *optional*):
Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
masks are necessary when passing `text_model_output`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
if prompt is not None:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
batch_size = text_model_output[0].shape[0]

device = self._execution_device

batch_size = batch_size * num_images_per_prompt

do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0

text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
)

# prior
Expand All @@ -315,6 +342,7 @@ def __call__(
prior_timesteps_tensor = self.prior_scheduler.timesteps

embedding_dim = self.prior.config.embedding_dim

prior_latents = self.prepare_latents(
(batch_size, embedding_dim),
text_embeddings.dtype,
Expand Down Expand Up @@ -378,6 +406,7 @@ def __call__(
num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size
width = self.decoder.sample_size

decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype,
Expand Down Expand Up @@ -430,6 +459,7 @@ def __call__(
channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size
width = self.super_res_first.sample_size

super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),
image_small.dtype,
Expand Down
81 changes: 41 additions & 40 deletions src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
latents = latents * scheduler.init_noise_sigma
return latents

# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1

Expand All @@ -139,15 +138,6 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
)
text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device)

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]

text_encoder_output = self.text_encoder(text_input_ids.to(device))

text_embeddings = text_encoder_output.text_embeds
Expand Down Expand Up @@ -199,14 +189,15 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr

return text_embeddings, text_encoder_hidden_states, text_mask

def _encode_image(self, image, device, num_images_per_prompt):
def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
if image_embeddings is None:
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds

image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

Expand Down Expand Up @@ -258,13 +249,14 @@ def _execution_device(self):
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
num_images_per_prompt: int = 1,
decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7,
generator: Optional[torch.Generator] = None,
decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None,
image_embeddings: Optional[torch.Tensor] = None,
decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil",
return_dict: bool = True,
Expand All @@ -277,7 +269,7 @@ def __call__(
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
configuration of
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor`.
`CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
Expand All @@ -299,18 +291,24 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
image_embeddings (`torch.Tensor`, *optional*):
Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
can be passed for tasks like image interpolations. `image` can the be left to `None`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
"""
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
if image is not None:
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
else:
batch_size = image.shape[0]
batch_size = image_embeddings.shape[0]

prompt = [""] * batch_size

Expand All @@ -324,10 +322,9 @@ def __call__(
prompt, device, num_images_per_prompt, do_classifier_free_guidance
)

image_embeddings = self._encode_image(image, device, num_images_per_prompt)
image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings)

# decoder

text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
text_embeddings=text_embeddings,
Expand All @@ -343,14 +340,16 @@ def __call__(
num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size
width = self.decoder.sample_size
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype,
device,
generator,
decoder_latents,
self.decoder_scheduler,
)

if decoder_latents is None:
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype,
device,
generator,
decoder_latents,
self.decoder_scheduler,
)

for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
Expand Down Expand Up @@ -395,14 +394,16 @@ def __call__(
channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size
width = self.super_res_first.sample_size
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),
image_small.dtype,
device,
generator,
super_res_latents,
self.super_res_scheduler,
)

if super_res_latents is None:
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),
image_small.dtype,
device,
generator,
super_res_latents,
self.super_res_scheduler,
)

interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters:
Expand Down
Loading