Skip to content

Conversation

@mdhaber
Copy link
Owner

@mdhaber mdhaber commented Mar 30, 2025

Closes gh-99.
Closes gh-82.

When this is ready:

  • I'll try imports in the test file, and if an import fails, the backend won't be added to the xps list. Done.
  • To get PyTorch running in CI, think I'll just need to add pytorch (and I guess array-api-compat, if it's not already there) to the test dependencies in pyproject.toml, then pixi update. Does that sound right @lucascolley? Done.
  • I don't think I want to spend money running on the GPU. I can do CuPy locally. How do we indicate that CuPy is an optional test dependency for users but that Pixi doesn't need to try installing it in CI? Follow-up PR.

@mdhaber mdhaber marked this pull request as ready for review March 31, 2025 06:24
@lucascolley
Copy link
Collaborator

Here's how we do it in array-api-extra (full file at https://github.com/data-apis/array-api-extra/blob/main/pyproject.toml):

[tool.pixi.environments]
tests = { features = ["py313", "tests"], solve-group = "py313" }
tests-backends = { features = ["py310", "tests", "backends"], solve-group = "backends" }
tests-cuda = { features = ["py310", "tests", "backends", "cuda-backends"], solve-group = "cuda" }
tests-numpy1 = ["py310", "tests", "numpy1"]
[tool.pixi.feature.tests.dependencies]
pytest = ">=6"
pytest-cov = ">=3"
hypothesis = "*"
array-api-strict = "*"
numpy = "*"
# Backends that can run on CPU-only hosts
# Note: JAX and PyTorch will install CPU variants.
[tool.pixi.feature.backends.dependencies]
pytorch = "*"
dask = "*"
numba = "*"  # sparse dependency
llvmlite = "*"  # sparse dependency

[tool.pixi.feature.backends.pypi-dependencies]
sparse = { version = ">= 0.16.0b3" }

[tool.pixi.feature.backends.target.linux-64.dependencies]
jax = "*"

[tool.pixi.feature.backends.target.osx-64.dependencies]
jax = "*"

[tool.pixi.feature.backends.target.osx-arm64.dependencies]
jax = "*"

[tool.pixi.feature.backends.target.win-64.dependencies]
# jax = "*"  # unavailable
# Backends that require a GPU host and a CUDA driver.
# Note that JAX and PyTorch automatically prefer CUDA variants
# thanks to the `system-requirements` below, *if available*.
# We request them explicitly below to ensure that we don't
# quietly revert to CPU-only in the future, e.g. when CUDA 13
# is released and CUDA 12 builds are dropped upstream.
[tool.pixi.feature.cuda-backends]
system-requirements = { cuda = "12" }

