Skip to content

Commit ac54ffe

Browse files
committed
Handle sparse mat mat in a more specialized manner, reducing number of format conversions
1 parent 6af5a1a commit ac54ffe

File tree

1 file changed

+52
-35
lines changed
  • pytensor/link/numba/dispatch/sparse

1 file changed

+52
-35
lines changed

pytensor/link/numba/dispatch/sparse/math.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,14 @@ def numba_funcify_SparseDot(op, node, **kwargs):
153153
if x_is_sparse and y_is_sparse:
154154
# General spmspm algorithm in CSR format
155155
@numba_basic.numba_njit
156-
def _spmspm(n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data):
156+
def _spmspm_csr(x, y, n_row, n_col):
157157
# Pass 1
158-
x_ind = x_ind.view(np.uint32)
159-
y_ind = y_ind.view(np.uint32)
160-
x_ptr = x_ptr.view(np.uint32)
161-
y_ptr = y_ptr.view(np.uint32)
158+
x_ind = x.indices.view(np.uint32)
159+
y_ind = y.indices.view(np.uint32)
160+
x_ptr = x.indptr.view(np.uint32)
161+
y_ptr = y.indptr.view(np.uint32)
162+
x_data = x.data
163+
y_data = y.data
162164

163165
output_nnz = 0
164166
mask = np.full(n_col, -1, dtype=np.int32)
@@ -217,39 +219,54 @@ def _spmspm(n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data):
217219

218220
return z_ptr.view(np.int32), z_ind.view(np.int32), z_data
219221

220-
@numba_basic.numba_njit
221-
def spmspm(x, y):
222-
if x_format == "csc" and y_format == "csc":
223-
# Compute the transpose dot, to avoid costly conversion tocsr()
224-
x, y = y.T, x.T
225-
elif x_format == "csc":
226-
x = x.tocsr()
227-
elif y_format == "csc":
228-
y = y.tocsr()
229-
230-
x_ptr, x_ind, x_data = x.indptr, x.indices, x.data
231-
y_ptr, y_ind, y_data = y.indptr, y.indices, y.data
232-
n_row, n_col = x.shape[0], y.shape[1]
233-
234-
z_ptr, z_ind, z_data = _spmspm(
235-
n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data
236-
)
222+
formats = (x_format, y_format)
223+
if formats == ("csc", "csc"):
224+
# In all cases, the output is dense when the op is Dot.
225+
@numba_basic.numba_njit
226+
def spmspm(x, y):
227+
# Swap inputs
228+
n_row, n_col = x.shape[0], y.shape[1]
229+
z_ptr, z_ind, z_data = _spmspm_csr(x=y, y=x, n_row=n_col, n_col=n_row)
230+
output = sp.csc_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
231+
if not z_is_sparse:
232+
return output.toarray()
233+
return output
234+
elif formats == ("csc", "csr"):
237235

238-
output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
236+
@numba_basic.numba_njit
237+
def spmspm(x, y):
238+
# Convert csr to csc and swap
239+
n_row, n_col = x.shape[0], y.shape[1]
240+
z_ptr, z_ind, z_data = _spmspm_csr(
241+
x=y.tocsc(), y=x, n_row=n_col, n_col=n_row
242+
)
243+
output = sp.csc_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
244+
if not z_is_sparse:
245+
return output.toarray()
246+
return output
247+
elif formats == ("csr", "csc"):
239248

240-
if x_format == "csc" and y_format == "csc":
241-
# We computed the transposed dot in csr, if we transpose the result we get csc
242-
output = output.T
249+
@numba_basic.numba_njit
250+
def spmspm(x, y):
251+
# Convert csc to csr, no swap
252+
n_row, n_col = x.shape[0], y.shape[1]
253+
z_ptr, z_ind, z_data = _spmspm_csr(
254+
x=x, y=y.tocsr(), n_row=n_row, n_col=n_col
255+
)
256+
output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
257+
if not z_is_sparse:
258+
return output.toarray()
259+
return output
260+
else:
243261

244-
# Dot returns a dense result even in spMspM
245-
if not z_is_sparse:
246-
return output.toarray()
247-
# StructuredDot returns in the format of 'x'
248-
elif x_format == "csc" and y_format == "csr":
249-
# This is the only case we can't escape a `tocsc()` call
250-
return output.tocsc()
251-
else:
252-
# Output already in the desired format
262+
@numba_basic.numba_njit
263+
def spmspm(x, y):
264+
# No conversion, no swap
265+
n_row, n_col = x.shape[0], y.shape[1]
266+
z_ptr, z_ind, z_data = _spmspm_csr(x=x, y=y, n_row=n_row, n_col=n_col)
267+
output = sp.csr_matrix((z_data, z_ind, z_ptr), shape=(n_row, n_col))
268+
if not z_is_sparse:
269+
return output.toarray()
253270
return output
254271

255272
return spmspm, cache_key

0 commit comments

Comments
 (0)