Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fa7417f
Langevin PR (#453)
andyElking Sep 1, 2024
0f809d0
tidy-ups
patrick-kidger Sep 14, 2024
d80ff7a
Split SDE tests in half, to try and avoid GitHub runner issues?
patrick-kidger Oct 21, 2024
cbf04ff
Added effects_barrier to fix test issue with JAX 0.4.33+
patrick-kidger Oct 21, 2024
2e7836d
small fix of docs in all three and a return type in quicsort
andyElking Oct 18, 2024
8b34a1c
bump doc building pipeline
patrick-kidger Nov 2, 2024
161f2a6
Compatibility with JAX 0.4.36, which removes ConcreteArray
patrick-kidger Nov 17, 2024
c88305d
using a fori_loop to save states in edge case t0==t1
dkweiss31 Aug 18, 2024
c96ee56
added case for saving t0 data, which was also not getting updated.
dkweiss31 Aug 18, 2024
439887c
using while_loop, ran into issues with reverse-mode diff using the fo…
dkweiss31 Aug 21, 2024
dc0dba4
bug fix for cases when t0=True
dkweiss31 Nov 13, 2024
c994a60
simplified logic for saving, no loop necessary
dkweiss31 Nov 27, 2024
78289c9
added vmap test
dkweiss31 Nov 27, 2024
666948c
using a fori_loop to save states in edge case t0==t1
dkweiss31 Aug 18, 2024
1bd4e08
added case for saving t0 data, which was also not getting updated.
dkweiss31 Aug 18, 2024
f23456e
using while_loop, ran into issues with reverse-mode diff using the fo…
dkweiss31 Aug 21, 2024
0e00411
bug fix for cases when t0=True
dkweiss31 Nov 13, 2024
065fe11
simplified logic for saving, no loop necessary
dkweiss31 Nov 27, 2024
47880be
added vmap test
dkweiss31 Nov 27, 2024
bb292b7
Merge remote-tracking branch 'origin/save_fix' into save_fix
dkweiss31 Dec 5, 2024
c8ca285
fix t1 out of bounds issue
dkweiss31 Dec 5, 2024
059f72b
fix for steps: don't want to update those values if t0==t1 since we d…
dkweiss31 Dec 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
mkdocs build # twice, see https://github.com/patrick-kidger/pytkdocs_tweaks

- name: Upload docs
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
name: docs
path: site # where `mkdocs build` puts the built site
6 changes: 6 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
AbstractDIRK as AbstractDIRK,
AbstractERK as AbstractERK,
AbstractESDIRK as AbstractESDIRK,
AbstractFosterLangevinSRK as AbstractFosterLangevinSRK,
AbstractImplicitSolver as AbstractImplicitSolver,
AbstractItoSolver as AbstractItoSolver,
AbstractRungeKutta as AbstractRungeKutta,
Expand All @@ -79,6 +80,7 @@
AbstractSRK as AbstractSRK,
AbstractStratonovichSolver as AbstractStratonovichSolver,
AbstractWrappedSolver as AbstractWrappedSolver,
ALIGN as ALIGN,
Bosh3 as Bosh3,
ButcherTableau as ButcherTableau,
CalculateJacobian as CalculateJacobian,
Expand All @@ -100,11 +102,13 @@
LeapfrogMidpoint as LeapfrogMidpoint,
Midpoint as Midpoint,
MultiButcherTableau as MultiButcherTableau,
QUICSORT as QUICSORT,
Ralston as Ralston,
ReversibleHeun as ReversibleHeun,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
ShARK as ShARK,
ShOULD as ShOULD,
Sil3 as Sil3,
SlowRK as SlowRK,
SPaRK as SPaRK,
Expand All @@ -125,6 +129,8 @@
ControlTerm as ControlTerm,
MultiTerm as MultiTerm,
ODETerm as ODETerm,
UnderdampedLangevinDiffusionTerm as UnderdampedLangevinDiffusionTerm,
UnderdampedLangevinDriftTerm as UnderdampedLangevinDriftTerm,
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,
)

