Skip to content

Commit e3095c5

Browse files
authored
Fix invocation of some slow Flax tests (#3058)
* Fix invocation of some slow tests. We use __call__ rather than pmapping the generation function ourselves because the number of static arguments is different now. * style
1 parent 526827c commit e3095c5

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

tests/test_pipelines_flax.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import jax.numpy as jnp
2929
from flax.jax_utils import replicate
3030
from flax.training.common_utils import shard
31-
from jax import pmap
3231

3332
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
3433

@@ -70,14 +69,12 @@ def test_dummy_all_tpus(self):
7069
prompt = num_samples * [prompt]
7170
prompt_ids = pipeline.prepare_inputs(prompt)
7271

73-
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
74-
7572
# shard inputs and rng
7673
params = replicate(params)
7774
prng_seed = jax.random.split(prng_seed, num_samples)
7875
prompt_ids = shard(prompt_ids)
7976

80-
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
77+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
8178

8279
assert images.shape == (num_samples, 1, 64, 64, 3)
8380
if jax.device_count() == 8:
@@ -105,14 +102,12 @@ def test_stable_diffusion_v1_4(self):
105102
prompt = num_samples * [prompt]
106103
prompt_ids = pipeline.prepare_inputs(prompt)
107104

108-
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
109-
110105
# shard inputs and rng
111106
params = replicate(params)
112107
prng_seed = jax.random.split(prng_seed, num_samples)
113108
prompt_ids = shard(prompt_ids)
114109

115-
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
110+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
116111

117112
assert images.shape == (num_samples, 1, 512, 512, 3)
118113
if jax.device_count() == 8:
@@ -136,14 +131,12 @@ def test_stable_diffusion_v1_4_bfloat_16(self):
136131
prompt = num_samples * [prompt]
137132
prompt_ids = pipeline.prepare_inputs(prompt)
138133

139-
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
140-
141134
# shard inputs and rng
142135
params = replicate(params)
143136
prng_seed = jax.random.split(prng_seed, num_samples)
144137
prompt_ids = shard(prompt_ids)
145138

146-
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
139+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
147140

148141
assert images.shape == (num_samples, 1, 512, 512, 3)
149142
if jax.device_count() == 8:
@@ -211,14 +204,12 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
211204
prompt = num_samples * [prompt]
212205
prompt_ids = pipeline.prepare_inputs(prompt)
213206

214-
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
215-
216207
# shard inputs and rng
217208
params = replicate(params)
218209
prng_seed = jax.random.split(prng_seed, num_samples)
219210
prompt_ids = shard(prompt_ids)
220211

221-
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
212+
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
222213

223214
assert images.shape == (num_samples, 1, 512, 512, 3)
224215
if jax.device_count() == 8:

0 commit comments

Comments
 (0)