@@ -153,12 +153,14 @@ def numba_funcify_SparseDot(op, node, **kwargs):
153153 if x_is_sparse and y_is_sparse :
154154 # General spmspm algorithm in CSR format
155155 @numba_basic .numba_njit
156- def _spmspm ( n_row , n_col , x_ptr , x_ind , x_data , y_ptr , y_ind , y_data ):
156+ def _spmspm_csr ( x , y , n_row , n_col ):
157157 # Pass 1
158- x_ind = x_ind .view (np .uint32 )
159- y_ind = y_ind .view (np .uint32 )
160- x_ptr = x_ptr .view (np .uint32 )
161- y_ptr = y_ptr .view (np .uint32 )
158+ x_ind = x .indices .view (np .uint32 )
159+ y_ind = y .indices .view (np .uint32 )
160+ x_ptr = x .indptr .view (np .uint32 )
161+ y_ptr = y .indptr .view (np .uint32 )
162+ x_data = x .data
163+ y_data = y .data
162164
163165 output_nnz = 0
164166 mask = np .full (n_col , - 1 , dtype = np .int32 )
@@ -217,39 +219,54 @@ def _spmspm(n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data):
217219
218220 return z_ptr .view (np .int32 ), z_ind .view (np .int32 ), z_data
219221
220- @numba_basic .numba_njit
221- def spmspm (x , y ):
222- if x_format == "csc" and y_format == "csc" :
223- # Compute the transpose dot, to avoid costly conversion tocsr()
224- x , y = y .T , x .T
225- elif x_format == "csc" :
226- x = x .tocsr ()
227- elif y_format == "csc" :
228- y = y .tocsr ()
229-
230- x_ptr , x_ind , x_data = x .indptr , x .indices , x .data
231- y_ptr , y_ind , y_data = y .indptr , y .indices , y .data
232- n_row , n_col = x .shape [0 ], y .shape [1 ]
233-
234- z_ptr , z_ind , z_data = _spmspm (
235- n_row , n_col , x_ptr , x_ind , x_data , y_ptr , y_ind , y_data
236- )
222+ formats = (x_format , y_format )
223+ if formats == ("csc" , "csc" ):
224+ # In all cases, the output is dense when the op is Dot.
225+ @numba_basic .numba_njit
226+ def spmspm (x , y ):
227+ # Swap inputs
228+ n_row , n_col = x .shape [0 ], y .shape [1 ]
229+ z_ptr , z_ind , z_data = _spmspm_csr (x = y , y = x , n_row = n_col , n_col = n_row )
230+ output = sp .csc_matrix ((z_data , z_ind , z_ptr ), shape = (n_row , n_col ))
231+ if not z_is_sparse :
232+ return output .toarray ()
233+ return output
234+ elif formats == ("csc" , "csr" ):
237235
238- output = sp .csr_matrix ((z_data , z_ind , z_ptr ), shape = (n_row , n_col ))
236+ @numba_basic .numba_njit
237+ def spmspm (x , y ):
238+ # Convert csr to csc and swap
239+ n_row , n_col = x .shape [0 ], y .shape [1 ]
240+ z_ptr , z_ind , z_data = _spmspm_csr (
241+ x = y .tocsc (), y = x , n_row = n_col , n_col = n_row
242+ )
243+ output = sp .csc_matrix ((z_data , z_ind , z_ptr ), shape = (n_row , n_col ))
244+ if not z_is_sparse :
245+ return output .toarray ()
246+ return output
247+ elif formats == ("csr" , "csc" ):
239248
240- if x_format == "csc" and y_format == "csc" :
241- # We computed the transposed dot in csr, if we transpose the result we get csc
242- output = output .T
249+ @numba_basic .numba_njit
250+ def spmspm (x , y ):
251+ # Convert csc to csr, no swap
252+ n_row , n_col = x .shape [0 ], y .shape [1 ]
253+ z_ptr , z_ind , z_data = _spmspm_csr (
254+ x = x , y = y .tocsr (), n_row = n_row , n_col = n_col
255+ )
256+ output = sp .csr_matrix ((z_data , z_ind , z_ptr ), shape = (n_row , n_col ))
257+ if not z_is_sparse :
258+ return output .toarray ()
259+ return output
260+ else :
243261
244- # Dot returns a dense result even in spMspM
245- if not z_is_sparse :
246- return output .toarray ()
247- # StructuredDot returns in the format of 'x'
248- elif x_format == "csc" and y_format == "csr" :
249- # This is the only case we can't escape a `tocsc()` call
250- return output .tocsc ()
251- else :
252- # Output already in the desired format
262+ @numba_basic .numba_njit
263+ def spmspm (x , y ):
264+ # No conversion, no swap
265+ n_row , n_col = x .shape [0 ], y .shape [1 ]
266+ z_ptr , z_ind , z_data = _spmspm_csr (x = x , y = y , n_row = n_row , n_col = n_col )
267+ output = sp .csr_matrix ((z_data , z_ind , z_ptr ), shape = (n_row , n_col ))
268+ if not z_is_sparse :
269+ return output .toarray ()
253270 return output
254271
255272 return spmspm , cache_key
0 commit comments