@@ -183,6 +183,44 @@ def test_t0_eq_t1(subs):
183
183
assert tree_allclose (sol .ys , compare )
184
184
185
185
186
+ @pytest .mark .parametrize ("subs" , [True , False ])
187
+ def test_vmap_t0_eq_t1 (subs ):
188
+ ntsave = 4
189
+ y0 = jnp .array ([2.0 ])
190
+ term = diffrax .ODETerm (lambda t , y , args : y )
191
+
192
+ def _solve (tf ):
193
+ ts = jnp .linspace (0.0 , tf , ntsave )
194
+ get0 = diffrax .SubSaveAt (
195
+ ts = ts ,
196
+ t1 = True ,
197
+ )
198
+ get1 = diffrax .SubSaveAt (
199
+ t0 = True ,
200
+ ts = ts ,
201
+ )
202
+ subs = (get0 , get1 )
203
+ saveat = diffrax .SaveAt (subs = subs )
204
+ return diffrax .diffeqsolve (
205
+ term ,
206
+ t0 = ts [0 ],
207
+ t1 = ts [- 1 ],
208
+ y0 = y0 ,
209
+ dt0 = 0.1 ,
210
+ solver = diffrax .Dopri5 (),
211
+ saveat = saveat ,
212
+ )
213
+
214
+ compare = jnp .full ((ntsave + 1 , * y0 .shape ), y0 )
215
+ sol = jax .vmap (_solve )(jnp .array ([0.0 , 1.0 ]))
216
+ assert tree_allclose (sol .ys [0 ][0 ], compare ) # pyright: ignore
217
+ assert tree_allclose (sol .ys [1 ][0 ], compare ) # pyright: ignore
218
+
219
+ regular_solve = _solve (1.0 )
220
+ assert tree_allclose (sol .ys [0 ][1 ], regular_solve .ys [0 ]) # pyright: ignore
221
+ assert tree_allclose (sol .ys [1 ][1 ], regular_solve .ys [1 ]) # pyright: ignore
222
+
223
+
186
224
def test_trivial_dense ():
187
225
term = diffrax .ODETerm (lambda t , y , args : - 0.5 * y )
188
226
y0 = jnp .array ([2.1 ])
0 commit comments