Skip to content

Commit b757e7e

Browse files
brandonwillardtwiecki
authored andcommitted
Fix PyTensor-to-Numba type resolution for sparse variables
1 parent bf678ca commit b757e7e

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

pytensor/link/numba/dispatch/sparse.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@
1313
typeof_impl,
1414
unbox,
1515
)
16+
from numba.np.numpy_support import from_dtype
17+
18+
from pytensor.link.numba.dispatch.basic import get_numba_type
19+
from pytensor.sparse.type import SparseTensorType
1620

1721

1822
class CSMatrixType(types.Type):
1923
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`."""
2024

2125
name: str
22-
instance_class: type
26+
27+
@staticmethod
28+
def instance_class(data, indices, indptr, shape):
29+
raise NotImplementedError()
2330

2431
def __init__(self, dtype):
2532
self.dtype = dtype
@@ -29,6 +36,10 @@ def __init__(self, dtype):
2936
self.shape = types.UniTuple(types.int64, 2)
3037
super().__init__(self.name)
3138

39+
@property
40+
def key(self):
41+
return (self.name, self.dtype)
42+
3243

3344
make_attribute_wrapper(CSMatrixType, "data", "data")
3445
make_attribute_wrapper(CSMatrixType, "indices", "indices")
@@ -161,3 +172,15 @@ def ndim(inst):
161172
return 2
162173

163174
return ndim
175+
176+
177+
@get_numba_type.register(SparseTensorType)
178+
def get_numba_type_SparseType(pytensor_type, **kwargs):
179+
dtype = from_dtype(np.dtype(pytensor_type.dtype))
180+
181+
if pytensor_type.format == "csr":
182+
return CSRMatrixType(dtype)
183+
if pytensor_type.format == "csc":
184+
return CSCMatrixType(dtype)
185+
186+
raise NotImplementedError()

tests/link/numba/test_sparse.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
import numba
22
import numpy as np
3+
import pytest
34
import scipy as sp
45

5-
# Load Numba customizations
6+
# Make sure the Numba customizations are loaded
67
import pytensor.link.numba.dispatch.sparse # noqa: F401
8+
from pytensor import config
9+
from pytensor.sparse import Dot, SparseTensorType
10+
from tests.link.numba.test_basic import compare_numba_and_py
11+
12+
13+
pytestmark = pytest.mark.filterwarnings("error")
714

815

916
def test_sparse_unboxing():
@@ -62,3 +69,17 @@ def test_fn(x):
6269
res = test_fn(x_val)
6370

6471
assert res == 2
72+
73+
74+
def test_sparse_objmode():
75+
76+
x = SparseTensorType("csc", dtype=config.floatX)()
77+
y = SparseTensorType("csc", dtype=config.floatX)()
78+
79+
out = Dot()(x, y)
80+
81+
x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
82+
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
83+
84+
with pytest.warns(UserWarning):
85+
compare_numba_and_py(((x, y), (out,)), [x_val, y_val])

0 commit comments

Comments
 (0)