28
28
import jax .numpy as jnp
29
29
from flax .jax_utils import replicate
30
30
from flax .training .common_utils import shard
31
- from jax import pmap
32
31
33
32
from diffusers import FlaxDDIMScheduler , FlaxDiffusionPipeline , FlaxStableDiffusionPipeline
34
33
@@ -70,14 +69,12 @@ def test_dummy_all_tpus(self):
70
69
prompt = num_samples * [prompt ]
71
70
prompt_ids = pipeline .prepare_inputs (prompt )
72
71
73
- p_sample = pmap (pipeline .__call__ , static_broadcasted_argnums = (3 ,))
74
-
75
72
# shard inputs and rng
76
73
params = replicate (params )
77
74
prng_seed = jax .random .split (prng_seed , num_samples )
78
75
prompt_ids = shard (prompt_ids )
79
76
80
- images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
77
+ images = pipeline (prompt_ids , params , prng_seed , num_inference_steps , jit = True ).images
81
78
82
79
assert images .shape == (num_samples , 1 , 64 , 64 , 3 )
83
80
if jax .device_count () == 8 :
@@ -105,14 +102,12 @@ def test_stable_diffusion_v1_4(self):
105
102
prompt = num_samples * [prompt ]
106
103
prompt_ids = pipeline .prepare_inputs (prompt )
107
104
108
- p_sample = pmap (pipeline .__call__ , static_broadcasted_argnums = (3 ,))
109
-
110
105
# shard inputs and rng
111
106
params = replicate (params )
112
107
prng_seed = jax .random .split (prng_seed , num_samples )
113
108
prompt_ids = shard (prompt_ids )
114
109
115
- images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
110
+ images = pipeline (prompt_ids , params , prng_seed , num_inference_steps , jit = True ).images
116
111
117
112
assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
118
113
if jax .device_count () == 8 :
@@ -136,14 +131,12 @@ def test_stable_diffusion_v1_4_bfloat_16(self):
136
131
prompt = num_samples * [prompt ]
137
132
prompt_ids = pipeline .prepare_inputs (prompt )
138
133
139
- p_sample = pmap (pipeline .__call__ , static_broadcasted_argnums = (3 ,))
140
-
141
134
# shard inputs and rng
142
135
params = replicate (params )
143
136
prng_seed = jax .random .split (prng_seed , num_samples )
144
137
prompt_ids = shard (prompt_ids )
145
138
146
- images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
139
+ images = pipeline (prompt_ids , params , prng_seed , num_inference_steps , jit = True ).images
147
140
148
141
assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
149
142
if jax .device_count () == 8 :
@@ -211,14 +204,12 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
211
204
prompt = num_samples * [prompt ]
212
205
prompt_ids = pipeline .prepare_inputs (prompt )
213
206
214
- p_sample = pmap (pipeline .__call__ , static_broadcasted_argnums = (3 ,))
215
-
216
207
# shard inputs and rng
217
208
params = replicate (params )
218
209
prng_seed = jax .random .split (prng_seed , num_samples )
219
210
prompt_ids = shard (prompt_ids )
220
211
221
- images = p_sample (prompt_ids , params , prng_seed , num_inference_steps ).images
212
+ images = pipeline (prompt_ids , params , prng_seed , num_inference_steps , jit = True ).images
222
213
223
214
assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
224
215
if jax .device_count () == 8 :
0 commit comments