[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
cupy = "*"
jaxlib = { version = "*", build = "cuda12*" }
pytorch = { version = "*", build = "cuda12*" }

[tool.pixi.feature.cuda-backends.target.osx-64.dependencies]
# cupy = "*"  # unavailable
# jaxlib = { version = "*", build = "cuda12*" }  # unavailable
# pytorch = { version = "*", build = "cuda12*" }  # unavailable

[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies]
# cupy = "*"  # unavailable
# jaxlib = { version = "*", build = "cuda12*" }  # unavailable
# pytorch = { version = "*", build = "cuda12*" }  # unavailable

[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
cupy = "*"
# jaxlib = { version = "*", build = "cuda12*" }  # unavailable
pytorch = { version = "*", build = "cuda12*" }

@lucascolley
Copy link
Collaborator

Then in CI we have https://github.com/data-apis/array-api-extra/blob/66e9282df165df219797f65f1cc061665fe7e95f/.github/workflows/ci.yml#L53-L54:

matrix:
    environment: [tests-py310, tests-py313, tests-numpy1, tests-backends]

notably, no tests-cuda in CI.

@mdhaber mdhaber changed the title TST: run tests with PyTorch/CuPy backends TST: run tests with PyTorch Mar 31, 2025
@mdhaber mdhaber mentioned this pull request Mar 31, 2025
Copy link
Owner Author

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some self-review. I think this is ready for a look. Despite the line count, hopefully it's pretty straightforward.

The main sticking point might be the precise manner in which we allow tests to pass that would otherwise fail. Should they not be part of the parameterization? Should they be skipped? XFailed? A related concern might be how liberal the approach for passing tests looks. In this PR, I'm not concerned with this stuff. We are not trying to test the array API compliance of the backends; we're just supposed to be testing whether MArray is doing what it intends to do. We can fine-tune this stuff later, but it's better to have a fix for gh-99 and some sort PyTorch testing than to not have them.

@lucascolley Re: #110 (comment), I'm not going to add CuPy in this PR after all. This took a good fraction of a day to get working because of all the creative ways in which PyTorch is non-compliant. CuPy will be much easier, but I'd rather not hold this up for it - and besides, it should be easier to do when CuPy 14 gets released, since the arrays themselves are supposed to become more array API compliant. The need to do this is tracked in gh-112. At that time, can I get CuPy working locally and ask you to configure pyproject.toml and such as you see fit?

Please feel free to hit resolve on any of my self-review comments - the intent was just to explain a few things that might not be obvious. Please also LMK whether you think I should report #110 (comment).

Comment on lines +45 to +46
def pass_backend(*, xp, pass_xp, fun=None, pass_funs=None, dtype=None,
pass_dtypes=None, pass_using=pytest.skip, reason="Debug later."):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useful for skipping or xfailing a test for a particular function, dtype, and backend. There are a handful of cases where this sort of thing is needed, such as when there are arithmetic differences between backends that have nothing to do with MArray.

Comment on lines +119 to +121
low_precision = xp.isdtype(res.dtype, (xp.float32, xp.complex64))
tol = {'rtol': 1e-5, 'atol': 1e-16} if low_precision else {'rtol': 1e-10, 'atol': 1e-32}
kwargs = tol | kwargs
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjust default tolerances depending on precision.

Comment on lines +134 to +135
if re.escape(message) in re.escape(str(e)):
pytest.xfail(str(e))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some error messages have parentheses.

pass_backend(xp=xp, dtype=dtype, fun=f_name, pass_xp='torch',
pass_funs=['copysign', 'atan2'],
pass_dtypes=['bool', 'uint8', 'uint16', 'int8', 'int16'],
reason="Unexpected dtype", pass_using=pytest.xfail)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than xfailing, we could probably predict Torch's dtype rules or just ignore the dtype comparison. We're talking about two functions... I'm not terribly concerned about missing an MArray bug by ignoring these.

cond = rng.random(marrays[0].shape) > 0.5
x[xp.asarray(cond)] = 0
y[xp.asarray(cond)] = 0
y[cond] = 0
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just wrong before.

np.testing.assert_equal(res[i].mask, np.full(ref[i].shape, False))
actual = mxp.asarray(res[i].data, mask=res[i].mask)
desired = np.ma.masked_array(ref[i], np.full(ref[i].shape, False))
assert_equal(actual, desired, xp=xp, seed=seed)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reorganization is just to support CuPy, so not strictly needed here, but will be good to have done when I address gh-112.


data[mask] = sentinel
ref = xp.unique_all(data)
ref = np.unique_all(np.asarray(data))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, the point is that we want to use NumPy's behavior as the reference, since Torch doesn't have unique_all.

Won't work as-written when CuPy is the backend, but that can be addressed in gh-112.

Copy link
Collaborator

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Matt, all looks reasonable to me!



def as_numpy(x):
# Use `cupy.testing` when `assert_allclose` and `assert_equal` support `strict`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may be interested in helping to push data-apis/array-api-extra#17 along. Long-term I think testing infra belongs in xpx.testing.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll trade you.
Do you want me to submit a PR with the testing functions or review one?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it is ready for a PR yet, I think there is still a bit of investigation to do. We have at https://github.com/data-apis/array-api-extra/blob/main/src/array_api_extra/_lib/_testing.py the private implementations added by Guido which have been sufficient for xpx internally so far. I think the main investigative work is:

That might already be quite a lot of work. We may want to judge some things (like SciPy's 'regular' vs 'no-0d' stuff) as out of scope for the xpx versions, but I'm not sure that helps avoid the questions of what should be the defaults.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we can leave the 0d stuff out. I'll get caught up on the discussion.

Copy link
Owner Author

@mdhaber mdhaber Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW something I noticed is that CuPy actually converts to NumPy for testing. So for all libraries, rather than using their functions when available, might be easiest to always convert to NumPy (if that's not already what we do, I forget).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we incorporate all features from the scipy._lib (and perhaps MArray) versions such that we can drop the xpx versions in

I took a look. I don't see why not? The way I'd investigate further is to try it. Do you want me to do that, and if it seems OK, submit the PR?

can we incorporate everything scikit-learn would need...

For device and dtype testing infrastructure - like automatically providing test functions with types, devices, dtypes - that is out of my area. I'd only be able to help with the assertions, which I think can be separate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want me to do that, and if it seems OK, submit the PR?

If you are willing, sure!

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The differences I'm seeing between the xpx functions and what we have in SciPy are:

  • xpx does not distinguish between 0d arrays and scalars (for NumPy)
  • xpx does not have options check_namespace, check_dtype, check_shape, check_0d (maybe should be called check_type because it distinguishes between arrays and scalars?)
  • xpx does not have xp_assert_less

(Otherwise, it looks like it works just like SciPy. It even has the "smart" default rtol (rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4).

Should I add these features?

Something else that might be worth changing for simplicity is to use NumPy testing functions for all backends. CuPy already does this, actually, and the only backend that really seems to have its own testing functions is PyTorch. So I would just add conversion of these array types to NumPy arrays to _check_ns_shape_dtype and remove the use of backend-specific testing functions. How does that sound?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I add these features?

Yes sounds good, thanks!

How does that sound?

That sounds reasonable to me 👍

@mdhaber mdhaber merged commit a8ca4a9 into main Apr 14, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

any() fails when wrapping JAX and PyTorch Test against PyTorch?

3 participants