Skip to content

Create a NumPyroNUTS Op #4646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
314 changes: 162 additions & 152 deletions pymc3/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,150 +3,133 @@
import re
import warnings

from collections import defaultdict

xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])

import aesara.graph.fg
import aesara.tensor as at
import arviz as az
import jax
import numpy as np
import pandas as pd

from aesara.link.jax.jax_dispatch import jax_funcify

import pymc3 as pm
from aesara.compile import SharedVariable
from aesara.graph.basic import Apply, Constant, clone, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import MergeOptimizer
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.type import TensorType

from pymc3 import modelcontext

warnings.warn("This module is experimental.")

# Disable C compilation by default
# aesara.config.cxx = ""
# This will make the JAX Linker the default
# aesara.config.mode = "JAX"

class NumPyroNUTS(Op):
def __init__(
self,
inputs,
outputs,
target_accept=0.8,
draws=1000,
tune=1000,
chains=4,
seed=None,
progress_bar=True,
):
self.draws = draws
self.tune = tune
self.chains = chains
self.target_accept = target_accept
self.progress_bar = progress_bar
self.seed = seed

def sample_tfp_nuts(
draws=1000,
tune=1000,
chains=4,
target_accept=0.8,
random_seed=10,
model=None,
num_tuning_epoch=2,
num_compute_step_size=500,
):
import jax
self.inputs, self.outputs = clone(inputs, outputs, copy_inputs=False)
self.inputs_type = tuple([input.type for input in inputs])
self.outputs_type = tuple([output.type for output in outputs])
self.nin = len(inputs)
self.nout = len(outputs)
self.nshared = len([v for v in inputs if isinstance(v, SharedVariable)])
self.samples_bcast = [self.chains == 1, self.draws == 1]

from tensorflow_probability.substrates import jax as tfp
self.fgraph = FunctionGraph(self.inputs, self.outputs, clone=False)
MergeOptimizer().optimize(self.fgraph)

model = modelcontext(model)
super().__init__()

seed = jax.random.PRNGKey(random_seed)
def make_node(self, *inputs):

fgraph = model.logp.f.maker.fgraph
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]
# The samples for each variable
outputs = [
TensorType(v.dtype, self.samples_bcast + list(v.broadcastable))() for v in inputs
]

rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
# The leapfrog statistics
outputs += [TensorType("int64", self.samples_bcast)()]

@jax.pmap
def _sample(init_state, seed):
def gen_kernel(step_size):
hmc = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=logp_fn_jax, step_size=step_size)
return tfp.mcmc.DualAveragingStepSizeAdaptation(
hmc, tune // num_tuning_epoch, target_accept_prob=target_accept
)
all_inputs = list(inputs)
if self.nshared > 0:
all_inputs += self.inputs[-self.nshared :]

def trace_fn(_, pkr):
return pkr.new_step_size

def get_tuned_stepsize(samples, step_size):
return step_size[-1] * jax.numpy.std(samples[-num_compute_step_size:])

step_size = jax.tree_map(jax.numpy.ones_like, init_state)
for i in range(num_tuning_epoch - 1):
tuning_hmc = gen_kernel(step_size)
init_samples, tuning_result, kernel_results = tfp.mcmc.sample_chain(
num_results=tune // num_tuning_epoch,
current_state=init_state,
kernel=tuning_hmc,
trace_fn=trace_fn,
return_final_kernel_results=True,
seed=seed,
)
return Apply(self, all_inputs, outputs)

step_size = jax.tree_multimap(get_tuned_stepsize, list(init_samples), tuning_result)
init_state = [x[-1] for x in init_samples]

# Run inference
sample_kernel = gen_kernel(step_size)
mcmc_samples, leapfrog_num = tfp.mcmc.sample_chain(
num_results=draws,
num_burnin_steps=tune // num_tuning_epoch,
current_state=init_state,
kernel=sample_kernel,
trace_fn=lambda _, pkr: pkr.inner_results.leapfrogs_taken,
seed=seed,
)
def do_constant_folding(self, *args):
return False

return mcmc_samples, leapfrog_num
def perform(self, node, inputs, outputs):
raise NotImplementedError()

print("Compiling...")
tic2 = pd.Timestamp.now()
map_seed = jax.random.split(seed, chains)
mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed)

# map_seed = jax.random.split(seed, chains)
# mcmc_samples = _sample(init_state_batched, map_seed)
# tic4 = pd.Timestamp.now()
# print("Sampling time = ", tic4 - tic3)

posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}

az_trace = az.from_dict(posterior=posterior)
tic3 = pd.Timestamp.now()
print("Compilation + sampling time = ", tic3 - tic2)
return az_trace # , leapfrog_num, tic3 - tic2


def sample_numpyro_nuts(
draws=1000,
tune=1000,
chains=4,
target_accept=0.8,
random_seed=10,
model=None,
progress_bar=True,
keep_untransformed=False,
):
@jax_funcify.register(NumPyroNUTS)
def jax_funcify_NumPyroNUTS(op, node, **kwargs):
from numpyro.infer import MCMC, NUTS

from pymc3 import modelcontext
draws = op.draws
tune = op.tune
chains = op.chains
target_accept = op.target_accept
progress_bar = op.progress_bar
seed = op.seed

# Compile the "inner" log-likelihood function. This will have extra shared
# variable inputs as the last arguments
logp_fn = jax_funcify(op.fgraph, **kwargs)

