Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions examples/associative_scan.ipynb

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions tests/test_sampler.py
Comment thread
KaelanDt marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,27 @@


def test_sample_array_input():
jax.config.update("jax_enable_x64", True)
key = jax.random.PRNGKey(0)
dim = 2
dt = 0.1
ts = jnp.arange(0, 10_000, dt)

A = jnp.array([[3, 2], [2, 4.0]])
b, x0 = jnp.zeros(dim), jnp.zeros(dim)
# Add some noise to the time points to make the timesteps different
ts += jax.random.uniform(key, (ts.shape[0],)) * dt
ts = ts.sort()

A = jnp.array([[3, 2.5], [2, 4.0]])
b = jax.random.normal(jax.random.PRNGKey(1), (dim,))
x0 = jax.random.normal(jax.random.PRNGKey(2), (dim,))
D = 2 * jnp.eye(dim)

samples = thermox.sample(key, ts, x0, A, b, D)
samples = thermox.sample(key, ts, x0, A, b, D, associative_scan=False)

samp_cov = jnp.cov(samples.T)
samp_mean = jnp.mean(samples.T, axis=1)
assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1)
assert jnp.allclose(samp_mean, b, atol=1e-1)

samples_as = thermox.sample(key, ts, x0, A, b, D, associative_scan=True)
assert jnp.allclose(samples, samples_as, atol=1e-6)
81 changes: 31 additions & 50 deletions thermox/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,58 @@
)


def log_prob_identity_diffusion(
def log_prob(
ts: Array,
xs: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
) -> float:
D: Array | ProcessedDiffusionMatrix,
) -> Array:
"""Calculates log probability of samples from the Ornstein-Uhlenbeck process,
defined as:

dx = - A * (x - b) dt + dW
dx = - A * (x - b) dt + sqrt(D) dW

by using exact diagonalization.

Assumes x(t_0) is given deterministically.

Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2).
Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2),
where T=len(ts).

By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.

Args:
ts: Times at which samples are collected. Includes time for x0.
xs: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).

Returns:
Scalar log probability of given xs.
"""
A_y, D = handle_matrix_inputs(A, D)

ys = vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt_inv, xs)
b_y = D.sqrt_inv @ b
log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, b_y)

D_sqrt_inv_log_det = jnp.log(jnp.linalg.det(D.sqrt_inv))
return log_prob_ys + D_sqrt_inv_log_det * (len(ts) - 1)


def log_prob_identity_diffusion(
ts: Array,
xs: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
) -> float:
if isinstance(A, Array):
A = preprocess_drift_matrix(A)

Expand Down Expand Up @@ -76,49 +103,3 @@ def logpt(yt, y0, dt):
)

return log_prob_val.real


def log_prob(
ts: Array,
xs: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
D: Array | ProcessedDiffusionMatrix,
) -> Array:
"""Calculates log probability of samples from the Ornstein-Uhlenbeck process,
defined as:

dx = - A * (x - b) dt + sqrt(D) dW

by using exact diagonalization.

Assumes x(t_0) is given deterministically.

Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2),
where T=len(ts).

By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.

Args:
ts: Times at which samples are collected. Includes time for x0.
xs: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).

Returns:
Scalar log probability of given xs.
"""
A_y, D = handle_matrix_inputs(A, D)

ys = vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt_inv, xs)
b_y = D.sqrt_inv @ b
log_prob_ys = log_prob_identity_diffusion(ts, ys, A_y, b_y)

D_sqrt_inv_log_det = jnp.log(jnp.linalg.det(D.sqrt_inv))
return log_prob_ys + D_sqrt_inv_log_det * (len(ts) - 1)
133 changes: 91 additions & 42 deletions thermox/sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax.lax import scan
from jax import Array

from thermox.utils import (
Expand All @@ -11,34 +11,77 @@
)


def sample_identity_diffusion(
def sample(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
D: Array | ProcessedDiffusionMatrix,
associative_scan: bool = True,
) -> Array:
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:

dx = - A * (x - b) dt + dW
dx = - A * (x - b) dt + sqrt(D) dW

by using exact diagonalization.

Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2)
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2),
where T=len(ts).

If associative_scan=True then jax.lax.associative_scan is used which will run in
time O((T/p + log(T)) * d^2) on a GPU/TPU with p cores, still with
O(d^3) preprocessing.

