Skip to content

Conversation

@tomicapretto
Copy link
Contributor

@tomicapretto tomicapretto commented Jan 29, 2026

Description

The main contribution of this PR is the implementation of StructuredDotGradCSR and StructuredDotGradCSC in the numba backend.

While I was working on it, I noticed Ops SpSum and SparseFromDense were running in object mode, so I also implemented them.

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 2 times, most recently from 190c587 to 4690cde Compare January 29, 2026 13:34
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 2 times, most recently from c96ae8c to 2025883 Compare January 29, 2026 14:56
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 2 times, most recently from 512fb59 to 32099f1 Compare January 30, 2026 03:42
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch 2 times, most recently from e7b1261 to a054b5d Compare January 31, 2026 17:51
@tomicapretto tomicapretto force-pushed the sparse_gradients_numba branch from a054b5d to 6af5a1a Compare January 31, 2026 17:54
@tomicapretto tomicapretto marked this pull request as ready for review January 31, 2026 17:56
@tomicapretto
Copy link
Contributor Author

The test that fails is:

FAILED tests/tensor/test_slinalg.py::TestSchur::test_schur_empty - ValueError: negative dimensions not allowed

which is unrelated to this PR.

@ricardoV94
Copy link
Member

The test that fails is:

FAILED tests/tensor/test_slinalg.py::TestSchur::test_schur_empty - ValueError: negative dimensions not allowed

which is unrelated to this PR.

@jessegrabowski


# Pre-allocate internal containers
data = np.empty(nnz, dtype=matrix.dtype)
indices = np.empty(nnz, dtype=np.uint32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be int32 no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used uint32 in other places. I thought that since they're never negative, we could use uint32.


for col_idx in range(size):
for value_idx in range(x_ptr[col_idx], x_ptr[col_idx + 1]):
output[value_idx] = np.dot(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to be careful with np.dot. IIRC numba overload doesn't support integer / mixed dtypes well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh, I'm using it since np.sum(x * y) was slower. There are a bunch of test that pass different data types, and they have all passed. Probably that's ok?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its probably fine as long as we're upcasting the inputs to a common dtype in the make_node of Dot?

In the medium term we should consider re-implementing the BLAS calls ourselves

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, I just left some minor comments

axis = op.axis

@numba_basic.numba_njit
def perform(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does mypy freak out if you typehint this as SparseArray -> SparseArray? It would make the function more clear. Not required if it causes a headache (typehinting these overloads often does)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tried it, but the SpSum op returns a dense array (see this).

What happens here is that this calls the function I implemented in overload_sum in variable.py.
Maybe a global somewhere (per op or at the top of the file) saying that many (if not all) Ops are using overloads written in a separate python file?

# General spmspm algorithm in CSR format
@numba_basic.numba_njit
def _spmspm(n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data):
def _spmspm_csr(x, y, n_row, n_col):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth considering a bit of reorganization here for future extensibility. We can make a new sparse/math sub-module and have a sum.py file with each of these inner njit functions defined independently. numba_funcify_SparseDenseMultiply can still live here, but it would be just an input checker and routing to the correct function. I'm thinking about what it will look like in the future to add support for a new sparse type.

The pattern I'm thinking about is what we are doing with linalg, for example QZ: each case is defined separately here, then the actual dispatch is defined here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds good to me. I thought a bit about it prior starting to work on this, but I saw the other ops in this module were implemented this way, so I thought it was for a reason. Maybe I just overthought about it and it was simple convenience.

if formats == ("csc", "csc"):
# In all cases, the output is dense when the op is Dot.
@numba_basic.numba_njit
def spmspm(x, y):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to my point above, it would be great if each of these functions were defined with a name that clarified the case we're handling. It would make this format routing much more clear.

@register_funcify_and_cache_key(StructuredDotGradCSR)
@register_funcify_and_cache_key(StructuredDotGradCSC)
def numba_funcify_StructuredDotGrad(op, node, **kwargs):
# Let:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a docstring

Comment on lines +209 to +214
# Pass 1: Count non-zeros to pre-allocate
nnz = 0
for i in range(n_rows):
for j in range(n_cols):
if arg1[i, j] != 0:
nnz += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some helpers in common between the csc and csr case that you could consider extracting (though I recognize we haven't hit the rule of 3 yet)

@overload_method(CSMatrixType, "sum")
def overload_sum(matrix, axis):
# 'axis' can be either None, 0, or 1.
if axis is types.none:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if axis is None doesn't work here?

Copy link
Contributor Author

@tomicapretto tomicapretto Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, this was actually something I discovered by trial and error. We're dealing with numba types as inputs here, thus None does not work. Same idea applies for isinstance(matrix, np.ndarray), one has to do isinstance(matrix, types.Array).

I could have read numba's docs on extending it more thoroughly, of course hehe

out[0] = np.asarray(variable, str(variable.dtype))

def grad(self, inputs, gout):
# FIXME: It's not always true that b and g_out are dense.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you already open an issue for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just created it #1871, thanks for the nudge.



@pytest.mark.parametrize("x_format", ["csr", "csc"])
@pytest.mark.parametrize("y_format", ["csr", "csc", None])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.parametrize("y_format", ["csr", "csc", None])
@pytest.mark.parametrize("y_format", ["csr", "csc", "dense"])

@pytest.mark.parametrize("y_format", ["csr", "csc", None])
@pytest.mark.parametrize("x_shape, y_shape", DOT_SHAPES)
def test_structured_dot_grad(x_format, y_format, x_shape, y_shape):
rng = np.random.default_rng(sum(map(ord, x_format)) + sum(x_shape) + sum(y_shape))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't love seeded tests, I'd rather know over time if an implementation is flaky. I'm not sure everyone agrees on this, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think randomness should be a default. Float point precision is a weird thing and from experience flakyness from float point issues is orders of magnitude larger than from actual bugs (god knows how much electricity/time the whole float32 ci cost us in the long run)

Also from experience real bugs are unlikely to be masked away from a single seed.

I would restrict randomness to stochastic/statistical applications, where it's a meaningful construct.

In this specific case, I don't believe a seed could matter for bug catching.

Copy link
Member

@ricardoV94 ricardoV94 Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think randomness should be a default. Float point precision is a weird thing and from experience flakyness from float point issues is orders of magnitude larger than from actual bugs (god knows how much electricity/time the whole float32 ci cost us in the long run)

Also from experience real bugs are unlikely to be masked away from a single seed.

I would restrict randomness to stochastic/statistical applications, where it's a meaningful construct. Or where there's not risk of float point shenanigans

In this specific case, I don't believe a seed could matter for bug catching.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just followed the pattern I saw. I don't like thinking about how I create a seed, but at the same time I hate my tests failing because of weird float stuff.

Are we OK with myself removing seeds in these tests then?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants