Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tests/test_sampler.py
Comment thread
KaelanDt marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ def test_sample_array_input():
dt = 0.1
ts = jnp.arange(0, 10_000, dt)

A = jnp.array([[3, 2], [2, 4.0]])
b, x0 = jnp.zeros(dim), jnp.zeros(dim)
# Add some noise to the time points to make the timesteps different
ts += jax.random.uniform(key, (ts.shape[0],)) * dt
ts = ts.sort()

A = jnp.array([[3, 2.5], [2, 4.0]])
b = jax.random.normal(jax.random.PRNGKey(1), (dim,))
x0 = jax.random.normal(jax.random.PRNGKey(2), (dim,))
D = 2 * jnp.eye(dim)

samples = thermox.sample(key, ts, x0, A, b, D)
Expand Down
2 changes: 1 addition & 1 deletion thermox/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def log_prob(
ts: Times at which samples are collected. Includes time for x0.
xs: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
Expand Down
45 changes: 28 additions & 17 deletions thermox/sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax.lax import scan
from jax import Array

from thermox.utils import (
Expand All @@ -27,6 +27,9 @@ def sample_identity_diffusion(
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2)
where T=len(ts).

Uses jax.lax.associative_scan, so will run in time O(log(T) * d^2) on a GPU/TPU with
Comment thread
KaelanDt marked this conversation as resolved.
Outdated
O(T) cores.

Args:
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
Expand All @@ -48,30 +51,35 @@ def expm_vp(v, dt):
out = A.eigvecs @ out
return out.real

def transition_mean(x, dt):
return b + expm_vp(x - b, dt)

def transition_cov_sqrt_vp(v, dt):
diag = ((1 - jnp.exp(-2 * A.sym_eigvals * dt)) / (2 * A.sym_eigvals)) ** 0.5
out = diag * v
out = A.sym_eigvecs @ out
return out.real

def next_x(x, dt, tkey):
randv = jax.random.normal(tkey, shape=x.shape)
return transition_mean(x, dt) + transition_cov_sqrt_vp(randv, dt)
dts = jnp.diff(ts)

def scan_body(x_and_key, dt):
x, rk = x_and_key
rk, rk_use = jax.random.split(rk)
x = next_x(x, dt, rk_use)
return (x, rk), x
# transition_mean(x, dt) = b + expm_vp(x - b, dt)
def position_dep_mean_component(x, dt):
return expm_vp(x, dt)

dts = jnp.diff(ts)
gauss_samps = jax.random.normal(key, (len(dts),) + x0.shape)
position_indep_terms = jax.vmap(transition_cov_sqrt_vp)(gauss_samps, dts) + b

xs = scan(scan_body, (x0, key), dts)[1]
xs = jnp.concatenate([jnp.expand_dims(x0, axis=0), xs], axis=0)
return xs
@partial(jax.vmap, in_axes=(0, 0))
def binary_associative_operator(elem_a, elem_b):
Comment thread
SamDuffield marked this conversation as resolved.
Outdated
t_a, x_a = elem_a
t_b, x_b = elem_b
return t_a + t_b, position_dep_mean_component(x_a, t_b) + x_b

scan_times = jnp.concatenate([ts[:1], dts], dtype=float) # [t0, dt1, dt2, ...]
scan_input_values = (
jnp.concatenate([x0[None], position_indep_terms], axis=0) - b
) # Shift by b to ensure expm_vp(x - b, dt) is calculated at each step
scan_elems = (scan_times, scan_input_values)

scan_output = jax.lax.associative_scan(binary_associative_operator, scan_elems)
return scan_output[1] + b


def sample(
Expand All @@ -91,6 +99,9 @@ def sample(
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2),
where T=len(ts).

Uses jax.lax.associative_scan, so will run in time O(log(T) * d^2) on a GPU/TPU with
O(T) cores.

By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.
Expand All @@ -100,7 +111,7 @@ def sample(
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
Expand Down