diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index 860f6af676..cd2f816aab 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -1,159 +1,193 @@ import jax import jax.numpy as jnp -from pytensor.graph.fg import FunctionGraph from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.scan.op import Scan -from pytensor.scan.utils import ScanArgs @jax_funcify.register(Scan) -def jax_funcify_Scan(op, **kwargs): - inner_fg = FunctionGraph(op.inputs, op.outputs) - jax_at_inner_func = jax_funcify(inner_fg, **kwargs) +def jax_funcify_Scan(op: Scan, **kwargs): + info = op.info - def scan(*outer_inputs): - scan_args = ScanArgs( - list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info + if info.as_while: + raise NotImplementedError("While Scan cannot yet be converted to JAX") + + if info.n_mit_mot: + raise NotImplementedError( + "Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX" ) - # `outer_inputs` is a list with the following composite form: - # [n_steps] - # + outer_in_seqs - # + outer_in_mit_mot - # + outer_in_mit_sot - # + outer_in_sit_sot - # + outer_in_shared - # + outer_in_nit_sot - # + outer_in_non_seqs - n_steps = scan_args.n_steps - seqs = scan_args.outer_in_seqs - - # TODO: mit_mots - mit_mot_in_slices = [] - - mit_sot_in_slices = [] - for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot): - neg_taps = [abs(t) for t in tap if t < 0] - pos_taps = [abs(t) for t in tap if t > 0] - max_neg = max(neg_taps) if neg_taps else 0 - max_pos = max(pos_taps) if pos_taps else 0 - init_slice = seq[: max_neg + max_pos] - mit_sot_in_slices.append(init_slice) - - sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot] + # Optimize inner graph + rewriter = op.mode_instance.optimizer + rewriter(op.fgraph) + scan_inner_func = jax_funcify(op.fgraph, **kwargs) + + def scan(*outer_inputs): + # Extract JAX scan inputs + outer_inputs = list(outer_inputs) + n_steps = outer_inputs[0] # JAX `length` + seqs = op.outer_seqs(outer_inputs) # JAX `xs` + + mit_sot_init = [] + for tap, seq in zip(op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs)): + init_slice = seq[: abs(min(tap))] + mit_sot_init.append(init_slice) + + sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)] init_carry = ( - mit_mot_in_slices, - mit_sot_in_slices, - sit_sot_in_slices, - scan_args.outer_in_shared, - scan_args.outer_in_non_seqs, - ) + mit_sot_init, + sit_sot_init, + op.outer_shared(outer_inputs), + op.outer_non_seqs(outer_inputs), + ) # JAX `init` + + def jax_args_to_inner_func_args(carry, x): + """Convert JAX scan arguments into format expected by scan_inner_func. + + scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs) + """ - def jax_args_to_inner_scan(op, carry, x): - # `carry` contains all inner-output taps, non_seqs, and shared - # terms + # `carry` contains all inner taps, shared terms, and non_seqs ( - inner_in_mit_mot, - inner_in_mit_sot, - inner_in_sit_sot, - inner_in_shared, - inner_in_non_seqs, + inner_mit_sot, + inner_sit_sot, + inner_shared, + inner_non_seqs, ) = carry - # `x` contains the in_seqs - inner_in_seqs = x - - # `inner_scan_inputs` is a list with the following composite form: - # inner_in_seqs - # + sum(inner_in_mit_mot, []) - # + sum(inner_in_mit_sot, []) - # + inner_in_sit_sot - # + inner_in_shared - # + inner_in_non_seqs - inner_in_mit_sot_flatten = [] - for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices): - inner_in_mit_sot_flatten.extend(array[jnp.array(index)]) - - inner_scan_inputs = sum( - [ - inner_in_seqs, - inner_in_mit_mot, - inner_in_mit_sot_flatten, - inner_in_sit_sot, - inner_in_shared, - inner_in_non_seqs, - ], - [], - ) + # `x` contains the inner sequences + inner_seqs = x + + mit_sot_flatten = [] + for array, index in zip(inner_mit_sot, op.info.mit_sot_in_slices): + mit_sot_flatten.extend(array[jnp.array(index)]) + + inner_scan_inputs = [ + *inner_seqs, + *mit_sot_flatten, + *inner_sit_sot, + *inner_shared, + *inner_non_seqs, + ] return inner_scan_inputs - def inner_scan_outs_to_jax_outs( - op, + def inner_func_outs_to_jax_outs( old_carry, inner_scan_outs, ): + """Convert inner_scan_func outputs into format expected by JAX scan. + + old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys) + """ ( - inner_in_mit_mot, - inner_in_mit_sot, - inner_in_sit_sot, - inner_in_shared, - inner_in_non_seqs, + inner_mit_sot, + inner_sit_sot, + inner_shared, + inner_non_seqs, ) = old_carry - def update_mit_sot(mit_sot, new_val): - return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0) - - inner_out_mit_sot = [ - update_mit_sot(mit_sot, new_val) - for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs) + inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs) + inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs) + inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs) + inner_shared_outs = op.inner_shared_outs(inner_scan_outs) + + # Replace the oldest mit_sot tap by the newest value + inner_mit_sot_new = [ + jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0) + for old_mit_sot, new_val in zip( + inner_mit_sot, + inner_mit_sot_outs, + ) ] - # This should contain all inner-output taps, non_seqs, and shared - # terms - if not inner_in_sit_sot: - inner_out_sit_sot = [] - else: - inner_out_sit_sot = inner_scan_outs + # Nothing needs to be done with sit_sot + inner_sit_sot_new = inner_sit_sot_outs + + inner_shared_new = inner_shared + # Replace old shared inputs by new shared outputs + inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs + new_carry = ( - inner_in_mit_mot, + inner_mit_sot_new, + inner_sit_sot_new, + inner_shared_new, + inner_non_seqs, + ) + + # Shared variables and non_seqs are not traced + traced_outs = [ + *inner_mit_sot_outs, + *inner_sit_sot_outs, + *inner_nit_sot_outs, + ] + + return new_carry, traced_outs + + def jax_inner_func(carry, x): + inner_args = jax_args_to_inner_func_args(carry, x) + inner_scan_outs = list(scan_inner_func(*inner_args)) + new_carry, traced_outs = inner_func_outs_to_jax_outs(carry, inner_scan_outs) + return new_carry, traced_outs + + # Extract PyTensor scan outputs + final_carry, traces = jax.lax.scan( + jax_inner_func, init_carry, seqs, length=n_steps + ) + + def get_partial_traces(traces): + """Convert JAX scan traces to PyTensor traces. + + We need to: + 1. Prepend initial states to JAX output traces + 2. Slice final traces if Scan was instructed to only keep a portion + """ + + init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot + buffers = ( + op.outer_mitsot(outer_inputs) + + op.outer_sitsot(outer_inputs) + + op.outer_nitsot(outer_inputs) + ) + partial_traces = [] + for init_state, trace, buffer in zip(init_states, traces, buffers): + if init_state is not None: + # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer + full_trace = jnp.concatenate( + [jnp.atleast_1d(init_state), jnp.atleast_1d(trace)], + axis=0, + ) + buffer_size = buffer.shape[0] + else: + # NIT-SOT: Buffer is just the number of entries that should be returned + full_trace = jnp.atleast_1d(trace) + buffer_size = buffer + + partial_trace = full_trace[-buffer_size:] + partial_traces.append(partial_trace) + + return partial_traces + + def get_shared_outs(final_carry): + """Retrive last state of shared_outs from final_carry. + + These outputs cannot be traced in PyTensor Scan + """ + ( inner_out_mit_sot, inner_out_sit_sot, - inner_in_shared, + inner_out_shared, inner_in_non_seqs, - ) + ) = final_carry - return new_carry + shared_outs = inner_out_shared[: info.n_shared_outs] + return list(shared_outs) - def jax_inner_func(carry, x): - inner_args = jax_args_to_inner_scan(op, carry, x) - inner_scan_outs = list(jax_at_inner_func(*inner_args)) - new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs) - return new_carry, inner_scan_outs - - _, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps) - - # We need to prepend the initial values so that the JAX output will - # match the raw `Scan` `Op` output and, thus, work with a downstream - # `Subtensor` `Op` introduced by the `scan` helper function. - def append_scan_out(scan_in_part, scan_out_part): - return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0) - - if scan_args.outer_in_mit_sot: - scan_out_final = [ - append_scan_out(init, out) - for init, out in zip(scan_args.outer_in_mit_sot, scan_out) - ] - elif scan_args.outer_in_sit_sot: - scan_out_final = [ - append_scan_out(init, out) - for init, out in zip(scan_args.outer_in_sit_sot, scan_out) - ] + scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry) - if len(scan_out_final) == 1: - scan_out_final = scan_out_final[0] - return scan_out_final + if len(scan_outs_final) == 1: + scan_outs_final = scan_outs_final[0] + return scan_outs_final return scan diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index a249453609..588bb9e538 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -5,15 +5,13 @@ import pytest from pytensor.compile.function import function -from pytensor.compile.mode import Mode +from pytensor.compile.mode import get_mode from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.configdefaults import config from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op, get_test_value -from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.ifelse import ifelse -from pytensor.link.jax import JAXLinker from pytensor.raise_op import assert_op from pytensor.tensor.type import dscalar, scalar, vector @@ -27,12 +25,9 @@ def set_pytensor_flags(): jax = pytest.importorskip("jax") -jax_mode = Mode( - JAXLinker(), RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"]) -) -py_mode = Mode( - "py", RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) -) +# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs +jax_mode = get_mode("JAX") +py_mode = get_mode("FAST_COMPILE") def compare_jax_and_py( @@ -40,6 +35,8 @@ def compare_jax_and_py( test_inputs: Iterable, assert_fn: Optional[Callable] = None, must_be_device_array: bool = True, + jax_mode=jax_mode, + py_mode=py_mode, ): """Function to compare python graph output and jax compiled output for testing equality @@ -87,7 +84,7 @@ def compare_jax_and_py( else: assert_fn(jax_res, py_res) - return jax_res + return pytensor_jax_fn, jax_res def test_jax_FunctionGraph_once(): diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index b944aa951d..058a6b23c1 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -1,24 +1,202 @@ +import re + import numpy as np import pytest -from packaging.version import parse as version_parse import pytensor.tensor as at +from pytensor import function, shared +from pytensor.compile import get_mode from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph +from pytensor.scan import until from pytensor.scan.basic import scan +from pytensor.scan.op import Scan +from pytensor.tensor import random from pytensor.tensor.math import gammaln, log -from pytensor.tensor.type import ivector, lscalar, scalar +from pytensor.tensor.type import lscalar, scalar, vector from tests.link.jax.test_basic import compare_jax_and_py jax = pytest.importorskip("jax") -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_jax_scan_multiple_output(): +@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)]) +def test_scan_sit_sot(view): + x0 = at.scalar("x0", dtype="float64") + xs, _ = scan( + lambda xtm1: xtm1 + 1, + outputs_info=[x0], + n_steps=10, + ) + if view: + xs = xs[view] + fg = FunctionGraph([x0], [xs]) + test_input_vals = [np.e] + compare_jax_and_py(fg, test_input_vals) + + +@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) +def test_scan_mit_sot(view): + x0 = at.vector("x0", dtype="float64", shape=(3,)) + xs, _ = scan( + lambda xtm3, xtm1: xtm3 + xtm1 + 1, + outputs_info=[{"initial": x0, "taps": [-3, -1]}], + n_steps=10, + ) + if view: + xs = xs[view] + fg = FunctionGraph([x0], [xs]) + test_input_vals = [np.full((3,), np.e)] + compare_jax_and_py(fg, test_input_vals) + + +@pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)]) +@pytest.mark.parametrize("view_y", [None, (-1,), slice(-4, -1, None)]) +def test_scan_multiple_mit_sot(view_x, view_y): + x0 = at.vector("x0", dtype="float64", shape=(3,)) + y0 = at.vector("y0", dtype="float64", shape=(4,)) + + def step(xtm3, xtm1, ytm4, ytm2): + return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2 + + [xs, ys], _ = scan( + fn=step, + outputs_info=[ + {"initial": x0, "taps": [-3, -1]}, + {"initial": y0, "taps": [-4, -2]}, + ], + n_steps=10, + ) + if view_x: + xs = xs[view_x] + if view_y: + ys = ys[view_y] + + fg = FunctionGraph([x0, y0], [xs, ys]) + test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)] + compare_jax_and_py(fg, test_input_vals) + + +@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)]) +def test_scan_nit_sot(view): + rng = np.random.default_rng(seed=49) + + xs = at.vector("x0", dtype="float64", shape=(10,)) + + ys, _ = scan( + lambda x: at.exp(x), + outputs_info=[None], + sequences=[xs], + ) + if view: + ys = ys[view] + fg = FunctionGraph([xs], [ys]) + test_input_vals = [rng.normal(size=10)] + # We need to remove pushout rewrites, or the whole scan would just be + # converted to an Elemwise on xs + jax_fn, _ = compare_jax_and_py( + fg, test_input_vals, jax_mode=get_mode("JAX").excluding("scan_pushout") + ) + scan_nodes = [ + node for node in jax_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + assert len(scan_nodes) == 1 + + +@pytest.mark.xfail(raises=NotImplementedError) +def test_scan_mit_mot(): + xs = at.vector("xs", shape=(10,)) + ys, _ = scan( + lambda xtm2, xtm1: (xtm2 + xtm1), + outputs_info=[{"initial": xs, "taps": [-2, -1]}], + n_steps=10, + ) + grads_wrt_xs = at.grad(ys.sum(), wrt=xs) + fg = FunctionGraph([xs], [grads_wrt_xs]) + compare_jax_and_py(fg, [np.arange(10)]) + + +def test_scan_update(): + sh_static = shared(np.array(0.0), name="sh_static") + sh_update = shared(np.array(1.0), name="sh_update") + + xs, update = scan( + lambda sh_static, sh_update: ( + sh_static + sh_update, + {sh_update: sh_update * 2}, + ), + outputs_info=[None], + non_sequences=[sh_static, sh_update], + strict=True, + n_steps=7, + ) + + jax_fn = function([], xs, updates=update, mode="JAX") + np.testing.assert_array_equal(jax_fn(), np.array([1, 2, 4, 8, 16, 32, 64]) + 0.0) + + sh_static.set_value(1.0) + np.testing.assert_array_equal( + jax_fn(), np.array([128, 256, 512, 1024, 2048, 4096, 8192]) + 1.0 + ) + + sh_static.set_value(2.0) + sh_update.set_value(1.0) + np.testing.assert_array_equal(jax_fn(), np.array([1, 2, 4, 8, 16, 32, 64]) + 2.0) + + +def test_scan_rng_update(): + rng = shared(np.random.default_rng(190), name="rng") + + def update_fn(rng): + new_rng, x = random.normal(rng=rng).owner.outputs + return x, {rng: new_rng} + + xs, update = scan( + update_fn, + outputs_info=[None], + non_sequences=[rng], + strict=True, + n_steps=10, + ) + + # Without updates + with pytest.warns( + UserWarning, + match=re.escape("[rng] will not be used in the compiled JAX graph"), + ): + jax_fn = function([], [xs], updates=None, mode="JAX") + + res1, res2 = jax_fn(), jax_fn() + assert np.unique(res1).size == 10 + assert np.unique(res2).size == 10 + np.testing.assert_array_equal(res1, res2) + + # With updates + with pytest.warns( + UserWarning, + match=re.escape("[rng] will not be used in the compiled JAX graph"), + ): + jax_fn = function([], [xs], updates=update, mode="JAX") + + res1, res2 = jax_fn(), jax_fn() + assert np.unique(res1).size == 10 + assert np.unique(res2).size == 10 + assert np.all(np.not_equal(res1, res2)) + + +@pytest.mark.xfail(raises=NotImplementedError) +def test_scan_while(): + xs, _ = scan( + lambda x: (x + 1, until(x < 10)), + outputs_info=[at.zeros(())], + n_steps=100, + ) + + fg = FunctionGraph([], [xs]) + compare_jax_and_py(fg, []) + + +def test_scan_SEIR(): """Test a scan implementation of a SEIR model. SEIR model definition: @@ -38,8 +216,8 @@ def binom_log_prob(n, p, value): return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) # sequences - at_C = ivector("C_t") - at_D = ivector("D_t") + at_C = vector("C_t", dtype="int32", shape=(8,)) + at_D = vector("D_t", dtype="int32", shape=(8,)) # outputs_info (initial conditions) st0 = lscalar("s_t0") et0 = lscalar("e_t0") @@ -108,11 +286,7 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): compare_jax_and_py(out_fg, test_input_vals) -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_jax_scan_tap_output(): +def test_scan_mitsot_with_nonseq(): a_at = scalar("a") def input_step_fn(y_tm1, y_tm3, a): diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index fdab0f5962..dbe755c592 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -33,7 +33,7 @@ def test_jax_basic(): np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), ] - (jax_res,) = compare_jax_and_py(out_fg, test_input_vals) + _, [jax_res] = compare_jax_and_py(out_fg, test_input_vals) # Confirm that the `Subtensor` slice operations are correct assert jax_res.shape == (5, 3) diff --git a/tests/link/jax/test_subtensor.py b/tests/link/jax/test_subtensor.py index 04d6a93cfc..3c60914a2d 100644 --- a/tests/link/jax/test_subtensor.py +++ b/tests/link/jax/test_subtensor.py @@ -71,13 +71,15 @@ def test_jax_Subtensor_dynamic(): def test_jax_Subtensor_boolean_mask(): """JAX does not support resizing arrays with boolean masks.""" - x_at = at.arange(-5, 5) + x_at = at.vector("x", dtype="float64") out_at = x_at[x_at < 0] assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([x_at], [out_at]) + + x_at_test = np.arange(-5, 5) with pytest.raises(NotImplementedError, match="resizing arrays with boolean"): - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) + compare_jax_and_py(out_fg, [x_at_test]) def test_jax_Subtensor_boolean_mask_reexpressible(): diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 56a23908dc..3f061dbe31 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -15,7 +15,7 @@ def test_jax_Alloc(): x = at.alloc(0.0, 2, 3) x_fg = FunctionGraph([], [x]) - (jax_res,) = compare_jax_and_py(x_fg, []) + _, [jax_res] = compare_jax_and_py(x_fg, []) assert jax_res.shape == (2, 3)