Skip to content

Commit 7bbbfbf

Browse files
Jax infer support negative prompt (#1337)
* support negative prompts in sd jax pipeline * pass batched neg_prompt * only encode when negative prompt is None Co-authored-by: Juan Acevedo <[email protected]>
1 parent 3022090 commit 7bbbfbf

File tree

1 file changed

+52
-8
lines changed

1 file changed

+52
-8
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def _generate(
165165
guidance_scale: float = 7.5,
166166
latents: Optional[jnp.array] = None,
167167
debug: bool = False,
168+
neg_prompt_ids: jnp.array = None,
168169
):
169170
if height % 8 != 0 or width % 8 != 0:
170171
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -177,10 +178,14 @@ def _generate(
177178
batch_size = prompt_ids.shape[0]
178179

179180
max_length = prompt_ids.shape[-1]
180-
uncond_input = self.tokenizer(
181-
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
182-
)
183-
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
181+
182+
if neg_prompt_ids is None:
183+
uncond_input = self.tokenizer(
184+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
185+
).input_ids
186+
else:
187+
uncond_input = neg_prompt_ids
188+
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
184189
context = jnp.concatenate([uncond_embeddings, text_embeddings])
185190

186191
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
@@ -251,6 +256,7 @@ def __call__(
251256
return_dict: bool = True,
252257
jit: bool = False,
253258
debug: bool = False,
259+
neg_prompt_ids: jnp.array = None,
254260
**kwargs,
255261
):
256262
r"""
@@ -298,11 +304,30 @@ def __call__(
298304
"""
299305
if jit:
300306
images = _p_generate(
301-
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
307+
self,
308+
prompt_ids,
309+
params,
310+
prng_seed,
311+
num_inference_steps,
312+
height,
313+
width,
314+
guidance_scale,
315+
latents,
316+
debug,
317+
neg_prompt_ids,
302318
)
303319
else:
304320
images = self._generate(
305-
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
321+
prompt_ids,
322+
params,
323+
prng_seed,
324+
num_inference_steps,
325+
height,
326+
width,
327+
guidance_scale,
328+
latents,
329+
debug,
330+
neg_prompt_ids,
306331
)
307332

308333
if self.safety_checker is not None:
@@ -333,10 +358,29 @@ def __call__(
333358
# TODO: maybe use a config dict instead of so many static argnums
334359
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
335360
def _p_generate(
336-
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
361+
pipe,
362+
prompt_ids,
363+
params,
364+
prng_seed,
365+
num_inference_steps,
366+
height,
367+
width,
368+
guidance_scale,
369+
latents,
370+
debug,
371+
neg_prompt_ids,
337372
):
338373
return pipe._generate(
339-
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
374+
prompt_ids,
375+
params,
376+
prng_seed,
377+
num_inference_steps,
378+
height,
379+
width,
380+
guidance_scale,
381+
latents,
382+
debug,
383+
neg_prompt_ids,
340384
)
341385

342386

0 commit comments

Comments
 (0)