-
Notifications
You must be signed in to change notification settings - Fork 162
Implement StructuredDotGradCSR and StructuredDotGradCSC in numba backend
#1860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
190c587 to
4690cde
Compare
c96ae8c to
2025883
Compare
512fb59 to
32099f1
Compare
e7b1261 to
a054b5d
Compare
…ntation with StructuredDotGradCSR
a054b5d to
6af5a1a
Compare
|
The test that fails is: which is unrelated to this PR. |
|
|
|
||
| # Pre-allocate internal containers | ||
| data = np.empty(nnz, dtype=matrix.dtype) | ||
| indices = np.empty(nnz, dtype=np.uint32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be int32 no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We used uint32 in other places. I thought that since they're never negative, we could use uint32.
|
|
||
| for col_idx in range(size): | ||
| for value_idx in range(x_ptr[col_idx], x_ptr[col_idx + 1]): | ||
| output[value_idx] = np.dot( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to be careful with np.dot. IIRC numba overload doesn't support integer / mixed dtypes well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argh, I'm using it since np.sum(x * y) was slower. There are a bunch of test that pass different data types, and they have all passed. Probably that's ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its probably fine as long as we're upcasting the inputs to a common dtype in the make_node of Dot?
In the medium term we should consider re-implementing the BLAS calls ourselves
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, I just left some minor comments
…f format conversions
…those implementations in SparseFromDense
| axis = op.axis | ||
|
|
||
| @numba_basic.numba_njit | ||
| def perform(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does mypy freak out if you typehint this as SparseArray -> SparseArray? It would make the function more clear. Not required if it causes a headache (typehinting these overloads often does)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not tried it, but the SpSum op returns a dense array (see this).
What happens here is that this calls the function I implemented in overload_sum in variable.py.
Maybe a global somewhere (per op or at the top of the file) saying that many (if not all) Ops are using overloads written in a separate python file?
| # General spmspm algorithm in CSR format | ||
| @numba_basic.numba_njit | ||
| def _spmspm(n_row, n_col, x_ptr, x_ind, x_data, y_ptr, y_ind, y_data): | ||
| def _spmspm_csr(x, y, n_row, n_col): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth considering a bit of reorganization here for future extensibility. We can make a new sparse/math sub-module and have a sum.py file with each of these inner njit functions defined independently. numba_funcify_SparseDenseMultiply can still live here, but it would be just an input checker and routing to the correct function. I'm thinking about what it will look like in the future to add support for a new sparse type.
The pattern I'm thinking about is what we are doing with linalg, for example QZ: each case is defined separately here, then the actual dispatch is defined here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds good to me. I thought a bit about it prior starting to work on this, but I saw the other ops in this module were implemented this way, so I thought it was for a reason. Maybe I just overthought about it and it was simple convenience.
| if formats == ("csc", "csc"): | ||
| # In all cases, the output is dense when the op is Dot. | ||
| @numba_basic.numba_njit | ||
| def spmspm(x, y): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to my point above, it would be great if each of these functions were defined with a name that clarified the case we're handling. It would make this format routing much more clear.
| @register_funcify_and_cache_key(StructuredDotGradCSR) | ||
| @register_funcify_and_cache_key(StructuredDotGradCSC) | ||
| def numba_funcify_StructuredDotGrad(op, node, **kwargs): | ||
| # Let: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this a docstring
| # Pass 1: Count non-zeros to pre-allocate | ||
| nnz = 0 | ||
| for i in range(n_rows): | ||
| for j in range(n_cols): | ||
| if arg1[i, j] != 0: | ||
| nnz += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some helpers in common between the csc and csr case that you could consider extracting (though I recognize we haven't hit the rule of 3 yet)
| @overload_method(CSMatrixType, "sum") | ||
| def overload_sum(matrix, axis): | ||
| # 'axis' can be either None, 0, or 1. | ||
| if axis is types.none: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if axis is None doesn't work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, this was actually something I discovered by trial and error. We're dealing with numba types as inputs here, thus None does not work. Same idea applies for isinstance(matrix, np.ndarray), one has to do isinstance(matrix, types.Array).
I could have read numba's docs on extending it more thoroughly, of course hehe
| out[0] = np.asarray(variable, str(variable.dtype)) | ||
|
|
||
| def grad(self, inputs, gout): | ||
| # FIXME: It's not always true that b and g_out are dense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you already open an issue for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just created it #1871, thanks for the nudge.
|
|
||
|
|
||
| @pytest.mark.parametrize("x_format", ["csr", "csc"]) | ||
| @pytest.mark.parametrize("y_format", ["csr", "csc", None]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @pytest.mark.parametrize("y_format", ["csr", "csc", None]) | |
| @pytest.mark.parametrize("y_format", ["csr", "csc", "dense"]) |
| @pytest.mark.parametrize("y_format", ["csr", "csc", None]) | ||
| @pytest.mark.parametrize("x_shape, y_shape", DOT_SHAPES) | ||
| def test_structured_dot_grad(x_format, y_format, x_shape, y_shape): | ||
| rng = np.random.default_rng(sum(map(ord, x_format)) + sum(x_shape) + sum(y_shape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally don't love seeded tests, I'd rather know over time if an implementation is flaky. I'm not sure everyone agrees on this, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think randomness should be a default. Float point precision is a weird thing and from experience flakyness from float point issues is orders of magnitude larger than from actual bugs (god knows how much electricity/time the whole float32 ci cost us in the long run)
Also from experience real bugs are unlikely to be masked away from a single seed.
I would restrict randomness to stochastic/statistical applications, where it's a meaningful construct.
In this specific case, I don't believe a seed could matter for bug catching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think randomness should be a default. Float point precision is a weird thing and from experience flakyness from float point issues is orders of magnitude larger than from actual bugs (god knows how much electricity/time the whole float32 ci cost us in the long run)
Also from experience real bugs are unlikely to be masked away from a single seed.
I would restrict randomness to stochastic/statistical applications, where it's a meaningful construct. Or where there's not risk of float point shenanigans
In this specific case, I don't believe a seed could matter for bug catching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just followed the pattern I saw. I don't like thinking about how I create a seed, but at the same time I hate my tests failing because of weird float stuff.
Are we OK with myself removing seeds in these tests then?
Co-authored-by: Jesse Grabowski <[email protected]>
Description
The main contribution of this PR is the implementation of
StructuredDotGradCSRandStructuredDotGradCSCin the numba backend.While I was working on it, I noticed Ops
SpSumandSparseFromDensewere running in object mode, so I also implemented them.Checklist
Type of change