Skip to content

🔄 From Aesara: 1337: Add basic overloads for Numba sparse types #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import (
compile_function_src,
fgraph_to_python,
unique_name_generator,
)
from pytensor.scalar.basic import ScalarType
from pytensor.scalar.math import Softplus
from pytensor.sparse import SparseTensorType
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
Expand Down Expand Up @@ -105,6 +107,15 @@ def get_numba_type(
dtype = np.dtype(pytensor_type.dtype)
numba_dtype = numba.from_dtype(dtype)
return numba_dtype
elif isinstance(pytensor_type, SparseTensorType):
dtype = pytensor_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
if pytensor_type.format == "csr":
return CSRMatrixType(numba_dtype)
if pytensor_type.format == "csc":
return CSCMatrixType(numba_dtype)

raise NotImplementedError()
else:
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")

Expand Down
67 changes: 66 additions & 1 deletion pytensor/link/numba/dispatch/sparse.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import numpy as np
import scipy as sp
import scipy.sparse
from numba.core import cgutils, types
from numba.core.imputils import impl_ret_borrowed
from numba.extending import (
NativeValue,
box,
intrinsic,
make_attribute_wrapper,
models,
overload,
overload_attribute,
overload_method,
register_model,
typeof_impl,
unbox,
Expand All @@ -16,7 +22,10 @@ class CSMatrixType(types.Type):
"""A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`."""

name: str
instance_class: type

@staticmethod
def instance_class(data, indices, indptr, shape):
raise NotImplementedError()

def __init__(self, dtype):
self.dtype = dtype
Expand All @@ -26,6 +35,10 @@ def __init__(self, dtype):
self.shape = types.UniTuple(types.int64, 2)
super().__init__(self.name)

@property
def key(self):
return (self.name, self.dtype)


make_attribute_wrapper(CSMatrixType, "data", "data")
make_attribute_wrapper(CSMatrixType, "indices", "indices")
Expand Down Expand Up @@ -139,3 +152,55 @@ def box_matrix(typ, val, c):
c.pyapi.decref(shape_obj)

return obj


@overload(np.shape)
def overload_sparse_shape(x):
if isinstance(x, CSMatrixType):
return lambda x: x.shape


@overload_attribute(CSMatrixType, "ndim")
def overload_sparse_ndim(inst):
if not isinstance(inst, CSMatrixType):
return

def ndim(inst):
return 2

return ndim


@intrinsic
def _sparse_copy(typingctx, inst, data, indices, indptr, shape):
def _construct(context, builder, sig, args):
typ = sig.return_type
struct = cgutils.create_struct_proxy(typ)(context, builder)
_, data, indices, indptr, shape = args
struct.data = data
struct.indices = indices
struct.indptr = indptr
struct.shape = shape
return impl_ret_borrowed(
context,
builder,
sig.return_type,
struct._getvalue(),
)

sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape)

return sig, _construct


@overload_method(CSMatrixType, "copy")
def overload_sparse_copy(inst):
if not isinstance(inst, CSMatrixType):
return

def copy(inst):
return _sparse_copy(
inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape
)

return copy
2 changes: 1 addition & 1 deletion pytensor/link/numba/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def output_filter(self, var: "Variable", out: Any) -> Any:
if not isinstance(var, np.ndarray) and isinstance(
var.type, pytensor.tensor.TensorType
):
return np.asarray(out, dtype=var.type.dtype)
return var.type.filter(out, allow_downcast=True)

return out

Expand Down
4 changes: 1 addition & 3 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ def set_test_value(x, v):


def compare_shape_dtype(x, y):
(x,) = x
(y,) = y
return x.shape == y.shape and x.dtype == y.dtype


Expand Down Expand Up @@ -286,7 +284,7 @@ def assert_fn(x, y):
for j, p in zip(numba_res, py_res):
assert_fn(j, p)
else:
assert_fn(numba_res, py_res)
assert_fn(numba_res[0], py_res[0])

return pytensor_numba_fn, numba_res

Expand Down
62 changes: 61 additions & 1 deletion tests/link/numba/test_sparse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import numba
import numpy as np
import pytest
import scipy as sp

# Load Numba customizations
# Make sure the Numba customizations are loaded
import pytensor.link.numba.dispatch.sparse # noqa: F401
from pytensor import config
from pytensor.sparse import Dot, SparseTensorType
from tests.link.numba.test_basic import compare_numba_and_py


pytestmark = pytest.mark.filterwarnings("error")


def test_sparse_unboxing():
Expand Down Expand Up @@ -38,3 +45,56 @@ def test_boxing(x, y):
assert np.array_equal(res_y_val.indices, y_val.indices)
assert np.array_equal(res_y_val.indptr, y_val.indptr)
assert res_y_val.shape == y_val.shape


def test_sparse_shape():
@numba.njit
def test_fn(x):
return np.shape(x)

x_val = sp.sparse.csr_matrix(np.eye(100))

res = test_fn(x_val)

assert res == (100, 100)


def test_sparse_ndim():
@numba.njit
def test_fn(x):
return x.ndim

x_val = sp.sparse.csr_matrix(np.eye(100))

res = test_fn(x_val)

assert res == 2


def test_sparse_copy():
@numba.njit
def test_fn(x):
y = x.copy()
return (
y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices)
)

x_val = sp.sparse.csr_matrix(np.eye(100))

assert test_fn(x_val)


def test_sparse_objmode():
x = SparseTensorType("csc", dtype=config.floatX)()
y = SparseTensorType("csc", dtype=config.floatX)()

out = Dot()(x, y)

x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)
y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX)

with pytest.warns(
UserWarning,
match="Numba will use object mode to run SparseDot's perform method",
):
compare_numba_and_py(((x, y), (out,)), [x_val, y_val])