Skip to content

Commit 220442a

Browse files
brandonwillardricardoV94
authored andcommitted
Implement a copy method for Numba sparse types
1 parent eb6fc66 commit 220442a

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

pytensor/link/numba/dispatch/sparse.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import scipy as sp
33
import scipy.sparse
44
from numba.core import cgutils, types
5+
from numba.core.imputils import impl_ret_borrowed
56
from numba.extending import (
67
NativeValue,
78
box,
9+
intrinsic,
810
make_attribute_wrapper,
911
models,
1012
overload,
1113
overload_attribute,
14+
overload_method,
1215
register_model,
1316
typeof_impl,
1417
unbox,
@@ -166,3 +169,38 @@ def ndim(inst):
166169
return 2
167170

168171
return ndim
172+
173+
174+
@intrinsic
175+
def _sparse_copy(typingctx, inst, data, indices, indptr, shape):
176+
def _construct(context, builder, sig, args):
177+
typ = sig.return_type
178+
struct = cgutils.create_struct_proxy(typ)(context, builder)
179+
_, data, indices, indptr, shape = args
180+
struct.data = data
181+
struct.indices = indices
182+
struct.indptr = indptr
183+
struct.shape = shape
184+
return impl_ret_borrowed(
185+
context,
186+
builder,
187+
sig.return_type,
188+
struct._getvalue(),
189+
)
190+
191+
sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape)
192+
193+
return sig, _construct
194+
195+
196+
@overload_method(CSMatrixType, "copy")
197+
def overload_sparse_copy(inst):
198+
if not isinstance(inst, CSMatrixType):
199+
return
200+
201+
def copy(inst):
202+
return _sparse_copy(
203+
inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape
204+
)
205+
206+
return copy

tests/link/numba/test_sparse.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ def test_fn(x):
7171
assert res == 2
7272

7373

74+
def test_sparse_copy():
75+
@numba.njit
76+
def test_fn(x):
77+
y = x.copy()
78+
return (
79+
y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices)
80+
)
81+
82+
x_val = sp.sparse.csr_matrix(np.eye(100))
83+
84+
assert test_fn(x_val)
85+
86+
7487
def test_sparse_objmode():
7588
x = SparseTensorType("csc", dtype=config.floatX)()
7689
y = SparseTensorType("csc", dtype=config.floatX)()

0 commit comments

Comments
 (0)