-
-
Notifications
You must be signed in to change notification settings - Fork 158
Save fix for t0==t1
#494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save fix for t0==t1
#494
Changes from 11 commits
fa7417f
0f809d0
d80ff7a
cbf04ff
2e7836d
8b34a1c
161f2a6
c88305d
c96ee56
439887c
dc0dba4
c994a60
78289c9
666948c
1bd4e08
f23456e
0e00411
065fe11
47880be
bb292b7
c8ca285
059f72b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
import jax.numpy as jnp | ||
import jax.tree_util as jtu | ||
import lineax.internal as lxi | ||
import numpy as np | ||
import optimistix as optx | ||
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real | ||
|
||
|
@@ -258,12 +259,10 @@ def _maybe_static(static_x: Optional[ArrayLike], x: ArrayLike) -> ArrayLike: | |
# Some values (made_jump and result) are not used in many common use-cases. If we | ||
# detect that they're unused then we make sure they're non-Array Python values, so | ||
# that we can special case on them at trace time and get a performance boost. | ||
if isinstance(static_x, (bool, int, float, complex)): | ||
if isinstance(static_x, (bool, int, float, complex, np.ndarray)): | ||
return static_x | ||
elif static_x is None: | ||
return x | ||
elif type(jax.core.get_aval(static_x)) is jax.core.ConcreteArray: | ||
return static_x | ||
else: | ||
return x | ||
|
||
|
@@ -776,9 +775,51 @@ def _save_t1(subsaveat, save_state): | |
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state) | ||
return save_state | ||
|
||
save_state = jtu.tree_map( | ||
_save_t1, saveat.subs, final_state.save_state, is_leaf=_is_subsaveat | ||
def _save_ts_impl(ts, fn, _save_state): | ||
def _cond_fun(__save_state): | ||
return __save_state.saveat_ts_index < len(_save_state.ts) | ||
|
||
def _body_fun(__save_state): | ||
idx = __save_state.save_index | ||
ts = __save_state.ts.at[idx].set(t0) | ||
ys = jtu.tree_map( | ||
lambda _y, _ys: _ys.at[idx].set(_y), | ||
fn(t0, yfinal, args), | ||
__save_state.ys, | ||
) | ||
return SaveState( | ||
saveat_ts_index=idx + 1, | ||
ts=ts, | ||
ys=ys, | ||
save_index=idx + 1, | ||
) | ||
|
||
return inner_while_loop( | ||
_cond_fun, | ||
_body_fun, | ||
_save_state, | ||
max_steps=len(ts), | ||
buffers=_inner_buffers, | ||
checkpoints=len(ts), | ||
) | ||
|
||
def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: | ||
if subsaveat.ts is not None: | ||
save_state = _save_ts_impl(subsaveat.ts, subsaveat.fn, save_state) | ||
return save_state | ||
|
||
# if t0 == t1 then we don't enter the integration loop. In this case we have to | ||
# manually update the saved ts and ys if we want to save at "intermediate" | ||
# times specified by saveat.subs.ts | ||
save_state = jax.lax.cond( | ||
t0 == t1, | ||
lambda _save_state: jtu.tree_map( | ||
_save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat | ||
), | ||
lambda _save_state: _save_state, | ||
final_state.save_state, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you wrap this whole This is to avoid the vmap-of-cond-becomes-select issue: under vmap, then both branches will unconditionally trigger, which would really slow things down! |
||
save_state = jtu.tree_map(_save_t1, saveat.subs, save_state, is_leaf=_is_subsaveat) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this pretty much all LGTM! I can see that you're using the invariant I think my only concerns now are how this edge case interacts with a couple of other edge cases:
I think (1) at least might be solved by putting the new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Oh also you'll see a force-push on this branch -- I updated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're absolutely right about point 1: fixed in most recent commit. I'm having a hard time seeing what might go wrong with such a boolean event that immediately triggers. In this case, we should still just fill in I'm also wondering about what we should be doing if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Also sorry for the long delay, been a crazy week!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Oh, good observation! You're right. Ever-so-technically we shouldn't fill it in for the steps, I think.
Oh no worries at all, I sympathize entirely :D
I think you're correct, it's just quite a tricky thing to thread the needle on the logic. I'll probably just add a test for this once we have this PR in :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, updated to not update steps when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also added a test for this case |
||
final_state = eqx.tree_at( | ||
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import equinox as eqx | ||
import jax.numpy as jnp | ||
import jax.tree_util as jtu | ||
from equinox.internal import ω | ||
from jaxtyping import ArrayLike, PyTree | ||
|
||
from .._custom_types import ( | ||
AbstractSpaceTimeLevyArea, | ||
RealScalarLike, | ||
) | ||
from .._local_interpolation import LocalLinearInterpolation | ||
from .._term import ( | ||
UnderdampedLangevinLeaf, | ||
UnderdampedLangevinTuple, | ||
UnderdampedLangevinX, | ||
) | ||
from .foster_langevin_srk import ( | ||
AbstractCoeffs, | ||
AbstractFosterLangevinSRK, | ||
UnderdampedLangevinArgs, | ||
) | ||
|
||
|
||
# For an explanation of the coefficients, see foster_langevin_srk.py | ||
class _ALIGNCoeffs(AbstractCoeffs): | ||
beta: PyTree[ArrayLike] | ||
a1: PyTree[ArrayLike] | ||
b1: PyTree[ArrayLike] | ||
aa: PyTree[ArrayLike] | ||
chh: PyTree[ArrayLike] | ||
dtype: jnp.dtype = eqx.field(static=True) | ||
|
||
def __init__(self, beta, a1, b1, aa, chh): | ||
self.beta = beta | ||
self.a1 = a1 | ||
self.b1 = b1 | ||
self.aa = aa | ||
self.chh = chh | ||
all_leaves = jtu.tree_leaves([self.beta, self.a1, self.b1, self.aa, self.chh]) | ||
self.dtype = jnp.result_type(*all_leaves) | ||
|
||
|
||
_ErrorEstimate = UnderdampedLangevinTuple | ||
|
||
|
||
class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]): | ||
r"""The Adaptive Langevin via Interpolated Gradients and Noise method | ||
designed by James Foster. This is a second order solver for the | ||
Underdamped Langevin Diffusion, and accepts terms of the form | ||
`MultiTerm(UnderdampedLangevinDriftTerm, UnderdampedLangevinDiffusionTerm)`. | ||
Uses two evaluations of the vector | ||
field per step, but is FSAL, so in practice it only requires one. | ||
|
||
??? cite "Reference" | ||
|
||
This is a modification of the Strang-Splitting method from Definition 4.2 of | ||
|
||
```bibtex | ||
@misc{foster2021shiftedode, | ||
title={The shifted ODE method for underdamped Langevin MCMC}, | ||
author={James Foster and Terry Lyons and Harald Oberhauser}, | ||
year={2021}, | ||
eprint={2101.03446}, | ||
archivePrefix={arXiv}, | ||
primaryClass={math.NA}, | ||
url={https://arxiv.org/abs/2101.03446}, | ||
} | ||
``` | ||
|
||
""" | ||
|
||
interpolation_cls = LocalLinearInterpolation | ||
minimal_levy_area = AbstractSpaceTimeLevyArea | ||
taylor_threshold: float = eqx.field(static=True) | ||
_is_fsal = True | ||
|
||
def __init__(self, taylor_threshold: float = 0.1): | ||
r"""**Arguments:** | ||
|
||
- `taylor_threshold`: If the product `h*gamma` is less than this, then | ||
the Taylor expansion will be used to compute the coefficients. | ||
Otherwise they will be computed directly. When using float32, the | ||
empirically optimal value is 0.1, and for float64 about 0.01. | ||
""" | ||
self.taylor_threshold = taylor_threshold | ||
|
||
def order(self, terms): | ||
del terms | ||
return 2 | ||
|
||
def strong_order(self, terms): | ||
del terms | ||
return 2.0 | ||
|
||
def _directly_compute_coeffs_leaf( | ||
self, h: RealScalarLike, c: UnderdampedLangevinLeaf | ||
) -> _ALIGNCoeffs: | ||
del self | ||
# c is a leaf of gamma | ||
# compute the coefficients directly (as opposed to via Taylor expansion) | ||
al = c * h | ||
beta = jnp.exp(-al) | ||
a1 = (1 - beta) / c | ||
b1 = (beta + al - 1) / (c * al) | ||
aa = a1 / h | ||
|
||
al2 = al**2 | ||
chh = 6 * (beta * (al + 2) + al - 2) / (al2 * c) | ||
|
||
return _ALIGNCoeffs( | ||
beta=beta, | ||
a1=a1, | ||
b1=b1, | ||
aa=aa, | ||
chh=chh, | ||
) | ||
|
||
def _tay_coeffs_single(self, c: UnderdampedLangevinLeaf) -> _ALIGNCoeffs: | ||
del self | ||
# c is a leaf of gamma | ||
zero = jnp.zeros_like(c) | ||
one = jnp.ones_like(c) | ||
c2 = jnp.square(c) | ||
c3 = c2 * c | ||
c4 = c3 * c | ||
c5 = c4 * c | ||
|
||
# Coefficients of the Taylor expansion, starting from 5th power | ||
# to 0th power. The descending power order is because of jnp.polyval | ||
beta = jnp.stack([-c5 / 120, c4 / 24, -c3 / 6, c2 / 2, -c, one], axis=-1) | ||
a1 = jnp.stack([c4 / 120, -c3 / 24, c2 / 6, -c / 2, one, zero], axis=-1) | ||
b1 = jnp.stack([c4 / 720, -c3 / 120, c2 / 24, -c / 6, one / 2, zero], axis=-1) | ||
aa = jnp.stack([-c5 / 720, c4 / 120, -c3 / 24, c2 / 6, -c / 2, one], axis=-1) | ||
chh = jnp.stack([c4 / 168, -c3 / 30, 3 * c2 / 20, -c / 2, one, zero], axis=-1) | ||
|
||
correct_shape = jnp.shape(c) + (6,) | ||
assert ( | ||
beta.shape == a1.shape == b1.shape == aa.shape == chh.shape == correct_shape | ||
) | ||
|
||
return _ALIGNCoeffs( | ||
beta=beta, | ||
a1=a1, | ||
b1=b1, | ||
aa=aa, | ||
chh=chh, | ||
) | ||
|
||
def _compute_step( | ||
self, | ||
h: RealScalarLike, | ||
levy: AbstractSpaceTimeLevyArea, | ||
x0: UnderdampedLangevinX, | ||
v0: UnderdampedLangevinX, | ||
underdamped_langevin_args: UnderdampedLangevinArgs, | ||
coeffs: _ALIGNCoeffs, | ||
rho: UnderdampedLangevinX, | ||
prev_f: UnderdampedLangevinX, | ||
) -> tuple[ | ||
UnderdampedLangevinX, | ||
UnderdampedLangevinX, | ||
UnderdampedLangevinX, | ||
UnderdampedLangevinTuple, | ||
]: | ||
dtypes = jtu.tree_map(jnp.result_type, x0) | ||
w: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes) | ||
hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes) | ||
|
||
gamma, u, f = underdamped_langevin_args | ||
|
||
uh = (u**ω * h).ω | ||
f0 = prev_f | ||
x1 = ( | ||
x0**ω | ||
+ coeffs.a1**ω * v0**ω | ||
- coeffs.b1**ω * uh**ω * f0**ω | ||
+ rho**ω * (coeffs.b1**ω * w**ω + coeffs.chh**ω * hh**ω) | ||
).ω | ||
f1 = f(x1) | ||
v1 = ( | ||
coeffs.beta**ω * v0**ω | ||
- u**ω * ((coeffs.a1**ω - coeffs.b1**ω) * f0**ω + coeffs.b1**ω * f1**ω) | ||
+ rho**ω * (coeffs.aa**ω * w**ω - gamma**ω * coeffs.chh**ω * hh**ω) | ||
).ω | ||
|
||
error_estimate = ( | ||
jtu.tree_map(jnp.zeros_like, x0), | ||
(-(u**ω) * coeffs.b1**ω * (f1**ω - f0**ω)).ω, | ||
) | ||
|
||
return x1, v1, f1, error_estimate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this
t0 == t1
branch can be made much more efficient: I imagine you should be able to do the tree-mapped equivalent ofys = jnp.where(ts == t0, y0, jnp.inf)
. Does something go wrong with this approach that I'm missing?(Regardless of the above, do also note that you are doing a tree-map-of-a-loop. If you can, it is usually much more efficient to do a loop-of-a-tree-map.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the note about tree-map-of-a-loop vs. loop-of-a-tree-map! This was a gotcha I was unaware of.
I am probably missing something here but it seems to me at least naively that things could get screwed up based on the way data is saved in SaveState. In my test, we would have for
saveat
and for
final_state.save_state
we haveThe issue is that for
final_state.save_state[0]
, we want to fill inys
andts
starting at index 0, whereas forfinal_state.save_state[1]
we want to start filling in at index 1 sincesaveat.subs[1].t0==True
. Its this dependence on the value ofsaveat_ts_index
orsave_index
that made me implement this in a loop and makes it non-obvious to me how atree_map
would be able to handle that.If you agree, then I will try to refactor this as a loop-of-a-tree-map vs. the way I have it now. Very possible I am missing something simple though :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. I think the only cases that affect this are
t0
andt1
? If we're saving neither of them then we should fill ints == t0
-many elements of the output. If we're saving just one of them then we want to save one more value. If we're saving both of them then we want to save two more values.So untested but something like this:
?