13
13
# limitations under the License.
14
14
15
15
import inspect
16
- from typing import List , Optional , Union
16
+ from typing import List , Optional , Tuple , Union
17
17
18
18
import torch
19
19
from torch .nn import functional as F
20
20
21
21
from transformers import CLIPTextModelWithProjection , CLIPTokenizer
22
+ from transformers .models .clip .modeling_clip import CLIPTextModelOutput
22
23
23
24
from ...models import PriorTransformer , UNet2DConditionModel , UNet2DModel
24
25
from ...pipelines import DiffusionPipeline , ImagePipelineOutput
@@ -117,31 +118,44 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
117
118
latents = latents * scheduler .init_noise_sigma
118
119
return latents
119
120
120
- def _encode_prompt (self , prompt , device , num_images_per_prompt , do_classifier_free_guidance ):
121
- batch_size = len (prompt ) if isinstance (prompt , list ) else 1
122
-
123
- # get prompt text embeddings
124
- text_inputs = self .tokenizer (
125
- prompt ,
126
- padding = "max_length" ,
127
- max_length = self .tokenizer .model_max_length ,
128
- return_tensors = "pt" ,
129
- )
130
- text_input_ids = text_inputs .input_ids
131
- text_mask = text_inputs .attention_mask .bool ().to (device )
132
-
133
- if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
134
- removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
135
- logger .warning (
136
- "The following part of your input was truncated because CLIP can only handle sequences up to"
137
- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
121
+ def _encode_prompt (
122
+ self ,
123
+ prompt ,
124
+ device ,
125
+ num_images_per_prompt ,
126
+ do_classifier_free_guidance ,
127
+ text_model_output : Optional [Union [CLIPTextModelOutput , Tuple ]] = None ,
128
+ text_attention_mask : Optional [torch .Tensor ] = None ,
129
+ ):
130
+ if text_model_output is None :
131
+ batch_size = len (prompt ) if isinstance (prompt , list ) else 1
132
+ # get prompt text embeddings
133
+ text_inputs = self .tokenizer (
134
+ prompt ,
135
+ padding = "max_length" ,
136
+ max_length = self .tokenizer .model_max_length ,
137
+ return_tensors = "pt" ,
138
138
)
139
- text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
139
+ text_input_ids = text_inputs .input_ids
140
+ text_mask = text_inputs .attention_mask .bool ().to (device )
141
+
142
+ if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
143
+ removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
144
+ logger .warning (
145
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
146
+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
147
+ )
148
+ text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
140
149
141
- text_encoder_output = self .text_encoder (text_input_ids .to (device ))
150
+ text_encoder_output = self .text_encoder (text_input_ids .to (device ))
142
151
143
- text_embeddings = text_encoder_output .text_embeds
144
- text_encoder_hidden_states = text_encoder_output .last_hidden_state
152
+ text_embeddings = text_encoder_output .text_embeds
153
+ text_encoder_hidden_states = text_encoder_output .last_hidden_state
154
+
155
+ else :
156
+ batch_size = text_model_output [0 ].shape [0 ]
157
+ text_embeddings , text_encoder_hidden_states = text_model_output [0 ], text_model_output [1 ]
158
+ text_mask = text_attention_mask
145
159
146
160
text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
147
161
text_encoder_hidden_states = text_encoder_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
@@ -150,11 +164,10 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
150
164
if do_classifier_free_guidance :
151
165
uncond_tokens = ["" ] * batch_size
152
166
153
- max_length = text_input_ids .shape [- 1 ]
154
167
uncond_input = self .tokenizer (
155
168
uncond_tokens ,
156
169
padding = "max_length" ,
157
- max_length = max_length ,
170
+ max_length = self . tokenizer . model_max_length ,
158
171
truncation = True ,
159
172
return_tensors = "pt" ,
160
173
)
@@ -235,7 +248,7 @@ def _execution_device(self):
235
248
@torch .no_grad ()
236
249
def __call__ (
237
250
self ,
238
- prompt : Union [str , List [str ]],
251
+ prompt : Optional [ Union [str , List [str ]]] = None ,
239
252
num_images_per_prompt : int = 1 ,
240
253
prior_num_inference_steps : int = 25 ,
241
254
decoder_num_inference_steps : int = 25 ,
@@ -244,6 +257,8 @@ def __call__(
244
257
prior_latents : Optional [torch .FloatTensor ] = None ,
245
258
decoder_latents : Optional [torch .FloatTensor ] = None ,
246
259
super_res_latents : Optional [torch .FloatTensor ] = None ,
260
+ text_model_output : Optional [Union [CLIPTextModelOutput , Tuple ]] = None ,
261
+ text_attention_mask : Optional [torch .Tensor ] = None ,
247
262
prior_guidance_scale : float = 4.0 ,
248
263
decoder_guidance_scale : float = 8.0 ,
249
264
output_type : Optional [str ] = "pil" ,
@@ -254,7 +269,8 @@ def __call__(
254
269
255
270
Args:
256
271
prompt (`str` or `List[str]`):
257
- The prompt or prompts to guide the image generation.
272
+ The prompt or prompts to guide the image generation. This can only be left undefined if
273
+ `text_model_output` and `text_attention_mask` is passed.
258
274
num_images_per_prompt (`int`, *optional*, defaults to 1):
259
275
The number of images to generate per prompt.
260
276
prior_num_inference_steps (`int`, *optional*, defaults to 25):
@@ -287,26 +303,37 @@ def __call__(
287
303
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
288
304
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
289
305
usually at the expense of lower image quality.
306
+ text_model_output (`CLIPTextModelOutput`, *optional*):
307
+ Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
308
+ can be passed for tasks like text embedding interpolations. Make sure to also pass
309
+ `text_attention_mask` in this case. `prompt` can the be left to `None`.
310
+ text_attention_mask (`torch.Tensor`, *optional*):
311
+ Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
312
+ masks are necessary when passing `text_model_output`.
290
313
output_type (`str`, *optional*, defaults to `"pil"`):
291
314
The output format of the generated image. Choose between
292
315
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
293
316
return_dict (`bool`, *optional*, defaults to `True`):
294
317
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
295
318
"""
296
- if isinstance (prompt , str ):
297
- batch_size = 1
298
- elif isinstance (prompt , list ):
299
- batch_size = len (prompt )
319
+ if prompt is not None :
320
+ if isinstance (prompt , str ):
321
+ batch_size = 1
322
+ elif isinstance (prompt , list ):
323
+ batch_size = len (prompt )
324
+ else :
325
+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
300
326
else :
301
- raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
327
+ batch_size = text_model_output [0 ].shape [0 ]
328
+
302
329
device = self ._execution_device
303
330
304
331
batch_size = batch_size * num_images_per_prompt
305
332
306
333
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
307
334
308
335
text_embeddings , text_encoder_hidden_states , text_mask = self ._encode_prompt (
309
- prompt , device , num_images_per_prompt , do_classifier_free_guidance
336
+ prompt , device , num_images_per_prompt , do_classifier_free_guidance , text_model_output , text_attention_mask
310
337
)
311
338
312
339
# prior
@@ -315,6 +342,7 @@ def __call__(
315
342
prior_timesteps_tensor = self .prior_scheduler .timesteps
316
343
317
344
embedding_dim = self .prior .config .embedding_dim
345
+
318
346
prior_latents = self .prepare_latents (
319
347
(batch_size , embedding_dim ),
320
348
text_embeddings .dtype ,
@@ -378,6 +406,7 @@ def __call__(
378
406
num_channels_latents = self .decoder .in_channels
379
407
height = self .decoder .sample_size
380
408
width = self .decoder .sample_size
409
+
381
410
decoder_latents = self .prepare_latents (
382
411
(batch_size , num_channels_latents , height , width ),
383
412
text_encoder_hidden_states .dtype ,
@@ -430,6 +459,7 @@ def __call__(
430
459
channels = self .super_res_first .in_channels // 2
431
460
height = self .super_res_first .sample_size
432
461
width = self .super_res_first .sample_size
462
+
433
463
super_res_latents = self .prepare_latents (
434
464
(batch_size , channels , height , width ),
435
465
image_small .dtype ,
0 commit comments