By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.

Args:
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
associative_scan: If True, uses jax.lax.associative_scan.

Returns:
Array-like, desired samples.
shape: (len(ts), ) + x0.shape
"""
A_y, D = handle_matrix_inputs(A, D)

y0 = D.sqrt_inv @ x0
b_y = D.sqrt_inv @ b
ys = sample_identity_diffusion(key, ts, y0, A_y, b_y, associative_scan)
return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys)


def sample_identity_diffusion(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
associative_scan: bool = True,
) -> Array:
if associative_scan:
return _sample_identity_diffusion_associative_scan(key, ts, x0, A, b)
else:
return _sample_identity_diffusion_scan(key, ts, x0, A, b)


def _sample_identity_diffusion_scan(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
) -> Array:
if isinstance(A, Array):
A = preprocess_drift_matrix(A)

Expand All @@ -57,62 +100,68 @@ def transition_cov_sqrt_vp(v, dt):
out = A.sym_eigvecs @ out
return out.real

def next_x(x, dt, tkey):
randv = jax.random.normal(tkey, shape=x.shape)
return transition_mean(x, dt) + transition_cov_sqrt_vp(randv, dt)
def next_x(x, dt, rv):
return transition_mean(x, dt) + transition_cov_sqrt_vp(rv, dt)

def scan_body(x_and_key, dt):
x, rk = x_and_key
rk, rk_use = jax.random.split(rk)
x = next_x(x, dt, rk_use)
return (x, rk), x
def scan_body(carry, dt_and_rv):
x = carry
dt, rv = dt_and_rv
new_x = next_x(x, dt, rv)
return new_x, new_x

dts = jnp.diff(ts)
gauss_samps = jax.random.normal(key, (len(dts),) + x0.shape)

xs = scan(scan_body, (x0, key), dts)[1]
# Stack dts and gauss_samps along a new axis
dt_and_rv = (dts, gauss_samps)

_, xs = jax.lax.scan(scan_body, x0, dt_and_rv)
xs = jnp.concatenate([jnp.expand_dims(x0, axis=0), xs], axis=0)
return xs


def sample(
def _sample_identity_diffusion_associative_scan(
key: Array,
ts: Array,
x0: Array,
A: Array | ProcessedDriftMatrix,
b: Array,
D: Array | ProcessedDiffusionMatrix,
) -> Array:
"""Collects samples from the Ornstein-Uhlenbeck process, defined as:
if isinstance(A, Array):
A = preprocess_drift_matrix(A)

dx = - A * (x - b) dt + sqrt(D) dW
def expm_vp(v, dt):
out = A.eigvecs_inv @ v
out = jnp.exp(-A.eigvals * dt) * out
out = A.eigvecs @ out
return out.real

by using exact diagonalization.
def transition_cov_sqrt_vp(v, dt):
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
out = diag * v
out = A.sym_eigvecs @ out
return out.real

Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2),
where T=len(ts).
dts = jnp.diff(ts)

By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.
# transition_mean(x, dt) = b + expm_vp(x - b, dt)
def position_dep_mean_component(x, dt):
return expm_vp(x, dt)

Args:
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
gauss_samps = jax.random.normal(key, (len(dts),) + x0.shape)
position_indep_terms = jax.vmap(transition_cov_sqrt_vp)(gauss_samps, dts)

Returns:
Array-like, desired samples.
shape: (len(ts), ) + x0.shape
"""
A_y, D = handle_matrix_inputs(A, D)
@partial(jax.vmap, in_axes=(0, 0))
def binary_associative_operator(elem_a, elem_b):
t_a, x_a = elem_a
t_b, x_b = elem_b
return t_a + t_b, position_dep_mean_component(x_a, t_b) + x_b

y0 = D.sqrt_inv @ x0
b_y = D.sqrt_inv @ b
ys = sample_identity_diffusion(key, ts, y0, A_y, b_y)
return jax.vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt, ys)
scan_times = jnp.concatenate([ts[:1], dts], dtype=float) # [t0, dt1, dt2, ...]
scan_input_values = jnp.concatenate(
[x0[None] - b, position_indep_terms], axis=0
) # Shift input by b
scan_elems = (scan_times, scan_input_values)

scan_output = jax.lax.associative_scan(binary_associative_operator, scan_elems)
return scan_output[1] + b # Shift back by b