Skip to content

Commit ba51e7d

Browse files
brandonwillardricardoV94
authored andcommitted
Fix PyTensor-to-Numba type resolution for sparse variables
1 parent e31655f commit ba51e7d

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
from pytensor.graph.fg import FunctionGraph
2626
from pytensor.graph.type import Type
2727
from pytensor.ifelse import IfElse
28+
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
2829
from pytensor.link.utils import (
2930
compile_function_src,
3031
fgraph_to_python,
3132
unique_name_generator,
3233
)
3334
from pytensor.scalar.basic import ScalarType
3435
from pytensor.scalar.math import Softplus
36+
from pytensor.sparse import SparseTensorType
3537
from pytensor.tensor.blas import BatchedDot
3638
from pytensor.tensor.math import Dot
3739
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@@ -105,6 +107,15 @@ def get_numba_type(
105107
dtype = np.dtype(pytensor_type.dtype)
106108
numba_dtype = numba.from_dtype(dtype)
107109
return numba_dtype
110+
elif isinstance(pytensor_type, SparseTensorType):
111+
dtype = pytensor_type.numpy_dtype
112+
numba_dtype = numba.from_dtype(dtype)
113+
if pytensor_type.format == "csr":
114+
return CSRMatrixType(numba_dtype)
115+
if pytensor_type.format == "csc":
116+
return CSCMatrixType(numba_dtype)
117+
118+
raise NotImplementedError()
108119
else:
109120
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
110121

pytensor/link/numba/dispatch/sparse.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ class CSMatrixType(types.Type):
1919
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`."""
2020

2121
name: str
22-
instance_class: type
22+
23+
@staticmethod
24+
def instance_class(data, indices, indptr, shape):
25+
raise NotImplementedError()
2326

2427
def __init__(self, dtype):
2528
self.dtype = dtype
@@ -29,6 +32,10 @@ def __init__(self, dtype):
2932
self.shape = types.UniTuple(types.int64, 2)
3033
super().__init__(self.name)
3134

35+
@property
36+
def key(self):
37+
return (self.name, self.dtype)
38+
3239

3340
make_attribute_wrapper(CSMatrixType, "data", "data")
3441
make_attribute_wrapper(CSMatrixType, "indices", "indices")
@@ -152,7 +159,6 @@ def overload_sparse_shape(x):
152159

153160
@overload_attribute(CSMatrixType, "ndim")
154161
def overload_sparse_ndim(inst):
155-
156162
if not isinstance(inst, CSMatrixType):
157163
return
158164

tests/link/numba/test_sparse.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
import numba
22
import numpy as np
3+
import pytest
34
import scipy as sp
45

5-
# Load Numba customizations
6+
# Make sure the Numba customizations are loaded
67
import pytensor.link.numba.dispatch.sparse # noqa: F401
8+
from pytensor import config
9+
from pytensor.sparse import Dot, SparseTensorType
10+
from tests.link.numba.test_basic import compare_numba_and_py
11+
12+
13+
pytestmark = pytest.mark.filterwarnings("error")
714

815

916
def test_sparse_unboxing():
@@ -62,3 +69,19 @@ def test_fn(x):
6269
res = test_fn(x_val)
6370

6471
assert res == 2
72+
73+
74+
def test_sparse_objmode():
75+
x = SparseTensorType("csc", dtype=config.floatX)()
76+
y = SparseTensorType("csc", dtype=config.floatX)()
77+
78+
out = Dot()(x, y)
79+
80+
x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
81+
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
82+
83+
with pytest.warns(
84+
UserWarning,
85+
match="Numba will use object mode to run SparseDot's perform method",
86+
):
87+
compare_numba_and_py(((x, y), (out,)), [x_val, y_val])

0 commit comments

Comments
 (0)