diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index c7cb2632a1..9810e14178 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -10,5 +10,6 @@ import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.sparse +import pytensor.link.numba.dispatch.slinalg # isort: on diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py new file mode 100644 index 0000000000..ad8065defd --- /dev/null +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -0,0 +1,275 @@ +import ctypes + +import numba +import numpy as np +from numba.core import cgutils, types +from numba.extending import get_cython_function_address, intrinsic, overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack, get_blas_kind +from scipy import linalg + +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.basic import numba_funcify +from pytensor.tensor.slinalg import SolveTriangular + + +_PTR = ctypes.POINTER + +_dbl = ctypes.c_double +_float = ctypes.c_float +_char = ctypes.c_char +_int = ctypes.c_int + +_ptr_float = _PTR(_float) +_ptr_dbl = _PTR(_dbl) +_ptr_char = _PTR(_char) +_ptr_int = _PTR(_int) + + +@intrinsic +def val_to_dptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float64)(types.float64) + return sig, impl + + +@intrinsic +def val_to_zptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.complex128)(types.complex128) + return sig, impl + + +@intrinsic +def val_to_sptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float32)(types.float32) + return sig, impl + + +@intrinsic +def val_to_int_ptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.int32)(types.int32) + return sig, impl + + +@intrinsic +def int_ptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.int32(types.CPointer(types.int32)) + return sig, impl + + +@intrinsic +def dptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float64(types.CPointer(types.float64)) + return sig, impl + + +@intrinsic +def sptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float32(types.CPointer(types.float32)) + return sig, impl + + +def _get_float_pointer_for_dtype(blas_dtype): + if blas_dtype in ["s", "c"]: + return _ptr_float + elif blas_dtype in ["d", "z"]: + return _ptr_dbl + + +def _get_underlying_float(dtype): + s_dtype = str(dtype) + out_type = s_dtype + if s_dtype == "complex64": + out_type = "float32" + elif s_dtype == "complex128": + out_type = "float64" + + return np.dtype(out_type) + + +def _get_lapack_ptr_and_ptr_type(dtype, name): + d = get_blas_kind(dtype) + func_name = f"{d}{name}" + float_pointer = _get_float_pointer_for_dtype(d) + lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) + + return lapack_ptr, float_pointer + + +def _check_scipy_linalg_matrix(a, func_name): + """ + Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 + """ + prefix = "scipy.linalg" + interp = (prefix, func_name) + # Unpack optional type + if isinstance(a, types.Optional): + a = a.type + if not isinstance(a, types.Array): + msg = "%s.%s() only supported for array types" % interp + raise numba.TypingError(msg, highlighting=False) + if a.ndim not in [1, 2]: + msg = "%s.%s() only supported on 1d or 2d arrays, found %s." % ( + interp + (a.ndim,) + ) + raise numba.TypingError(msg, highlighting=False) + if not isinstance(a.dtype, (types.Float, types.Complex)): + msg = "%s.%s() only supported on " "float and complex arrays." % interp + raise numba.TypingError(msg, highlighting=False) + + +class _LAPACK: + """ + Functions to return type signatures for wrapped LAPACK functions. + + Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 + """ + + def __init__(self): + ensure_lapack() + + @classmethod + def numba_xtrtrs(cls, dtype): + """ + Called by scipy.linalg.solve_triangular + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs") + + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # TRANS + _ptr_int, # DIAG + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, # INFO + ) + + return functype(lapack_ptr) + + +def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False): + return linalg.solve_triangular( + A, B, trans=trans, lower=lower, unit_diagonal=unit_diagonal + ) + + +@overload(_solve_triangular) +def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): + ensure_lapack() + + _check_scipy_linalg_matrix(A, "solve_triangular") + _check_scipy_linalg_matrix(B, "solve_triangular") + + dtype = A.dtype + if str(dtype).startswith("complex"): + raise ValueError( + "Complex inputs not currently supported by solve_triangular in Numba mode" + ) + + w_type = _get_underlying_float(dtype) + numba_trtrs = _LAPACK().numba_xtrtrs(dtype) + + def impl(A, B, trans=0, lower=False, unit_diagonal=False): + B_is_1d = B.ndim == 1 + + _N = np.int32(A.shape[-1]) + if A.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") + + if A.shape[0] != B.shape[0]: + raise linalg.LinAlgError("Dimensions of A and B do not conform") + + if B_is_1d: + B_copy = np.asfortranarray(np.expand_dims(B, -1)) + else: + B_copy = _copy_to_fortran_order(B) + + if trans not in [0, 1, 2]: + raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2') + if trans == 0: + transval = ord("N") + elif trans == 1: + transval = ord("T") + else: + transval = ord("C") + + B_NDIM = 1 if B_is_1d else int(B.shape[1]) + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + TRANS = val_to_int_ptr(transval) + DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N")) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(B_NDIM) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_trtrs( + UPLO, + TRANS, + DIAG, + N, + NRHS, + np.asfortranarray(A).T.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + if B_is_1d: + return B_copy[..., 0] + return B_copy + + return impl + + +@numba_funcify.register(SolveTriangular) +def numba_funcify_SolveTriangular(op, node, **kwargs): + trans = op.trans + lower = op.lower + unit_diagonal = op.unit_diagonal + check_finite = op.check_finite + + @numba_basic.numba_njit(inline="always") + def solve_triangular(a, b): + res = _solve_triangular(a, b, trans, lower, unit_diagonal) + if check_finite: + if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))): + raise ValueError( + "Non-numeric values (nan or inf) returned by solve_triangular" + ) + return res + + return solve_triangular diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 857bd49152..71a2fb7998 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -110,40 +110,6 @@ def test_Solve(A, x, lower, exc): ) -@pytest.mark.parametrize( - "A, x, lower, exc", - [ - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")), - "sym", - UserWarning, - ), - ], -) -def test_SolveTriangular(A, x, lower, exc): - g = slinalg.SolveTriangular(lower=lower, b_ndim=1)(A, x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - @pytest.mark.parametrize( "x, exc", [ diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py new file mode 100644 index 0000000000..75e016f1e0 --- /dev/null +++ b/tests/link/numba/test_slinalg.py @@ -0,0 +1,104 @@ +import re + +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor import config + + +numba = pytest.importorskip("numba") + + +ATOL = 0 if config.floatX.endswith("64") else 1e-6 +RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6 +rng = np.random.default_rng(42849) + + +def transpose_func(x, trans): + if trans == 0: + return x + if trans == 1: + return x.conj().T + if trans == 2: + return x.T + + +@pytest.mark.parametrize( + "b_func, b_size", + [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], + ids=["b_col_vec", "b_matrix", "b_vec"], +) +@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) +@pytest.mark.parametrize("trans", [0, 1, 2], ids=["trans=N", "trans=C", "trans=T"]) +@pytest.mark.parametrize( + "unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"] +) +@pytest.mark.parametrize("complex", [True, False], ids=["complex", "real"]) +@pytest.mark.filterwarnings( + 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' +) +def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex): + if complex: + # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, + # why? + pytest.skip("Complex inputs currently not supported to solve_triangular") + + complex_dtype = "complex64" if config.floatX.endswith("32") else "complex128" + dtype = complex_dtype if complex else config.floatX + + A = pt.matrix("A", dtype=dtype) + b = b_func("b", dtype=dtype) + + X = pt.linalg.solve_triangular( + A, b, lower=lower, trans=trans, unit_diagonal=unit_diag + ) + f = pytensor.function([A, b], X, mode="NUMBA") + + A_val = np.random.normal(size=(5, 5)) + b = np.random.normal(size=b_size) + + if complex: + A_val = A_val + np.random.normal(size=(5, 5)) * 1j + b = b + np.random.normal(size=b_size) * 1j + A_sym = A_val @ A_val.conj().T + + A_tri = np.linalg.cholesky(A_sym).astype(dtype) + if unit_diag: + adj_mat = np.ones((5, 5)) + adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri) + A_tri = A_tri * adj_mat + + A_tri = A_tri.astype(dtype) + b = b.astype(dtype) + + if not lower: + A_tri = A_tri.T + + X_np = f(A_tri, b) + np.testing.assert_allclose( + transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL + ) + + +@pytest.mark.parametrize("value", [np.nan, np.inf]) +@pytest.mark.filterwarnings( + 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' +) +def test_solve_triangular_raises_on_nan_inf(value): + A = pt.matrix("A") + b = pt.matrix("b") + + X = pt.linalg.solve_triangular(A, b, check_finite=True) + f = pytensor.function([A, b], X, mode="NUMBA") + A_val = np.random.normal(size=(5, 5)) + A_sym = A_val @ A_val.conj().T + + A_tri = np.linalg.cholesky(A_sym).astype(config.floatX) + b = np.full((5, 1), value) + + with pytest.raises( + ValueError, match=re.escape("Non-numeric values (nan or inf) returned ") + ): + f(A_tri, b)