Skip to content

Add numba overload for solve_triangular #423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
import pytensor.link.numba.dispatch.elemwise
import pytensor.link.numba.dispatch.scan
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.slinalg

# isort: on
267 changes: 267 additions & 0 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
import ctypes

import numba
import numpy as np
import scipy
from numba.core import cgutils, types
from numba.extending import get_cython_function_address, intrinsic, overload
from numba.np.linalg import _blas_kinds, _copy_to_fortran_order, ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import SolveTriangular


_PTR = ctypes.POINTER

_dbl = ctypes.c_double
_float = ctypes.c_float
_char = ctypes.c_char
_int = ctypes.c_int

_ptr_float = _PTR(_float)
_ptr_dbl = _PTR(_dbl)
_ptr_char = _PTR(_char)
_ptr_int = _PTR(_int)


@intrinsic
def val_to_dptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr

sig = types.CPointer(types.float64)(types.float64)
return sig, impl


@intrinsic
def val_to_zptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr

sig = types.CPointer(types.complex128)(types.complex128)
return sig, impl


@intrinsic
def val_to_sptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr

sig = types.CPointer(types.float32)(types.float32)
return sig, impl


@intrinsic
def val_to_int_ptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr

sig = types.CPointer(types.int32)(types.int32)
return sig, impl


@intrinsic
def int_ptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val

sig = types.int32(types.CPointer(types.int32))
return sig, impl


@intrinsic
def dptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val

sig = types.float64(types.CPointer(types.float64))
return sig, impl


@intrinsic
def sptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val

sig = types.float32(types.CPointer(types.float32))
return sig, impl


def _get_float_pointer_for_dtype(blas_dtype):
if blas_dtype in ["s", "c"]:
return _ptr_float
elif blas_dtype in ["d", "z"]:
return _ptr_dbl


def _get_underlying_float(dtype):
s_dtype = str(dtype)
out_type = s_dtype
if s_dtype == "complex64":
out_type = "float32"
elif s_dtype == "complex128":
out_type = "float64"

return np.dtype(out_type)


def _get_addr_and_float_pointer(dtype, name):
d = _blas_kinds[dtype]
func_name = f"{d}{name}"
float_pointer = _get_float_pointer_for_dtype(d)
addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)

return addr, float_pointer


def _check_scipy_linalg_matrix(a, func_name):
"""
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
"""
prefix = "scipy.linalg"
interp = (prefix, func_name)
# Unpack optional type
if isinstance(a, types.Optional):
a = a.type
if not isinstance(a, types.Array):
msg = "%s.%s() only supported for array types" % interp
raise numba.TypingError(msg, highlighting=False)
if not a.ndim == 2:
msg = "%s.%s() only supported on 2-D arrays." % interp
raise numba.TypingError(msg, highlighting=False)
if not isinstance(a.dtype, (types.Float, types.Complex)):
msg = "%s.%s() only supported on " "float and complex arrays." % interp
raise numba.TypingError(msg, highlighting=False)


class _LAPACK:
"""
Functions to return type signatures for wrapped LAPACK functions.

Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
"""

def __init__(self):
ensure_lapack()

@classmethod
def test_blas_kinds(cls, dtype):
return _blas_kinds[dtype]

@classmethod
def numba_xtrtrs(cls, dtype):
"""
Called by scipy.linalg.solve_triangular
"""
d = _blas_kinds[dtype]
func_name = f"{d}trtrs"
float_pointer = _get_float_pointer_for_dtype(d)

addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO
_ptr_int, # TRANS
_ptr_int, # DIAG
_ptr_int, # N
_ptr_int, # NRHS
float_pointer, # A
_ptr_int, # LDA
float_pointer, # B
_ptr_int, # LDB
_ptr_int,
) # INFO

return functype(addr)


@overload(scipy.linalg.solve_triangular)
def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
ensure_lapack()

_check_scipy_linalg_matrix(A, "solve_triangular")
_check_scipy_linalg_matrix(B, "solve_triangular")

dtype = A.dtype
w_type = _get_underlying_float(dtype)

numba_trtrs = _LAPACK().numba_xtrtrs(dtype)

def impl(A, B, trans=0, lower=False, unit_diagonal=False):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")

if A.shape[0] != B.shape[0]:
raise linalg.LinAlgError("Dimensions of A and B do not conform")

A_copy = _copy_to_fortran_order(A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid the copy if it is c-order by flipping transval? I think we could also have a special overload for when trans, lower and unit_diag are literals, and we statically know that A and B are C or Fortran continuous.
I think that would really be only an optimization of the current code though, this here should be fine as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does setting an array to fortran contiguous actually transpose the matrix, or does it just re-order the pointers to the internal flat representation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After testing, we can avoid copying A in all cases.

Re: the other point, do you mean checking the values of trans, lower, and unit_diag inside the wrapper function, then returning a specialized impl function based on their values? Similar to how I'm doing dispatching to real/complex versions here?

B_copy = _copy_to_fortran_order(B)

# if isinstance(trans, str):
# if trans not in ['N', 'C', 'T']:
# raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2')
# transval = ord(trans)

# else:
if trans not in [0, 1, 2]:
raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2')
if trans == 0:
transval = ord("N")
elif trans == 1:
transval = ord("T")
else:
transval = ord("C")

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
TRANS = val_to_int_ptr(transval)
DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N"))
N = val_to_int_ptr(_N)
NRHS = val_to_int_ptr(B.shape[1])
LDA = val_to_int_ptr(_N)
LDB = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

numba_trtrs(
UPLO,
TRANS,
DIAG,
N,
NRHS,
A_copy.view(w_type).ctypes,
LDA,
B_copy.view(w_type).ctypes,
LDB,
INFO,
)

return B_copy

return impl


@numba_funcify.register(SolveTriangular)
def numba_funcify_SolveTriangular(op, node, **kwargs):
trans = op.trans
lower = op.lower
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite

@numba_basic.numba_njit(inline="always")
def solve_triangular(a, b):
res = scipy.linalg.solve_triangular(a, b, trans, lower, unit_diagonal)
if check_finite:
if np.any(np.isinf(res)):
raise ValueError
return res

return solve_triangular
22 changes: 22 additions & 0 deletions tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np

import pytensor
import pytensor.tensor as pt
from pytensor import config


def test_solve_triangular():
A = pt.matrix("A")
b = pt.matrix("b")

X = pt.linalg.solve_triangular(A, b, lower=True)
f = pytensor.function([A, b], X, mode="NUMBA")

A_val = np.random.normal(size=(5, 5)).astype(config.floatX)
A_sym = A_val @ A_val.T
A_tri = np.linalg.cholesky(A_sym)

b = np.random.normal(size=(5, 1)).astype(config.floatX)

X_np = f(A_tri, b)
np.testing.assert_allclose(A_tri @ X_np, b)