Skip to content

Commit 78289c9

Browse files
committed
added vmap test
1 parent c994a60 commit 78289c9

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

test/test_saveat_solution.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,44 @@ def test_t0_eq_t1(subs):
183183
assert tree_allclose(sol.ys, compare)
184184

185185

186+
@pytest.mark.parametrize("subs", [True, False])
187+
def test_vmap_t0_eq_t1(subs):
188+
ntsave = 4
189+
y0 = jnp.array([2.0])
190+
term = diffrax.ODETerm(lambda t, y, args: y)
191+
192+
def _solve(tf):
193+
ts = jnp.linspace(0.0, tf, ntsave)
194+
get0 = diffrax.SubSaveAt(
195+
ts=ts,
196+
t1=True,
197+
)
198+
get1 = diffrax.SubSaveAt(
199+
t0=True,
200+
ts=ts,
201+
)
202+
subs = (get0, get1)
203+
saveat = diffrax.SaveAt(subs=subs)
204+
return diffrax.diffeqsolve(
205+
term,
206+
t0=ts[0],
207+
t1=ts[-1],
208+
y0=y0,
209+
dt0=0.1,
210+
solver=diffrax.Dopri5(),
211+
saveat=saveat,
212+
)
213+
214+
compare = jnp.full((ntsave + 1, *y0.shape), y0)
215+
sol = jax.vmap(_solve)(jnp.array([0.0, 1.0]))
216+
assert tree_allclose(sol.ys[0][0], compare) # pyright: ignore
217+
assert tree_allclose(sol.ys[1][0], compare) # pyright: ignore
218+
219+
regular_solve = _solve(1.0)
220+
assert tree_allclose(sol.ys[0][1], regular_solve.ys[0]) # pyright: ignore
221+
assert tree_allclose(sol.ys[1][1], regular_solve.ys[1]) # pyright: ignore
222+
223+
186224
def test_trivial_dense():
187225
term = diffrax.ODETerm(lambda t, y, args: -0.5 * y)
188226
y0 = jnp.array([2.1])

0 commit comments

Comments
 (0)