Skip to content

Commit b28ab30

Browse files
[Unclip] Make sure text_embeddings & image_embeddings can directly be passed to enable interpolation tasks. (#1858)
* [Unclip] Make sure latents can be reused * allow one to directly pass embeddings * up * make unclip for text work * finish allowing to pass embeddings * correct more * make style
1 parent 29b2c93 commit b28ab30

File tree

4 files changed

+270
-79
lines changed

4 files changed

+270
-79
lines changed

src/diffusers/pipelines/unclip/pipeline_unclip.py

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
from typing import List, Optional, Union
16+
from typing import List, Optional, Tuple, Union
1717

1818
import torch
1919
from torch.nn import functional as F
2020

2121
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
22+
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
2223

2324
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
2425
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
@@ -117,31 +118,44 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
117118
latents = latents * scheduler.init_noise_sigma
118119
return latents
119120

120-
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
121-
batch_size = len(prompt) if isinstance(prompt, list) else 1
122-
123-
# get prompt text embeddings
124-
text_inputs = self.tokenizer(
125-
prompt,
126-
padding="max_length",
127-
max_length=self.tokenizer.model_max_length,
128-
return_tensors="pt",
129-
)
130-
text_input_ids = text_inputs.input_ids
131-
text_mask = text_inputs.attention_mask.bool().to(device)
132-
133-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
134-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
135-
logger.warning(
136-
"The following part of your input was truncated because CLIP can only handle sequences up to"
137-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
121+
def _encode_prompt(
122+
self,
123+
prompt,
124+
device,
125+
num_images_per_prompt,
126+
do_classifier_free_guidance,
127+
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
128+
text_attention_mask: Optional[torch.Tensor] = None,
129+
):
130+
if text_model_output is None:
131+
batch_size = len(prompt) if isinstance(prompt, list) else 1
132+
# get prompt text embeddings
133+
text_inputs = self.tokenizer(
134+
prompt,
135+
padding="max_length",
136+
max_length=self.tokenizer.model_max_length,
137+
return_tensors="pt",
138138
)
139-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
139+
text_input_ids = text_inputs.input_ids
140+
text_mask = text_inputs.attention_mask.bool().to(device)
141+
142+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
143+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
144+
logger.warning(
145+
"The following part of your input was truncated because CLIP can only handle sequences up to"
146+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
147+
)
148+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
140149

141-
text_encoder_output = self.text_encoder(text_input_ids.to(device))
150+
text_encoder_output = self.text_encoder(text_input_ids.to(device))
142151

143-
text_embeddings = text_encoder_output.text_embeds
144-
text_encoder_hidden_states = text_encoder_output.last_hidden_state
152+
text_embeddings = text_encoder_output.text_embeds
153+
text_encoder_hidden_states = text_encoder_output.last_hidden_state
154+
155+
else:
156+
batch_size = text_model_output[0].shape[0]
157+
text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
158+
text_mask = text_attention_mask
145159

146160
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
147161
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
@@ -150,11 +164,10 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
150164
if do_classifier_free_guidance:
151165
uncond_tokens = [""] * batch_size
152166

153-
max_length = text_input_ids.shape[-1]
154167
uncond_input = self.tokenizer(
155168
uncond_tokens,
156169
padding="max_length",
157-
max_length=max_length,
170+
max_length=self.tokenizer.model_max_length,
158171
truncation=True,
159172
return_tensors="pt",
160173
)
@@ -235,7 +248,7 @@ def _execution_device(self):
235248
@torch.no_grad()
236249
def __call__(
237250
self,
238-
prompt: Union[str, List[str]],
251+
prompt: Optional[Union[str, List[str]]] = None,
239252
num_images_per_prompt: int = 1,
240253
prior_num_inference_steps: int = 25,
241254
decoder_num_inference_steps: int = 25,
@@ -244,6 +257,8 @@ def __call__(
244257
prior_latents: Optional[torch.FloatTensor] = None,
245258
decoder_latents: Optional[torch.FloatTensor] = None,
246259
super_res_latents: Optional[torch.FloatTensor] = None,
260+
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
261+
text_attention_mask: Optional[torch.Tensor] = None,
247262
prior_guidance_scale: float = 4.0,
248263
decoder_guidance_scale: float = 8.0,
249264
output_type: Optional[str] = "pil",
@@ -254,7 +269,8 @@ def __call__(
254269
255270
Args:
256271
prompt (`str` or `List[str]`):
257-
The prompt or prompts to guide the image generation.
272+
The prompt or prompts to guide the image generation. This can only be left undefined if
273+
`text_model_output` and `text_attention_mask` is passed.
258274
num_images_per_prompt (`int`, *optional*, defaults to 1):
259275
The number of images to generate per prompt.
260276
prior_num_inference_steps (`int`, *optional*, defaults to 25):
@@ -287,26 +303,37 @@ def __call__(
287303
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
288304
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
289305
usually at the expense of lower image quality.
306+
text_model_output (`CLIPTextModelOutput`, *optional*):
307+
Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
308+
can be passed for tasks like text embedding interpolations. Make sure to also pass
309+
`text_attention_mask` in this case. `prompt` can the be left to `None`.
310+
text_attention_mask (`torch.Tensor`, *optional*):
311+
Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
312+
masks are necessary when passing `text_model_output`.
290313
output_type (`str`, *optional*, defaults to `"pil"`):
291314
The output format of the generated image. Choose between
292315
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
293316
return_dict (`bool`, *optional*, defaults to `True`):
294317
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
295318
"""
296-
if isinstance(prompt, str):
297-
batch_size = 1
298-
elif isinstance(prompt, list):
299-
batch_size = len(prompt)
319+
if prompt is not None:
320+
if isinstance(prompt, str):
321+
batch_size = 1
322+
elif isinstance(prompt, list):
323+
batch_size = len(prompt)
324+
else:
325+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
300326
else:
301-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
327+
batch_size = text_model_output[0].shape[0]
328+
302329
device = self._execution_device
303330

304331
batch_size = batch_size * num_images_per_prompt
305332

306333
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
307334

308335
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
309-
prompt, device, num_images_per_prompt, do_classifier_free_guidance
336+
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
310337
)
311338

312339
# prior
@@ -315,6 +342,7 @@ def __call__(
315342
prior_timesteps_tensor = self.prior_scheduler.timesteps
316343

317344
embedding_dim = self.prior.config.embedding_dim
345+
318346
prior_latents = self.prepare_latents(
319347
(batch_size, embedding_dim),
320348
text_embeddings.dtype,
@@ -378,6 +406,7 @@ def __call__(
378406
num_channels_latents = self.decoder.in_channels
379407
height = self.decoder.sample_size
380408
width = self.decoder.sample_size
409+
381410
decoder_latents = self.prepare_latents(
382411
(batch_size, num_channels_latents, height, width),
383412
text_encoder_hidden_states.dtype,
@@ -430,6 +459,7 @@ def __call__(
430459
channels = self.super_res_first.in_channels // 2
431460
height = self.super_res_first.sample_size
432461
width = self.super_res_first.sample_size
462+
433463
super_res_latents = self.prepare_latents(
434464
(batch_size, channels, height, width),
435465
image_small.dtype,

src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
126126
latents = latents * scheduler.init_noise_sigma
127127
return latents
128128

129-
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt
130129
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
131130
batch_size = len(prompt) if isinstance(prompt, list) else 1
132131

@@ -139,15 +138,6 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
139138
)
140139
text_input_ids = text_inputs.input_ids
141140
text_mask = text_inputs.attention_mask.bool().to(device)
142-
143-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
144-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
145-
logger.warning(
146-
"The following part of your input was truncated because CLIP can only handle sequences up to"
147-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
148-
)
149-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
150-
151141
text_encoder_output = self.text_encoder(text_input_ids.to(device))
152142

153143
text_embeddings = text_encoder_output.text_embeds
@@ -199,14 +189,15 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
199189

200190
return text_embeddings, text_encoder_hidden_states, text_mask
201191

202-
def _encode_image(self, image, device, num_images_per_prompt):
192+
def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
203193
dtype = next(self.image_encoder.parameters()).dtype
204194

205-
if not isinstance(image, torch.Tensor):
206-
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
195+
if image_embeddings is None:
196+
if not isinstance(image, torch.Tensor):
197+
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
207198

208-
image = image.to(device=device, dtype=dtype)
209-
image_embeddings = self.image_encoder(image).image_embeds
199+
image = image.to(device=device, dtype=dtype)
200+
image_embeddings = self.image_encoder(image).image_embeds
210201

211202
image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
212203

@@ -258,13 +249,14 @@ def _execution_device(self):
258249
@torch.no_grad()
259250
def __call__(
260251
self,
261-
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
252+
image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
262253
num_images_per_prompt: int = 1,
263254
decoder_num_inference_steps: int = 25,
264255
super_res_num_inference_steps: int = 7,
265256
generator: Optional[torch.Generator] = None,
266257
decoder_latents: Optional[torch.FloatTensor] = None,
267258
super_res_latents: Optional[torch.FloatTensor] = None,
259+
image_embeddings: Optional[torch.Tensor] = None,
268260
decoder_guidance_scale: float = 8.0,
269261
output_type: Optional[str] = "pil",
270262
return_dict: bool = True,
@@ -277,7 +269,7 @@ def __call__(
277269
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
278270
configuration of
279271
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
280-
`CLIPFeatureExtractor`.
272+
`CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed.
281273
num_images_per_prompt (`int`, *optional*, defaults to 1):
282274
The number of images to generate per prompt.
283275
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
@@ -299,18 +291,24 @@ def __call__(
299291
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
300292
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
301293
usually at the expense of lower image quality.
294+
image_embeddings (`torch.Tensor`, *optional*):
295+
Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
296+
can be passed for tasks like image interpolations. `image` can the be left to `None`.
302297
output_type (`str`, *optional*, defaults to `"pil"`):
303298
The output format of the generated image. Choose between
304299
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
305300
return_dict (`bool`, *optional*, defaults to `True`):
306301
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
307302
"""
308-
if isinstance(image, PIL.Image.Image):
309-
batch_size = 1
310-
elif isinstance(image, list):
311-
batch_size = len(image)
303+
if image is not None:
304+
if isinstance(image, PIL.Image.Image):
305+
batch_size = 1
306+
elif isinstance(image, list):
307+
batch_size = len(image)
308+
else:
309+
batch_size = image.shape[0]
312310
else:
313-
batch_size = image.shape[0]
311+
batch_size = image_embeddings.shape[0]
314312

315313
prompt = [""] * batch_size
316314

@@ -324,10 +322,9 @@ def __call__(
324322
prompt, device, num_images_per_prompt, do_classifier_free_guidance
325323
)
326324

327-
image_embeddings = self._encode_image(image, device, num_images_per_prompt)
325+
image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings)
328326

329327
# decoder
330-
331328
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
332329
image_embeddings=image_embeddings,
333330
text_embeddings=text_embeddings,
@@ -343,14 +340,16 @@ def __call__(
343340
num_channels_latents = self.decoder.in_channels
344341
height = self.decoder.sample_size
345342
width = self.decoder.sample_size
346-
decoder_latents = self.prepare_latents(
347-
(batch_size, num_channels_latents, height, width),
348-
text_encoder_hidden_states.dtype,
349-
device,
350-
generator,
351-
decoder_latents,
352-
self.decoder_scheduler,
353-
)
343+
344+
if decoder_latents is None:
345+
decoder_latents = self.prepare_latents(
346+
(batch_size, num_channels_latents, height, width),
347+
text_encoder_hidden_states.dtype,
348+
device,
349+
generator,
350+
decoder_latents,
351+
self.decoder_scheduler,
352+
)
354353

355354
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
356355
# expand the latents if we are doing classifier free guidance
@@ -395,14 +394,16 @@ def __call__(
395394
channels = self.super_res_first.in_channels // 2
396395
height = self.super_res_first.sample_size
397396
width = self.super_res_first.sample_size
398-
super_res_latents = self.prepare_latents(
399-
(batch_size, channels, height, width),
400-
image_small.dtype,
401-
device,
402-
generator,
403-
super_res_latents,
404-
self.super_res_scheduler,
405-
)
397+
398+
if super_res_latents is None:
399+
super_res_latents = self.prepare_latents(
400+
(batch_size, channels, height, width),
401+
image_small.dtype,
402+
device,
403+
generator,
404+
super_res_latents,
405+
self.super_res_scheduler,
406+
)
406407

407408
interpolate_antialias = {}
408409
if "antialias" in inspect.signature(F.interpolate).parameters:

0 commit comments

Comments
 (0)