Skip to content

Commit 36d9949

Browse files
committed
Add experimental __array_module__ method
xref jax-ml#1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](numpy/numpy#16935 (comment)) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time.
1 parent 1316562 commit 36d9949

File tree

4 files changed

+51
-1
lines changed

4 files changed

+51
-1
lines changed

.github/workflows/ci-build.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ jobs:
6060
os: ubuntu-latest
6161
enable-x64: 1
6262
enable-omnistaging: 0
63-
package-overrides: "none"
63+
# Test experimental NumPy dispatch
64+
# TODO(shoyer): remove cython after
65+
# https://github.com/seberg/numpy-dispatch/pull/5 is merged
66+
package-overrides: "cython git+https://github.com/seberg/numpy-dispatch.git"
6467
num_generated_cases: 25
6568
- python-version: 3.6
6669
os: ubuntu-latest

jax/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,9 @@ def __len__(self):
469469
def aval(self):
470470
raise NotImplementedError("must override")
471471

472+
# Python looks up special methods only on classes, not instances. This means
473+
# these methods needs to be defined explicitly rather than relying on
474+
# __getattr__ (short of using a metaclass).
472475
def __neg__(self): return self.aval._neg(self)
473476
def __pos__(self): return self.aval._pos(self)
474477
def __eq__(self, other): return self.aval._eq(self, other)
@@ -528,6 +531,9 @@ def __complex__(self):
528531
def __setitem__(self, idx, val):
529532
raise TypeError("JAX 'Tracer' objects do not support item assignment")
530533

534+
# NumPy also only looks up special methods on classes.
535+
def __array_module__(self, types): return self.aval._array_module(self, types)
536+
531537
def __getattr__(self, name):
532538
# if the aval property raises an AttributeError, gets caught here
533539
assert skip_checks or name != "aval"

jax/numpy/lax_numpy.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import numpy as np
3737
import opt_einsum
3838

39+
import jax
3940
from jax import jit, custom_jvp
4041
from .vectorize import vectorize
4142
from ._util import _wraps
@@ -4574,6 +4575,21 @@ def _operator_round(number, ndigits=None):
45744575
setattr(DeviceArray, "nbytes", property(_nbytes))
45754576

45764577

4578+
# Experimental support for NumPy's module dispatch with NEP-37.
4579+
# Currently requires https://github.com/seberg/numpy-dispatch
4580+
_JAX_ARRAY_TYPES = (UnshapedArray, DeviceArray, core.Tracer)
4581+
_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)
4582+
4583+
def __array_module__(self, types):
4584+
if builtins.all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types):
4585+
return jax.numpy
4586+
else:
4587+
return NotImplemented
4588+
4589+
setattr(ShapedArray, "_array_module", staticmethod(__array_module__))
4590+
setattr(DeviceArray, "__array_module__", __array_module__)
4591+
4592+
45774593
# Extra methods that are handy
45784594
setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast))
45794595
setattr(ShapedArray, "broadcast_in_dim", core.aval_method(lax.broadcast_in_dim))

tests/lax_numpy_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
from absl.testing import parameterized
2929

3030
import numpy as np
31+
try:
32+
import numpy_dispatch
33+
except ImportError:
34+
numpy_dispatch = None
3135

3236
import jax
3337
import jax.ops
@@ -585,6 +589,27 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype):
585589
with self.assertRaises(TypeError):
586590
op(arg, other)
587591

592+
def testArrayModule(self):
593+
if numpy_dispatch is None:
594+
raise SkipTest('requires https://github.com/seberg/numpy-dispatch')
595+
596+
jnp_array = jnp.array(1.0)
597+
np_array = np.array(1.0)
598+
599+
with numpy_dispatch.ensure_dispatching():
600+
module = numpy_dispatch.get_array_module(jnp_array)
601+
self.assertIs(module, jnp)
602+
603+
module = numpy_dispatch.get_array_module(jnp_array, np_array)
604+
self.assertIs(module, jnp)
605+
606+
def f(x):
607+
module = numpy_dispatch.get_array_module(x)
608+
self.assertIs(module, jnp)
609+
return x
610+
jax.jit(f)(jnp_array)
611+
jax.grad(f)(jnp_array)
612+
588613
@parameterized.named_parameters(itertools.chain.from_iterable(
589614
jtu.cases_from_list(
590615
{"testcase_name": jtu.format_test_name_suffix(

0 commit comments

Comments
 (0)