-
Notifications
You must be signed in to change notification settings - Fork 133
Add numba overload for solve_triangular
#423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
4348da7
c9f5f4f
08de8fa
a323a1d
268f583
3837a28
74f3965
3e1c85a
15e23b9
7649d59
c7a1f28
296fec3
a779687
0f4f197
dd8cfed
519b1c5
7ccd3df
42b5b5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
import ctypes | ||
|
||
import numba | ||
import numpy as np | ||
import scipy | ||
from numba.core import cgutils, types | ||
from numba.extending import get_cython_function_address, intrinsic, overload | ||
from numba.np.linalg import _blas_kinds, _copy_to_fortran_order, ensure_lapack | ||
from scipy import linalg | ||
|
||
from pytensor.link.numba.dispatch import basic as numba_basic | ||
from pytensor.link.numba.dispatch.basic import numba_funcify | ||
from pytensor.tensor.slinalg import SolveTriangular | ||
|
||
|
||
_PTR = ctypes.POINTER | ||
|
||
_dbl = ctypes.c_double | ||
_float = ctypes.c_float | ||
_char = ctypes.c_char | ||
_int = ctypes.c_int | ||
|
||
_ptr_float = _PTR(_float) | ||
_ptr_dbl = _PTR(_dbl) | ||
_ptr_char = _PTR(_char) | ||
_ptr_int = _PTR(_int) | ||
|
||
|
||
@intrinsic | ||
def val_to_dptr(typingctx, data): | ||
def impl(context, builder, signature, args): | ||
ptr = cgutils.alloca_once_value(builder, args[0]) | ||
return ptr | ||
|
||
sig = types.CPointer(types.float64)(types.float64) | ||
return sig, impl | ||
|
||
|
||
@intrinsic | ||
def val_to_zptr(typingctx, data): | ||
def impl(context, builder, signature, args): | ||
ptr = cgutils.alloca_once_value(builder, args[0]) | ||
return ptr | ||
|
||
sig = types.CPointer(types.complex128)(types.complex128) | ||
return sig, impl | ||
|
||
|
||
@intrinsic | ||
def val_to_sptr(typingctx, data): | ||
def impl(context, builder, signature, args): | ||
ptr = cgutils.alloca_once_value(builder, args[0]) | ||
return ptr | ||
|
||
sig = types.CPointer(types.float32)(types.float32) | ||
return sig, impl | ||
|
||
|
||
@intrinsic | ||
def val_to_int_ptr(typingctx, data): | ||
def impl(context, builder, signature, args): | ||
ptr = cgutils.alloca_once_value(builder, args[0]) | ||
return ptr | ||
|
||
sig = types.CPointer(types.int32)(types.int32) | ||
return sig, impl | ||
|
||
|
||
@intrinsic | ||
def int_ptr_to_val(typingctx, data): | ||
def impl(context, builder, signature, args): | ||
val = builder.load(args[0]) | ||
return val | ||
|
||
sig = types.int32(types.CPointer(types.int32)) | ||
return sig, impl | ||
|
||
|
||
@intrinsic | ||
def dptr_to_val(typingctx, data): | ||
def impl(context, builder, signature, args): | ||
val = builder.load(args[0]) | ||
return val | ||
|
||
sig = types.float64(types.CPointer(types.float64)) | ||
return sig, impl | ||
|
||
|
||
@intrinsic | ||
def sptr_to_val(typingctx, data): | ||
def impl(context, builder, signature, args): | ||
val = builder.load(args[0]) | ||
return val | ||
|
||
sig = types.float32(types.CPointer(types.float32)) | ||
return sig, impl | ||
|
||
|
||
def _get_float_pointer_for_dtype(blas_dtype): | ||
if blas_dtype in ["s", "c"]: | ||
return _ptr_float | ||
elif blas_dtype in ["d", "z"]: | ||
return _ptr_dbl | ||
|
||
|
||
def _get_underlying_float(dtype): | ||
s_dtype = str(dtype) | ||
out_type = s_dtype | ||
if s_dtype == "complex64": | ||
out_type = "float32" | ||
elif s_dtype == "complex128": | ||
out_type = "float64" | ||
|
||
return np.dtype(out_type) | ||
|
||
|
||
def _get_addr_and_float_pointer(dtype, name): | ||
d = _blas_kinds[dtype] | ||
func_name = f"{d}{name}" | ||
float_pointer = _get_float_pointer_for_dtype(d) | ||
addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) | ||
|
||
return addr, float_pointer | ||
|
||
|
||
def _check_scipy_linalg_matrix(a, func_name): | ||
""" | ||
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 | ||
""" | ||
prefix = "scipy.linalg" | ||
interp = (prefix, func_name) | ||
# Unpack optional type | ||
if isinstance(a, types.Optional): | ||
a = a.type | ||
if not isinstance(a, types.Array): | ||
msg = "%s.%s() only supported for array types" % interp | ||
raise numba.TypingError(msg, highlighting=False) | ||
if not a.ndim == 2: | ||
msg = "%s.%s() only supported on 2-D arrays." % interp | ||
raise numba.TypingError(msg, highlighting=False) | ||
if not isinstance(a.dtype, (types.Float, types.Complex)): | ||
msg = "%s.%s() only supported on " "float and complex arrays." % interp | ||
raise numba.TypingError(msg, highlighting=False) | ||
|
||
|
||
class _LAPACK: | ||
""" | ||
Functions to return type signatures for wrapped LAPACK functions. | ||
|
||
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 | ||
""" | ||
|
||
def __init__(self): | ||
ensure_lapack() | ||
|
||
@classmethod | ||
def test_blas_kinds(cls, dtype): | ||
return _blas_kinds[dtype] | ||
|
||
@classmethod | ||
def numba_xtrtrs(cls, dtype): | ||
""" | ||
Called by scipy.linalg.solve_triangular | ||
""" | ||
d = _blas_kinds[dtype] | ||
func_name = f"{d}trtrs" | ||
float_pointer = _get_float_pointer_for_dtype(d) | ||
|
||
addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) | ||
functype = ctypes.CFUNCTYPE( | ||
None, | ||
_ptr_int, # UPLO | ||
_ptr_int, # TRANS | ||
_ptr_int, # DIAG | ||
_ptr_int, # N | ||
_ptr_int, # NRHS | ||
float_pointer, # A | ||
_ptr_int, # LDA | ||
float_pointer, # B | ||
_ptr_int, # LDB | ||
_ptr_int, | ||
) # INFO | ||
|
||
return functype(addr) | ||
|
||
|
||
@overload(scipy.linalg.solve_triangular) | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): | ||
ensure_lapack() | ||
|
||
_check_scipy_linalg_matrix(A, "solve_triangular") | ||
_check_scipy_linalg_matrix(B, "solve_triangular") | ||
|
||
dtype = A.dtype | ||
w_type = _get_underlying_float(dtype) | ||
|
||
numba_trtrs = _LAPACK().numba_xtrtrs(dtype) | ||
|
||
def impl(A, B, trans=0, lower=False, unit_diagonal=False): | ||
_N = np.int32(A.shape[-1]) | ||
if A.shape[-2] != _N: | ||
raise linalg.LinAlgError("Last 2 dimensions of A must be square") | ||
|
||
if A.shape[0] != B.shape[0]: | ||
raise linalg.LinAlgError("Dimensions of A and B do not conform") | ||
|
||
A_copy = _copy_to_fortran_order(A) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we avoid the copy if it is c-order by flipping transval? I think we could also have a special overload for when trans, lower and unit_diag are literals, and we statically know that A and B are C or Fortran continuous. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does setting an array to fortran contiguous actually transpose the matrix, or does it just re-order the pointers to the internal flat representation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After testing, we can avoid copying A in all cases. Re: the other point, do you mean checking the values of |
||
B_copy = _copy_to_fortran_order(B) | ||
|
||
# if isinstance(trans, str): | ||
# if trans not in ['N', 'C', 'T']: | ||
# raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2') | ||
# transval = ord(trans) | ||
|
||
# else: | ||
if trans not in [0, 1, 2]: | ||
raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2') | ||
if trans == 0: | ||
transval = ord("N") | ||
elif trans == 1: | ||
transval = ord("T") | ||
else: | ||
transval = ord("C") | ||
|
||
UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) | ||
TRANS = val_to_int_ptr(transval) | ||
DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N")) | ||
N = val_to_int_ptr(_N) | ||
NRHS = val_to_int_ptr(B.shape[1]) | ||
LDA = val_to_int_ptr(_N) | ||
LDB = val_to_int_ptr(_N) | ||
INFO = val_to_int_ptr(0) | ||
|
||
numba_trtrs( | ||
UPLO, | ||
TRANS, | ||
DIAG, | ||
N, | ||
NRHS, | ||
A_copy.view(w_type).ctypes, | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
LDA, | ||
B_copy.view(w_type).ctypes, | ||
LDB, | ||
INFO, | ||
) | ||
|
||
return B_copy | ||
|
||
return impl | ||
|
||
|
||
@numba_funcify.register(SolveTriangular) | ||
def numba_funcify_SolveTriangular(op, node, **kwargs): | ||
trans = op.trans | ||
lower = op.lower | ||
unit_diagonal = op.unit_diagonal | ||
check_finite = op.check_finite | ||
|
||
@numba_basic.numba_njit(inline="always") | ||
def solve_triangular(a, b): | ||
res = scipy.linalg.solve_triangular(a, b, trans, lower, unit_diagonal) | ||
if check_finite: | ||
if np.any(np.isinf(res)): | ||
raise ValueError | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return res | ||
|
||
return solve_triangular |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
|
||
import pytensor | ||
import pytensor.tensor as pt | ||
from pytensor import config | ||
|
||
|
||
def test_solve_triangular(): | ||
A = pt.matrix("A") | ||
b = pt.matrix("b") | ||
|
||
X = pt.linalg.solve_triangular(A, b, lower=True) | ||
f = pytensor.function([A, b], X, mode="NUMBA") | ||
|
||
A_val = np.random.normal(size=(5, 5)).astype(config.floatX) | ||
A_sym = A_val @ A_val.T | ||
A_tri = np.linalg.cholesky(A_sym) | ||
|
||
b = np.random.normal(size=(5, 1)).astype(config.floatX) | ||
|
||
X_np = f(A_tri, b) | ||
np.testing.assert_allclose(A_tri @ X_np, b) |
Uh oh!
There was an error while loading. Please reload this page.