if isinstance(logp_fn, (list, tuple)):
# This handles the new JAX backend, which always returns a tuple
logp_fn = logp_fn[0]

def _sample(*inputs):

if op.nshared > 0:
current_state = inputs[: -op.nshared]
shared_inputs = tuple(op.fgraph.inputs[-op.nshared :])
else:
current_state = inputs
shared_inputs = ()

def log_fn_wrap(x):
res = logp_fn(
*(
x
# We manually obtain the shared values and added them
# as arguments to our compiled "inner" function
+ tuple(
v.get_value(borrow=True, return_internal_type=True) for v in shared_inputs
)
)
)

model = modelcontext(model)
if isinstance(res, (list, tuple)):
# This handles the new JAX backend, which always returns a tuple
res = res[0]

seed = jax.random.PRNGKey(random_seed)
return -res

fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]

rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

@jax.jit
def _sample(current_state, seed):
step_size = jax.tree_map(jax.numpy.ones_like, init_state)
nuts_kernel = NUTS(
potential_fn=lambda x: -logp_fn_jax(*x),
# model=model,
potential_fn=log_fn_wrap,
target_accept_prob=target_accept,
adapt_step_size=True,
adapt_mass_matrix=True,
Expand All @@ -166,60 +149,87 @@ def _sample(current_state, seed):
pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",))
samples = pmap_numpyro.get_samples(group_by_chain=True)
leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)["num_steps"]
return samples, leapfrogs_taken

print("Compiling...")
tic2 = pd.Timestamp.now()
map_seed = jax.random.split(seed, chains)
mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed)
# map_seed = jax.random.split(seed, chains)
# mcmc_samples = _sample(init_state_batched, map_seed)
# tic4 = pd.Timestamp.now()
# print("Sampling time = ", tic4 - tic3)
return tuple(samples) + (leapfrogs_taken,)

posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
tic3 = pd.Timestamp.now()
posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed)
tic4 = pd.Timestamp.now()
return _sample

az_trace = az.from_dict(posterior=posterior)
print("Compilation + sampling time = ", tic3 - tic2)
print("Transformation time = ", tic4 - tic3)

return az_trace # , leapfrogs_taken, tic3 - tic2
def sample_numpyro_nuts(
draws=1000,
tune=1000,
chains=4,
target_accept=0.8,
random_seed=10,
model=None,
progress_bar=True,
keep_untransformed=False,
):
model = modelcontext(model)

seed = jax.random.PRNGKey(random_seed)

def _transform_samples(samples, model, keep_untransformed=False):
rv_names = [rv.name for rv in model.value_vars]
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
init_state_batched_at = [at.as_tensor(v) for v in init_state_batched]

# Find out which RVs we need to compute:
free_rv_names = {x.name for x in model.free_RVs}
unobserved_names = {x.name for x in model.unobserved_RVs}
nuts_inputs = sorted(
[v for v in graph_inputs([model.logpt]) if not isinstance(v, Constant)],
key=lambda x: isinstance(x, SharedVariable),
)
map_seed = jax.random.split(seed, chains)
numpyro_samples = NumPyroNUTS(
nuts_inputs,
[model.logpt],
target_accept=target_accept,
draws=draws,
tune=tune,
chains=chains,
seed=map_seed,
progress_bar=progress_bar,
)(*init_state_batched_at)

# Un-transform the transformed variables in JAX
sample_outputs = []
for i, (value_var, rv_samples) in enumerate(zip(model.value_vars, numpyro_samples[:-1])):
rv = model.values_to_rvs[value_var]
transform = getattr(value_var.tag, "transform", None)
if transform is not None:
untrans_value_var = transform.backward(rv, rv_samples)
untrans_value_var.name = rv.name
sample_outputs.append(untrans_value_var)

if keep_untransformed:
rv_samples.name = value_var.name
sample_outputs.append(rv_samples)
else:
rv_samples.name = rv.name
sample_outputs.append(rv_samples)

names_to_compute = unobserved_names - free_rv_names
ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute]
print("Compiling...")

# Create function graph for these:
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute)
tic1 = pd.Timestamp.now()
_sample = aesara.function(
[],
sample_outputs + [numpyro_samples[-1]],
allow_input_downcast=True,
on_unused_input="ignore",
accept_inplace=True,
mode="JAX",
)
tic2 = pd.Timestamp.now()

# Jaxify, which returns a list of functions, one for each op
jax_fns = jax_funcify(fgraph)
print("Compilation time = ", tic2 - tic1)

# Put together the inputs
inputs = [samples[x.name] for x in model.free_RVs]
print("Sampling...")

for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns):
*mcmc_samples, leapfrogs_taken = _sample()
tic3 = pd.Timestamp.now()

# We need a function taking a single argument to run vmap, while the
# jax_fn takes a list, so:
result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs)
print("Sampling time = ", tic3 - tic2)

# Add to sample dict
samples[cur_op.name] = result
posterior = {k.name: v for k, v in zip(sample_outputs, mcmc_samples)}

# Discard unwanted transformed variables, if desired:
vars_to_keep = set(
pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed)
)
samples = {x: y for x, y in samples.items() if x in vars_to_keep}
az_trace = az.from_dict(posterior=posterior)

return samples
return az_trace
Loading