@@ -165,6 +165,7 @@ def _generate(
165
165
guidance_scale : float = 7.5 ,
166
166
latents : Optional [jnp .array ] = None ,
167
167
debug : bool = False ,
168
+ neg_prompt_ids : jnp .array = None ,
168
169
):
169
170
if height % 8 != 0 or width % 8 != 0 :
170
171
raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -177,10 +178,14 @@ def _generate(
177
178
batch_size = prompt_ids .shape [0 ]
178
179
179
180
max_length = prompt_ids .shape [- 1 ]
180
- uncond_input = self .tokenizer (
181
- ["" ] * batch_size , padding = "max_length" , max_length = max_length , return_tensors = "np"
182
- )
183
- uncond_embeddings = self .text_encoder (uncond_input .input_ids , params = params ["text_encoder" ])[0 ]
181
+
182
+ if neg_prompt_ids is None :
183
+ uncond_input = self .tokenizer (
184
+ ["" ] * batch_size , padding = "max_length" , max_length = max_length , return_tensors = "np"
185
+ ).input_ids
186
+ else :
187
+ uncond_input = neg_prompt_ids
188
+ uncond_embeddings = self .text_encoder (uncond_input , params = params ["text_encoder" ])[0 ]
184
189
context = jnp .concatenate ([uncond_embeddings , text_embeddings ])
185
190
186
191
latents_shape = (batch_size , self .unet .in_channels , height // 8 , width // 8 )
@@ -251,6 +256,7 @@ def __call__(
251
256
return_dict : bool = True ,
252
257
jit : bool = False ,
253
258
debug : bool = False ,
259
+ neg_prompt_ids : jnp .array = None ,
254
260
** kwargs ,
255
261
):
256
262
r"""
@@ -298,11 +304,30 @@ def __call__(
298
304
"""
299
305
if jit :
300
306
images = _p_generate (
301
- self , prompt_ids , params , prng_seed , num_inference_steps , height , width , guidance_scale , latents , debug
307
+ self ,
308
+ prompt_ids ,
309
+ params ,
310
+ prng_seed ,
311
+ num_inference_steps ,
312
+ height ,
313
+ width ,
314
+ guidance_scale ,
315
+ latents ,
316
+ debug ,
317
+ neg_prompt_ids ,
302
318
)
303
319
else :
304
320
images = self ._generate (
305
- prompt_ids , params , prng_seed , num_inference_steps , height , width , guidance_scale , latents , debug
321
+ prompt_ids ,
322
+ params ,
323
+ prng_seed ,
324
+ num_inference_steps ,
325
+ height ,
326
+ width ,
327
+ guidance_scale ,
328
+ latents ,
329
+ debug ,
330
+ neg_prompt_ids ,
306
331
)
307
332
308
333
if self .safety_checker is not None :
@@ -333,10 +358,29 @@ def __call__(
333
358
# TODO: maybe use a config dict instead of so many static argnums
334
359
@partial (jax .pmap , static_broadcasted_argnums = (0 , 4 , 5 , 6 , 7 , 9 ))
335
360
def _p_generate (
336
- pipe , prompt_ids , params , prng_seed , num_inference_steps , height , width , guidance_scale , latents , debug
361
+ pipe ,
362
+ prompt_ids ,
363
+ params ,
364
+ prng_seed ,
365
+ num_inference_steps ,
366
+ height ,
367
+ width ,
368
+ guidance_scale ,
369
+ latents ,
370
+ debug ,
371
+ neg_prompt_ids ,
337
372
):
338
373
return pipe ._generate (
339
- prompt_ids , params , prng_seed , num_inference_steps , height , width , guidance_scale , latents , debug
374
+ prompt_ids ,
375
+ params ,
376
+ prng_seed ,
377
+ num_inference_steps ,
378
+ height ,
379
+ width ,
380
+ guidance_scale ,
381
+ latents ,
382
+ debug ,
383
+ neg_prompt_ids ,
340
384
)
341
385
342
386
0 commit comments