Expand Down
51 changes: 46 additions & 5 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax.internal as lxi
import numpy as np
import optimistix as optx
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real

Expand Down Expand Up @@ -258,12 +259,10 @@ def _maybe_static(static_x: Optional[ArrayLike], x: ArrayLike) -> ArrayLike:
# Some values (made_jump and result) are not used in many common use-cases. If we
# detect that they're unused then we make sure they're non-Array Python values, so
# that we can special case on them at trace time and get a performance boost.
if isinstance(static_x, (bool, int, float, complex)):
if isinstance(static_x, (bool, int, float, complex, np.ndarray)):
return static_x
elif static_x is None:
return x
elif type(jax.core.get_aval(static_x)) is jax.core.ConcreteArray:
return static_x
else:
return x

Expand Down Expand Up @@ -776,9 +775,51 @@ def _save_t1(subsaveat, save_state):
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state)
return save_state

save_state = jtu.tree_map(
_save_t1, saveat.subs, final_state.save_state, is_leaf=_is_subsaveat
def _save_ts_impl(ts, fn, _save_state):
def _cond_fun(__save_state):
return __save_state.saveat_ts_index < len(_save_state.ts)

def _body_fun(__save_state):
idx = __save_state.save_index
ts = __save_state.ts.at[idx].set(t0)
ys = jtu.tree_map(
lambda _y, _ys: _ys.at[idx].set(_y),
fn(t0, yfinal, args),
__save_state.ys,
)
return SaveState(
saveat_ts_index=idx + 1,
ts=ts,
ys=ys,
save_index=idx + 1,
)

return inner_while_loop(
_cond_fun,
_body_fun,
_save_state,
max_steps=len(ts),
buffers=_inner_buffers,
checkpoints=len(ts),
)

def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.ts is not None:
save_state = _save_ts_impl(subsaveat.ts, subsaveat.fn, save_state)
return save_state

# if t0 == t1 then we don't enter the integration loop. In this case we have to
# manually update the saved ts and ys if we want to save at "intermediate"
# times specified by saveat.subs.ts
save_state = jax.lax.cond(
t0 == t1,
lambda _save_state: jtu.tree_map(
_save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this t0 == t1 branch can be made much more efficient: I imagine you should be able to do the tree-mapped equivalent of ys = jnp.where(ts == t0, y0, jnp.inf). Does something go wrong with this approach that I'm missing?

(Regardless of the above, do also note that you are doing a tree-map-of-a-loop. If you can, it is usually much more efficient to do a loop-of-a-tree-map.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the note about tree-map-of-a-loop vs. loop-of-a-tree-map! This was a gotcha I was unaware of.

I am probably missing something here but it seems to me at least naively that things could get screwed up based on the way data is saved in SaveState. In my test, we would have for saveat

SaveAt(
  subs=(
    SubSaveAt(t0=False, t1=True, ts=f32[3], steps=False, fn=<function save_y>),
    SubSaveAt(t0=True, t1=False, ts=f32[3], steps=False, fn=<function save_y>)
  ),
  dense=False,
  solver_state=False,
  controller_state=False,
  made_jump=False
)

and for final_state.save_state we have

(SaveState(saveat_ts_index=i32[], ts=f32[4], ys=f32[4,1], save_index=i32[]), SaveState(saveat_ts_index=i32[], ts=f32[4], ys=f32[4,1], save_index=i32[]))

The issue is that for final_state.save_state[0], we want to fill in ys and ts starting at index 0, whereas for final_state.save_state[1] we want to start filling in at index 1 since saveat.subs[1].t0==True. Its this dependence on the value of saveat_ts_index or save_index that made me implement this in a loop and makes it non-obvious to me how a tree_map would be able to handle that.

If you agree, then I will try to refactor this as a loop-of-a-tree-map vs. the way I have it now. Very possible I am missing something simple though :)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I think the only cases that affect this are t0 and t1? If we're saving neither of them then we should fill in ts == t0-many elements of the output. If we're saving just one of them then we want to save one more value. If we're saving both of them then we want to save two more values.

So untested but something like this:

mask = ts == t0
if t0 or t1:
    if t0 and t1:
        mask = lax.dynamic_update_slice_in_dim(mask, jnp.array([True, True]), jnp.argmin(mask), axis=0)
    else:
        mask = lax.dynamic_update_slice_in_dim(mask, jnp.array([True]), jnp.argmin(mask), axis=0)
ys = jnp.where(mask, y0, jnp.inf)

?

lambda _save_state: _save_state,
final_state.save_state,
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you wrap this whole cond in another cond, which has predicate eqxi.unvmap_any(t0 == t1)?

This is to avoid the vmap-of-cond-becomes-select issue: under vmap, then both branches will unconditionally trigger, which would really slow things down!

save_state = jtu.tree_map(_save_t1, saveat.subs, save_state, is_leaf=_is_subsaveat)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this pretty much all LGTM! I can see that you're using the invariant t0 <= ts < t1. That's much neater than the unnecessary where check that I was suggesting!

I think my only concerns now are how this edge case interacts with a couple of other edge cases:

  1. if both t0 == t1 and the conditions of _save_t1 are triggered, then I think the latter will actually attempt to write out-of-bounds?
  2. if we have a boolean-returning event that triggers immediately (which is another reason to never enter the integration loop), then do we do the right thing? (Whether t0 == t1 or t0 != t1.)

I think (1) at least might be solved by putting the new _save_ts after _save_t1, so that we essentially just overwrite the latter if we hit this case -- WDYT?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Oh also you'll see a force-push on this branch -- I updated dev and this was what gave a readable diff for me to review here. I realised after-the-fact that this means you'll now need to force-pull your own local branch... sorry about that!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right about point 1: fixed in most recent commit.

I'm having a hard time seeing what might go wrong with such a boolean event that immediately triggers. In this case, we should still just fill in y0 as we do now if subsaveat.ts==True and t0==t1. If t0!=t1 then the cond doesn't get triggered anyways and this discussion is moot. Do you agree?

I'm also wondering about what we should be doing if t0==t1 and subsaveat.steps==True. If it is, but subsaveat.ts is None, then we don't fill in y0 for any of the steps. This feels like the right behavior. On the other hand, if subsaveat.steps==True and subsaveat.ts is not None, then we fill in y0 for all values in save_state, including those for the "steps". Do you think this is the right behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Also sorry for the long delay, been a crazy week!)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

including those for the "steps". Do you think this is the right behavior?

Oh, good observation! You're right. Ever-so-technically we shouldn't fill it in for the steps, I think.

(Also sorry for the long delay, been a crazy week!)

Oh no worries at all, I sympathize entirely :D

I'm having a hard time seeing what might go wrong with such a boolean event that immediately triggers. In this case, we should still just fill in y0 as we do now if subsaveat.ts==True and t0==t1. If t0!=t1 then the cond doesn't get triggered anyways and this discussion is moot. Do you agree?

I think you're correct, it's just quite a tricky thing to thread the needle on the logic. I'll probably just add a test for this once we have this PR in :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, updated to not update steps when t0==t1 and subsaveat.steps==True. Edge cases FTW!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added a test for this case

final_state = eqx.tree_at(
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
)
Expand Down
9 changes: 3 additions & 6 deletions diffrax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optimistix as optx
from jaxtyping import Array, ArrayLike, PyTree, Shaped

Expand Down Expand Up @@ -146,12 +147,8 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike
# predicate is statically known.
# This in turn allows us to perform some trace-time optimisations that XLA isn't
# smart enough to do on its own.
if (
type(pred) is not bool
and type(jax.core.get_aval(pred)) is jax.core.ConcreteArray
):
with jax.ensure_compile_time_eval():
pred = pred.item()
if isinstance(pred, (np.ndarray, np.generic)) and pred.shape == ():
pred = pred.item()
if pred is True:
return a
elif pred is False:
Expand Down
4 changes: 4 additions & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .align import ALIGN as ALIGN
from .base import (
AbstractAdaptiveSolver as AbstractAdaptiveSolver,
AbstractImplicitSolver as AbstractImplicitSolver,
Expand All @@ -12,6 +13,7 @@
from .dopri8 import Dopri8 as Dopri8
from .euler import Euler as Euler
from .euler_heun import EulerHeun as EulerHeun
from .foster_langevin_srk import AbstractFosterLangevinSRK as AbstractFosterLangevinSRK
from .heun import Heun as Heun
from .implicit_euler import ImplicitEuler as ImplicitEuler
from .kencarp3 import KenCarp3 as KenCarp3
Expand All @@ -26,6 +28,7 @@
ItoMilstein as ItoMilstein,
StratonovichMilstein as StratonovichMilstein,
)
from .quicsort import QUICSORT as QUICSORT
from .ralston import Ralston as Ralston
from .reversible_heun import ReversibleHeun as ReversibleHeun
from .runge_kutta import (
Expand All @@ -42,6 +45,7 @@
from .semi_implicit_euler import SemiImplicitEuler as SemiImplicitEuler
from .shark import ShARK as ShARK
from .shark_general import GeneralShARK as GeneralShARK
from .should import ShOULD as ShOULD
from .sil3 import Sil3 as Sil3
from .slowrk import SlowRK as SlowRK
from .spark import SPaRK as SPaRK
Expand Down
191 changes: 191 additions & 0 deletions diffrax/_solver/align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jaxtyping import ArrayLike, PyTree

from .._custom_types import (
AbstractSpaceTimeLevyArea,
RealScalarLike,
)
from .._local_interpolation import LocalLinearInterpolation
from .._term import (
UnderdampedLangevinLeaf,
UnderdampedLangevinTuple,
UnderdampedLangevinX,
)
from .foster_langevin_srk import (
AbstractCoeffs,
AbstractFosterLangevinSRK,
UnderdampedLangevinArgs,
)


# For an explanation of the coefficients, see foster_langevin_srk.py
class _ALIGNCoeffs(AbstractCoeffs):
beta: PyTree[ArrayLike]
a1: PyTree[ArrayLike]
b1: PyTree[ArrayLike]
aa: PyTree[ArrayLike]
chh: PyTree[ArrayLike]
dtype: jnp.dtype = eqx.field(static=True)

def __init__(self, beta, a1, b1, aa, chh):
self.beta = beta
self.a1 = a1
self.b1 = b1
self.aa = aa
self.chh = chh
all_leaves = jtu.tree_leaves([self.beta, self.a1, self.b1, self.aa, self.chh])
self.dtype = jnp.result_type(*all_leaves)


_ErrorEstimate = UnderdampedLangevinTuple


class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]):
r"""The Adaptive Langevin via Interpolated Gradients and Noise method
designed by James Foster. This is a second order solver for the
Underdamped Langevin Diffusion, and accepts terms of the form
`MultiTerm(UnderdampedLangevinDriftTerm, UnderdampedLangevinDiffusionTerm)`.
Uses two evaluations of the vector
field per step, but is FSAL, so in practice it only requires one.

??? cite "Reference"

This is a modification of the Strang-Splitting method from Definition 4.2 of

```bibtex
@misc{foster2021shiftedode,
title={The shifted ODE method for underdamped Langevin MCMC},
author={James Foster and Terry Lyons and Harald Oberhauser},
year={2021},
eprint={2101.03446},
archivePrefix={arXiv},
primaryClass={math.NA},
url={https://arxiv.org/abs/2101.03446},
}
```

"""

interpolation_cls = LocalLinearInterpolation
minimal_levy_area = AbstractSpaceTimeLevyArea
taylor_threshold: float = eqx.field(static=True)
_is_fsal = True

def __init__(self, taylor_threshold: float = 0.1):
r"""**Arguments:**

- `taylor_threshold`: If the product `h*gamma` is less than this, then
the Taylor expansion will be used to compute the coefficients.
Otherwise they will be computed directly. When using float32, the
empirically optimal value is 0.1, and for float64 about 0.01.
"""
self.taylor_threshold = taylor_threshold

def order(self, terms):
del terms
return 2

def strong_order(self, terms):
del terms
return 2.0

def _directly_compute_coeffs_leaf(
self, h: RealScalarLike, c: UnderdampedLangevinLeaf
) -> _ALIGNCoeffs:
del self
# c is a leaf of gamma
# compute the coefficients directly (as opposed to via Taylor expansion)
al = c * h
beta = jnp.exp(-al)
a1 = (1 - beta) / c
b1 = (beta + al - 1) / (c * al)
aa = a1 / h

al2 = al**2
chh = 6 * (beta * (al + 2) + al - 2) / (al2 * c)

return _ALIGNCoeffs(
beta=beta,
a1=a1,
b1=b1,
aa=aa,
chh=chh,
)

def _tay_coeffs_single(self, c: UnderdampedLangevinLeaf) -> _ALIGNCoeffs:
del self
# c is a leaf of gamma
zero = jnp.zeros_like(c)
one = jnp.ones_like(c)
c2 = jnp.square(c)
c3 = c2 * c
c4 = c3 * c
c5 = c4 * c

# Coefficients of the Taylor expansion, starting from 5th power
# to 0th power. The descending power order is because of jnp.polyval
beta = jnp.stack([-c5 / 120, c4 / 24, -c3 / 6, c2 / 2, -c, one], axis=-1)
a1 = jnp.stack([c4 / 120, -c3 / 24, c2 / 6, -c / 2, one, zero], axis=-1)
b1 = jnp.stack([c4 / 720, -c3 / 120, c2 / 24, -c / 6, one / 2, zero], axis=-1)
aa = jnp.stack([-c5 / 720, c4 / 120, -c3 / 24, c2 / 6, -c / 2, one], axis=-1)
chh = jnp.stack([c4 / 168, -c3 / 30, 3 * c2 / 20, -c / 2, one, zero], axis=-1)

correct_shape = jnp.shape(c) + (6,)
assert (
beta.shape == a1.shape == b1.shape == aa.shape == chh.shape == correct_shape
)

return _ALIGNCoeffs(
beta=beta,
a1=a1,
b1=b1,
aa=aa,
chh=chh,
)

def _compute_step(
self,
h: RealScalarLike,
levy: AbstractSpaceTimeLevyArea,
x0: UnderdampedLangevinX,
v0: UnderdampedLangevinX,
underdamped_langevin_args: UnderdampedLangevinArgs,
coeffs: _ALIGNCoeffs,
rho: UnderdampedLangevinX,
prev_f: UnderdampedLangevinX,
) -> tuple[
UnderdampedLangevinX,
UnderdampedLangevinX,
UnderdampedLangevinX,
UnderdampedLangevinTuple,
]:
dtypes = jtu.tree_map(jnp.result_type, x0)
w: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)

gamma, u, f = underdamped_langevin_args

uh = (u**ω * h).ω
f0 = prev_f
x1 = (
x0**ω
+ coeffs.a1**ω * v0**ω
- coeffs.b1**ω * uh**ω * f0**ω
+ rho**ω * (coeffs.b1**ω * w**ω + coeffs.chh**ω * hh**ω)
).ω
f1 = f(x1)
v1 = (
coeffs.beta**ω * v0**ω
- u**ω * ((coeffs.a1**ω - coeffs.b1**ω) * f0**ω + coeffs.b1**ω * f1**ω)
+ rho**ω * (coeffs.aa**ω * w**ω - gamma**ω * coeffs.chh**ω * hh**ω)
).ω

error_estimate = (
jtu.tree_map(jnp.zeros_like, x0),
(-(u**ω) * coeffs.b1**ω * (f1**ω - f0**ω)).ω,
)

return x1, v1, f1, error_estimate
Loading