Skip to content

Commit 2d94407

Browse files
Add numba overload for solve_triangular (#423)
* Add numba overload for * Overload dummy function instead of scipy.linalg * Add tolerance for float32 tests * Add tolerance for float32 tests * Remove overload test * Allow b to be 1d array Remove test_SolveTriangular from numba\test_nlinalg.py * Allow b to be 1d array Remove test_SolveTriangular from numba\test_nlinalg.py * revert local change to pyproject.toml * add numba importorskip to test_slinalg.py * Test all parameterizations of solve_triangular * Raise when inputs are complex Add informative message to error raised by check_finite=True * simplify check for complex input types * simplify check for complex input types * Rename _get_addr_and_float_pointer to _get_lapack_ptr_and_ptr_type Rename addr to lapack_ptr * Don't copy A matrix in overload func Don't copy B matrix when B is array in overload func
1 parent 071eadd commit 2d94407

File tree

4 files changed

+380
-34
lines changed

4 files changed

+380
-34
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
import pytensor.link.numba.dispatch.elemwise
1111
import pytensor.link.numba.dispatch.scan
1212
import pytensor.link.numba.dispatch.sparse
13+
import pytensor.link.numba.dispatch.slinalg
1314

1415
# isort: on
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
import ctypes
2+
3+
import numba
4+
import numpy as np
5+
from numba.core import cgutils, types
6+
from numba.extending import get_cython_function_address, intrinsic, overload
7+
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack, get_blas_kind
8+
from scipy import linalg
9+
10+
from pytensor.link.numba.dispatch import basic as numba_basic
11+
from pytensor.link.numba.dispatch.basic import numba_funcify
12+
from pytensor.tensor.slinalg import SolveTriangular
13+
14+
15+
_PTR = ctypes.POINTER
16+
17+
_dbl = ctypes.c_double
18+
_float = ctypes.c_float
19+
_char = ctypes.c_char
20+
_int = ctypes.c_int
21+
22+
_ptr_float = _PTR(_float)
23+
_ptr_dbl = _PTR(_dbl)
24+
_ptr_char = _PTR(_char)
25+
_ptr_int = _PTR(_int)
26+
27+
28+
@intrinsic
29+
def val_to_dptr(typingctx, data):
30+
def impl(context, builder, signature, args):
31+
ptr = cgutils.alloca_once_value(builder, args[0])
32+
return ptr
33+
34+
sig = types.CPointer(types.float64)(types.float64)
35+
return sig, impl
36+
37+
38+
@intrinsic
39+
def val_to_zptr(typingctx, data):
40+
def impl(context, builder, signature, args):
41+
ptr = cgutils.alloca_once_value(builder, args[0])
42+
return ptr
43+
44+
sig = types.CPointer(types.complex128)(types.complex128)
45+
return sig, impl
46+
47+
48+
@intrinsic
49+
def val_to_sptr(typingctx, data):
50+
def impl(context, builder, signature, args):
51+
ptr = cgutils.alloca_once_value(builder, args[0])
52+
return ptr
53+
54+
sig = types.CPointer(types.float32)(types.float32)
55+
return sig, impl
56+
57+
58+
@intrinsic
59+
def val_to_int_ptr(typingctx, data):
60+
def impl(context, builder, signature, args):
61+
ptr = cgutils.alloca_once_value(builder, args[0])
62+
return ptr
63+
64+
sig = types.CPointer(types.int32)(types.int32)
65+
return sig, impl
66+
67+
68+
@intrinsic
69+
def int_ptr_to_val(typingctx, data):
70+
def impl(context, builder, signature, args):
71+
val = builder.load(args[0])
72+
return val
73+
74+
sig = types.int32(types.CPointer(types.int32))
75+
return sig, impl
76+
77+
78+
@intrinsic
79+
def dptr_to_val(typingctx, data):
80+
def impl(context, builder, signature, args):
81+
val = builder.load(args[0])
82+
return val
83+
84+
sig = types.float64(types.CPointer(types.float64))
85+
return sig, impl
86+
87+
88+
@intrinsic
89+
def sptr_to_val(typingctx, data):
90+
def impl(context, builder, signature, args):
91+
val = builder.load(args[0])
92+
return val
93+
94+
sig = types.float32(types.CPointer(types.float32))
95+
return sig, impl
96+
97+
98+
def _get_float_pointer_for_dtype(blas_dtype):
99+
if blas_dtype in ["s", "c"]:
100+
return _ptr_float
101+
elif blas_dtype in ["d", "z"]:
102+
return _ptr_dbl
103+
104+
105+
def _get_underlying_float(dtype):
106+
s_dtype = str(dtype)
107+
out_type = s_dtype
108+
if s_dtype == "complex64":
109+
out_type = "float32"
110+
elif s_dtype == "complex128":
111+
out_type = "float64"
112+
113+
return np.dtype(out_type)
114+
115+
116+
def _get_lapack_ptr_and_ptr_type(dtype, name):
117+
d = get_blas_kind(dtype)
118+
func_name = f"{d}{name}"
119+
float_pointer = _get_float_pointer_for_dtype(d)
120+
lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
121+
122+
return lapack_ptr, float_pointer
123+
124+
125+
def _check_scipy_linalg_matrix(a, func_name):
126+
"""
127+
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
128+
"""
129+
prefix = "scipy.linalg"
130+
interp = (prefix, func_name)
131+
# Unpack optional type
132+
if isinstance(a, types.Optional):
133+
a = a.type
134+
if not isinstance(a, types.Array):
135+
msg = "%s.%s() only supported for array types" % interp
136+
raise numba.TypingError(msg, highlighting=False)
137+
if a.ndim not in [1, 2]:
138+
msg = "%s.%s() only supported on 1d or 2d arrays, found %s." % (
139+
interp + (a.ndim,)
140+
)
141+
raise numba.TypingError(msg, highlighting=False)
142+
if not isinstance(a.dtype, (types.Float, types.Complex)):
143+
msg = "%s.%s() only supported on " "float and complex arrays." % interp
144+
raise numba.TypingError(msg, highlighting=False)
145+
146+
147+
class _LAPACK:
148+
"""
149+
Functions to return type signatures for wrapped LAPACK functions.
150+
151+
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
152+
"""
153+
154+
def __init__(self):
155+
ensure_lapack()
156+
157+
@classmethod
158+
def numba_xtrtrs(cls, dtype):
159+
"""
160+
Called by scipy.linalg.solve_triangular
161+
"""
162+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs")
163+
164+
functype = ctypes.CFUNCTYPE(
165+
None,
166+
_ptr_int, # UPLO
167+
_ptr_int, # TRANS
168+
_ptr_int, # DIAG
169+
_ptr_int, # N
170+
_ptr_int, # NRHS
171+
float_pointer, # A
172+
_ptr_int, # LDA
173+
float_pointer, # B
174+
_ptr_int, # LDB
175+
_ptr_int, # INFO
176+
)
177+
178+
return functype(lapack_ptr)
179+
180+
181+
def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False):
182+
return linalg.solve_triangular(
183+
A, B, trans=trans, lower=lower, unit_diagonal=unit_diagonal
184+
)
185+
186+
187+
@overload(_solve_triangular)
188+
def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
189+
ensure_lapack()
190+
191+
_check_scipy_linalg_matrix(A, "solve_triangular")
192+
_check_scipy_linalg_matrix(B, "solve_triangular")
193+
194+
dtype = A.dtype
195+
if str(dtype).startswith("complex"):
196+
raise ValueError(
197+
"Complex inputs not currently supported by solve_triangular in Numba mode"
198+
)
199+
200+
w_type = _get_underlying_float(dtype)
201+
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
202+
203+
def impl(A, B, trans=0, lower=False, unit_diagonal=False):
204+
B_is_1d = B.ndim == 1
205+
206+
_N = np.int32(A.shape[-1])
207+
if A.shape[-2] != _N:
208+
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
209+
210+
if A.shape[0] != B.shape[0]:
211+
raise linalg.LinAlgError("Dimensions of A and B do not conform")
212+
213+
if B_is_1d:
214+
B_copy = np.asfortranarray(np.expand_dims(B, -1))
215+
else:
216+
B_copy = _copy_to_fortran_order(B)
217+
218+
if trans not in [0, 1, 2]:
219+
raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2')
220+
if trans == 0:
221+
transval = ord("N")
222+
elif trans == 1:
223+
transval = ord("T")
224+
else:
225+
transval = ord("C")
226+
227+
B_NDIM = 1 if B_is_1d else int(B.shape[1])
228+
229+
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
230+
TRANS = val_to_int_ptr(transval)
231+
DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N"))
232+
N = val_to_int_ptr(_N)
233+
NRHS = val_to_int_ptr(B_NDIM)
234+
LDA = val_to_int_ptr(_N)
235+
LDB = val_to_int_ptr(_N)
236+
INFO = val_to_int_ptr(0)
237+
238+
numba_trtrs(
239+
UPLO,
240+
TRANS,
241+
DIAG,
242+
N,
243+
NRHS,
244+
np.asfortranarray(A).T.view(w_type).ctypes,
245+
LDA,
246+
B_copy.view(w_type).ctypes,
247+
LDB,
248+
INFO,
249+
)
250+
251+
if B_is_1d:
252+
return B_copy[..., 0]
253+
return B_copy
254+
255+
return impl
256+
257+
258+
@numba_funcify.register(SolveTriangular)
259+
def numba_funcify_SolveTriangular(op, node, **kwargs):
260+
trans = op.trans
261+
lower = op.lower
262+
unit_diagonal = op.unit_diagonal
263+
check_finite = op.check_finite
264+
265+
@numba_basic.numba_njit(inline="always")
266+
def solve_triangular(a, b):
267+
res = _solve_triangular(a, b, trans, lower, unit_diagonal)
268+
if check_finite:
269+
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))):
270+
raise ValueError(
271+
"Non-numeric values (nan or inf) returned by solve_triangular"
272+
)
273+
return res
274+
275+
return solve_triangular

tests/link/numba/test_nlinalg.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -110,40 +110,6 @@ def test_Solve(A, x, lower, exc):
110110
)
111111

112112

113-
@pytest.mark.parametrize(
114-
"A, x, lower, exc",
115-
[
116-
(
117-
set_test_value(
118-
at.dmatrix(),
119-
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
120-
),
121-
set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")),
122-
"sym",
123-
UserWarning,
124-
),
125-
],
126-
)
127-
def test_SolveTriangular(A, x, lower, exc):
128-
g = slinalg.SolveTriangular(lower=lower, b_ndim=1)(A, x)
129-
130-
if isinstance(g, list):
131-
g_fg = FunctionGraph(outputs=g)
132-
else:
133-
g_fg = FunctionGraph(outputs=[g])
134-
135-
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
136-
with cm:
137-
compare_numba_and_py(
138-
g_fg,
139-
[
140-
i.tag.test_value
141-
for i in g_fg.inputs
142-
if not isinstance(i, (SharedVariable, Constant))
143-
],
144-
)
145-
146-
147113
@pytest.mark.parametrize(
148114
"x, exc",
149115
[

0 commit comments

Comments
 (0)