From 10875501e99695e39f63f111e834eae7f4e6804f Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Wed, 10 Jul 2024 16:57:07 +0800 Subject: [PATCH] Reorganized JAX link folder structure --- pytensor/link/jax/dispatch/__init__.py | 18 ++++--- pytensor/link/jax/dispatch/blas.py | 14 +++++ pytensor/link/jax/dispatch/math.py | 60 ++++++++++++++++++++++ pytensor/link/jax/dispatch/nlinalg.py | 68 ------------------------ tests/link/jax/test_blas.py | 36 +++++++++++++ tests/link/jax/test_math.py | 52 +++++++++++++++++++ tests/link/jax/test_nlinalg.py | 71 +------------------------- 7 files changed, 173 insertions(+), 146 deletions(-) create mode 100644 pytensor/link/jax/dispatch/blas.py create mode 100644 pytensor/link/jax/dispatch/math.py create mode 100644 tests/link/jax/test_blas.py create mode 100644 tests/link/jax/test_math.py diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index 7e27988cdf..1d8ae33104 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -2,18 +2,20 @@ from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify # Load dispatch specializations -import pytensor.link.jax.dispatch.scalar -import pytensor.link.jax.dispatch.tensor_basic -import pytensor.link.jax.dispatch.subtensor -import pytensor.link.jax.dispatch.shape +import pytensor.link.jax.dispatch.blas +import pytensor.link.jax.dispatch.blockwise +import pytensor.link.jax.dispatch.elemwise import pytensor.link.jax.dispatch.extra_ops +import pytensor.link.jax.dispatch.math import pytensor.link.jax.dispatch.nlinalg -import pytensor.link.jax.dispatch.slinalg import pytensor.link.jax.dispatch.random -import pytensor.link.jax.dispatch.elemwise +import pytensor.link.jax.dispatch.scalar import pytensor.link.jax.dispatch.scan -import pytensor.link.jax.dispatch.sparse -import pytensor.link.jax.dispatch.blockwise +import pytensor.link.jax.dispatch.shape +import pytensor.link.jax.dispatch.slinalg import pytensor.link.jax.dispatch.sort +import pytensor.link.jax.dispatch.sparse +import pytensor.link.jax.dispatch.subtensor +import pytensor.link.jax.dispatch.tensor_basic # isort: on diff --git a/pytensor/link/jax/dispatch/blas.py b/pytensor/link/jax/dispatch/blas.py new file mode 100644 index 0000000000..a0d0faeabb --- /dev/null +++ b/pytensor/link/jax/dispatch/blas.py @@ -0,0 +1,14 @@ +import jax.numpy as jnp + +from pytensor.link.jax.dispatch import jax_funcify +from pytensor.tensor.blas import BatchedDot + + +@jax_funcify.register(BatchedDot) +def jax_funcify_BatchedDot(op, **kwargs): + def batched_dot(a, b): + if a.shape[0] != b.shape[0]: + raise TypeError("Shapes must match along the first dimension of BatchedDot") + return jnp.matmul(a, b) + + return batched_dot diff --git a/pytensor/link/jax/dispatch/math.py b/pytensor/link/jax/dispatch/math.py new file mode 100644 index 0000000000..9aa6076ee7 --- /dev/null +++ b/pytensor/link/jax/dispatch/math.py @@ -0,0 +1,60 @@ +import jax.numpy as jnp +import numpy as np + +from pytensor.link.jax.dispatch import jax_funcify +from pytensor.tensor.math import Argmax, Dot, Max + + +@jax_funcify.register(Dot) +def jax_funcify_Dot(op, **kwargs): + def dot(x, y): + return jnp.dot(x, y) + + return dot + + +@jax_funcify.register(Max) +def jax_funcify_Max(op, **kwargs): + axis = op.axis + + def max(x): + max_res = jnp.max(x, axis) + + return max_res + + return max + + +@jax_funcify.register(Argmax) +def jax_funcify_Argmax(op, **kwargs): + axis = op.axis + + def argmax(x): + if axis is None: + axes = tuple(range(x.ndim)) + else: + axes = tuple(int(ax) for ax in axis) + + # NumPy does not support multiple axes for argmax; this is a + # work-around + keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") + # Not-reduced axes in front + transposed_x = jnp.transpose( + x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64")))) + ) + kept_shape = transposed_x.shape[: len(keep_axes)] + reduced_shape = transposed_x.shape[len(keep_axes) :] + + # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 + # Otherwise reshape would complain citing float arg + new_shape = ( + *kept_shape, + np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"), + ) + reshaped_x = transposed_x.reshape(tuple(new_shape)) + + max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") + + return max_idx_res + + return argmax diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index 0235f3c5db..8b6fc62f2a 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -1,9 +1,6 @@ import jax.numpy as jnp -import numpy as np from pytensor.link.jax.dispatch import jax_funcify -from pytensor.tensor.blas import BatchedDot -from pytensor.tensor.math import Argmax, Dot, Max from pytensor.tensor.nlinalg import ( SVD, Det, @@ -80,14 +77,6 @@ def qr_full(x, mode=mode): return qr_full -@jax_funcify.register(Dot) -def jax_funcify_Dot(op, **kwargs): - def dot(x, y): - return jnp.dot(x, y) - - return dot - - @jax_funcify.register(MatrixPinv) def jax_funcify_Pinv(op, **kwargs): def pinv(x): @@ -96,66 +85,9 @@ def pinv(x): return pinv -@jax_funcify.register(BatchedDot) -def jax_funcify_BatchedDot(op, **kwargs): - def batched_dot(a, b): - if a.shape[0] != b.shape[0]: - raise TypeError("Shapes must match in the 0-th dimension") - return jnp.matmul(a, b) - - return batched_dot - - @jax_funcify.register(KroneckerProduct) def jax_funcify_KroneckerProduct(op, **kwargs): def _kron(x, y): return jnp.kron(x, y) return _kron - - -@jax_funcify.register(Max) -def jax_funcify_Max(op, **kwargs): - axis = op.axis - - def max(x): - max_res = jnp.max(x, axis) - - return max_res - - return max - - -@jax_funcify.register(Argmax) -def jax_funcify_Argmax(op, **kwargs): - axis = op.axis - - def argmax(x): - if axis is None: - axes = tuple(range(x.ndim)) - else: - axes = tuple(int(ax) for ax in axis) - - # NumPy does not support multiple axes for argmax; this is a - # work-around - keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") - # Not-reduced axes in front - transposed_x = jnp.transpose( - x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64")))) - ) - kept_shape = transposed_x.shape[: len(keep_axes)] - reduced_shape = transposed_x.shape[len(keep_axes) :] - - # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 - # Otherwise reshape would complain citing float arg - new_shape = ( - *kept_shape, - np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"), - ) - reshaped_x = transposed_x.reshape(tuple(new_shape)) - - max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") - - return max_idx_res - - return argmax diff --git a/tests/link/jax/test_blas.py b/tests/link/jax/test_blas.py new file mode 100644 index 0000000000..fe162d1d45 --- /dev/null +++ b/tests/link/jax/test_blas.py @@ -0,0 +1,36 @@ +import numpy as np +import pytest + +from pytensor.compile.function import function +from pytensor.compile.mode import Mode +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.link.jax import JAXLinker +from pytensor.tensor import blas as pt_blas +from pytensor.tensor.type import tensor3 +from tests.link.jax.test_basic import compare_jax_and_py + + +def test_jax_BatchedDot(): + # tensor3 . tensor3 + a = tensor3("a") + a.tag.test_value = ( + np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) + ) + b = tensor3("b") + b.tag.test_value = ( + np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) + ) + out = pt_blas.BatchedDot()(a, b) + fgraph = FunctionGraph([a, b], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + # A dimension mismatch should raise a TypeError for compatibility + inputs = [get_test_value(a)[:-1], get_test_value(b)] + opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) + jax_mode = Mode(JAXLinker(), opts) + pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) + with pytest.raises(TypeError): + pytensor_jax_fn(*inputs) diff --git a/tests/link/jax/test_math.py b/tests/link/jax/test_math.py new file mode 100644 index 0000000000..0a1e91b4da --- /dev/null +++ b/tests/link/jax/test_math.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest + +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.tensor.math import Argmax, Max, maximum +from pytensor.tensor.math import max as pt_max +from pytensor.tensor.type import dvector, matrix, scalar, vector +from tests.link.jax.test_basic import compare_jax_and_py + + +jax = pytest.importorskip("jax") + + +def test_jax_max_and_argmax(): + # Test that a single output of a multi-output `Op` can be used as input to + # another `Op` + x = dvector() + mx = Max([0])(x) + amx = Argmax([0])(x) + out = mx * amx + out_fg = FunctionGraph([x], [out]) + compare_jax_and_py(out_fg, [np.r_[1, 2]]) + + +def test_dot(): + y = vector("y") + y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + x = vector("x") + x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) + A = matrix("A") + A.tag.test_value = np.empty((2, 2), dtype=config.floatX) + alpha = scalar("alpha") + alpha.tag.test_value = np.array(3.0, dtype=config.floatX) + beta = scalar("beta") + beta.tag.test_value = np.array(5.0, dtype=config.floatX) + + # This should be converted into a `Gemv` `Op` when the non-JAX compatible + # optimizations are turned on; however, when using JAX mode, it should + # leave the expression alone. + out = y.dot(alpha * A).dot(x) + beta * y + fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = maximum(y, x) + fgraph = FunctionGraph([y, x], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = pt_max(y) + fgraph = FunctionGraph([y], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 4340b395cb..cd6ca2ac71 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -2,46 +2,16 @@ import pytest from pytensor.compile.function import function -from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import get_test_value -from pytensor.graph.rewriting.db import RewriteDatabaseQuery -from pytensor.link.jax import JAXLinker -from pytensor.tensor import blas as pt_blas from pytensor.tensor import nlinalg as pt_nlinalg -from pytensor.tensor.math import Argmax, Max, maximum -from pytensor.tensor.math import max as pt_max -from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector +from pytensor.tensor.type import matrix from tests.link.jax.test_basic import compare_jax_and_py jax = pytest.importorskip("jax") -def test_jax_BatchedDot(): - # tensor3 . tensor3 - a = tensor3("a") - a.tag.test_value = ( - np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) - ) - b = tensor3("b") - b.tag.test_value = ( - np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) - ) - out = pt_blas.BatchedDot()(a, b) - fgraph = FunctionGraph([a, b], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - # A dimension mismatch should raise a TypeError for compatibility - inputs = [get_test_value(a)[:-1], get_test_value(b)] - opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) - jax_mode = Mode(JAXLinker(), opts) - pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) - with pytest.raises(TypeError): - pytensor_jax_fn(*inputs) - - def test_jax_basic_multiout(): rng = np.random.default_rng(213234) @@ -79,45 +49,6 @@ def assert_fn(x, y): compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) -def test_jax_max_and_argmax(): - # Test that a single output of a multi-output `Op` can be used as input to - # another `Op` - x = dvector() - mx = Max([0])(x) - amx = Argmax([0])(x) - out = mx * amx - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py(out_fg, [np.r_[1, 2]]) - - -def test_tensor_basics(): - y = vector("y") - y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) - x = vector("x") - x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) - A = matrix("A") - A.tag.test_value = np.empty((2, 2), dtype=config.floatX) - alpha = scalar("alpha") - alpha.tag.test_value = np.array(3.0, dtype=config.floatX) - beta = scalar("beta") - beta.tag.test_value = np.array(5.0, dtype=config.floatX) - - # This should be converted into a `Gemv` `Op` when the non-JAX compatible - # optimizations are turned on; however, when using JAX mode, it should - # leave the expression alone. - out = y.dot(alpha * A).dot(x) + beta * y - fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - out = maximum(y, x) - fgraph = FunctionGraph([y, x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - out = pt_max(y) - fgraph = FunctionGraph([y], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - def test_pinv(): x = matrix("x") x_inv = pt_nlinalg.pinv(x)