Skip to content

Reorganized JAX link folder structure #913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify

# Load dispatch specializations
import pytensor.link.jax.dispatch.scalar
import pytensor.link.jax.dispatch.tensor_basic
import pytensor.link.jax.dispatch.subtensor
import pytensor.link.jax.dispatch.shape
import pytensor.link.jax.dispatch.blas
import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.math
import pytensor.link.jax.dispatch.nlinalg
import pytensor.link.jax.dispatch.slinalg
import pytensor.link.jax.dispatch.random
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.scalar
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.sparse
import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.shape
import pytensor.link.jax.dispatch.slinalg
import pytensor.link.jax.dispatch.sort
import pytensor.link.jax.dispatch.sparse
import pytensor.link.jax.dispatch.subtensor
import pytensor.link.jax.dispatch.tensor_basic

# isort: on
14 changes: 14 additions & 0 deletions pytensor/link/jax/dispatch/blas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import jax.numpy as jnp

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot


@jax_funcify.register(BatchedDot)
def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match along the first dimension of BatchedDot")
return jnp.matmul(a, b)

return batched_dot
60 changes: 60 additions & 0 deletions pytensor/link/jax/dispatch/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import jax.numpy as jnp
import numpy as np

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.math import Argmax, Dot, Max


@jax_funcify.register(Dot)
def jax_funcify_Dot(op, **kwargs):
def dot(x, y):
return jnp.dot(x, y)

return dot


@jax_funcify.register(Max)
def jax_funcify_Max(op, **kwargs):
axis = op.axis

def max(x):
max_res = jnp.max(x, axis)

return max_res

return max


@jax_funcify.register(Argmax)
def jax_funcify_Argmax(op, **kwargs):
axis = op.axis

def argmax(x):
if axis is None:
axes = tuple(range(x.ndim))

Check warning on line 34 in pytensor/link/jax/dispatch/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/math.py#L34

Added line #L34 was not covered by tests
else:
axes = tuple(int(ax) for ax in axis)

# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]

# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = (
*kept_shape,
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(tuple(new_shape))

max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")

return max_idx_res

return argmax
68 changes: 0 additions & 68 deletions pytensor/link/jax/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import jax.numpy as jnp
import numpy as np

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Argmax, Dot, Max
from pytensor.tensor.nlinalg import (
SVD,
Det,
Expand Down Expand Up @@ -80,14 +77,6 @@ def qr_full(x, mode=mode):
return qr_full


@jax_funcify.register(Dot)
def jax_funcify_Dot(op, **kwargs):
def dot(x, y):
return jnp.dot(x, y)

return dot


@jax_funcify.register(MatrixPinv)
def jax_funcify_Pinv(op, **kwargs):
def pinv(x):
Expand All @@ -96,66 +85,9 @@ def pinv(x):
return pinv


@jax_funcify.register(BatchedDot)
def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension")
return jnp.matmul(a, b)

return batched_dot


@jax_funcify.register(KroneckerProduct)
def jax_funcify_KroneckerProduct(op, **kwargs):
def _kron(x, y):
return jnp.kron(x, y)

return _kron


@jax_funcify.register(Max)
def jax_funcify_Max(op, **kwargs):
axis = op.axis

def max(x):
max_res = jnp.max(x, axis)

return max_res

return max


@jax_funcify.register(Argmax)
def jax_funcify_Argmax(op, **kwargs):
axis = op.axis

def argmax(x):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)

# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]

# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = (
*kept_shape,
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(tuple(new_shape))

max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")

return max_idx_res

return argmax
36 changes: 36 additions & 0 deletions tests/link/jax/test_blas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
import pytest

from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas
from pytensor.tensor.type import tensor3
from tests.link.jax.test_basic import compare_jax_and_py


def test_jax_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
a.tag.test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
)
b = tensor3("b")
b.tag.test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
)
out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError):
pytensor_jax_fn(*inputs)
52 changes: 52 additions & 0 deletions tests/link/jax/test_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
import pytest

from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.math import Argmax, Max, maximum
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py


jax = pytest.importorskip("jax")


def test_jax_max_and_argmax():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = dvector()
mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]])


def test_dot():
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
alpha = scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
beta = scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=config.floatX)

# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

out = maximum(y, x)
fgraph = FunctionGraph([y, x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

out = pt_max(y)
fgraph = FunctionGraph([y], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
71 changes: 1 addition & 70 deletions tests/link/jax/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,16 @@
import pytest

from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas
from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor.math import Argmax, Max, maximum
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector
from pytensor.tensor.type import matrix
from tests.link.jax.test_basic import compare_jax_and_py


jax = pytest.importorskip("jax")


def test_jax_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
a.tag.test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
)
b = tensor3("b")
b.tag.test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
)
out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError):
pytensor_jax_fn(*inputs)


def test_jax_basic_multiout():
rng = np.random.default_rng(213234)

Expand Down Expand Up @@ -79,45 +49,6 @@ def assert_fn(x, y):
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)


def test_jax_max_and_argmax():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = dvector()
mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]])


def test_tensor_basics():
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
alpha = scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
beta = scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=config.floatX)

# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

out = maximum(y, x)
fgraph = FunctionGraph([y, x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

out = pt_max(y)
fgraph = FunctionGraph([y], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])


def test_pinv():
x = matrix("x")
x_inv = pt_nlinalg.pinv(x)
Expand Down
Loading