-
Notifications
You must be signed in to change notification settings - Fork 3k
jax.dispatch decorator - for improved readability #12031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
If I understand correctly, I think the mechanism you have in mind is already provided by the built-in import jax.numpy as jnp
from typing import Any, NamedTuple
from functools import singledispatch
class Batch_1D(NamedTuple):
x: jnp.ndarray
class Batch_2D(NamedTuple):
x: jnp.ndarray
@singledispatch
def flatten_batch_dim(batch: Any) -> Batch_1D:
raise NotImplementedError(f"flatten_batch_dim for type {type(batch)}")
@flatten_batch_dim.register
def _(batch: Batch_2D) -> Batch_1D:
x = batch.x.reshape((-1,)+batch.x.shape[2:])
return Batch_1D(x)
@flatten_batch_dim.register
def _(batch: Batch_1D) -> Batch_1D:
return batch
batch = Batch_2D(x = jnp.ones((16,32,5)))
print(flatten_batch_dim(batch).x.shape)
# (512, 5)
batch = Batch_1D(x = jnp.ones((16,32,5)))
print(flatten_batch_dim(batch).x.shape)
# (16, 32, 5) |
Ah poor example on my part. It's quite powerful but unfortunately this implementation does not like jit ;) |
Have you tried plum? I've not tested to see how it works under (In passing -- there's been some offline discussion about trying to get plum working with jaxtyping, which would be pretty neat, and may solve your use case.) |
Can confirm that plum (including multiple dispatch) works indeed perfectly under jit. |
Unfortunately, both Example: @ft.singledispatch
def f(x):
raise NotImplementedError()
@f.register
def _(x: jnp.ndarray):
print("Array-operation")
@jax.vmap
def g(x: jnp.ndarray):
return f(x)
g(jnp.ones((10,1)))
>>> NotImplementedError Edit: *fixed typo |
The reason this doesn't work is likely because |
according to jake this should work @ft.singledispatch
def f(x):
raise NotImplementedError()
@f.register(jnp.ndarray)
@f.register(jax.core.Tracer)
def _(x):
print("Array-operation")
@jax.vmap
def g(x: jnp.ndarray):
return f(x)
g(jnp.ones((10,1))) # Array-operation Also If you want multiple dispatch using only It seems that python # Tested on mac m1 CPU
from multipledispatch import dispatch as dispatch_md
from plum import dispatch as dispatch_plum
from functools import singledispatch as dispatch_std
@dispatch_md(int)
def f_md(x):
return x
@dispatch_plum
def f_plum(x: int):
return x
def f_native(x):
return x
@dispatch_std
def f_std(x): ...
@f_std.register(int)
def _(x):
return x
f_md(1); f_plum(1); # Run once to populate cache.
%timeit f_native(1)
# 39.6 ns ± 0.629 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
%timeit f_md(1)
# 281 ns ± 2.08 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%timeit f_plum(1)
# 337 ns ± 5.64 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%timeit f_std(1)
# 267 ns ± 1.97 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each) |
Yes i guess this does work. It is just a very unpretty amount of decorators which kind of defeats its purpose (improved readability) in the first place. |
I would like to request a
jax.dispatch
decorator that can be used to transform a function into a function supporting single/multiple dispatch on its typed arguments.The main advantage of such a decorator is
a) improved code readability
b) less boilerplate code
c) no runtime overhead, since dispatch happens at JIT-compile time
Consider e.g. the following scenario
Now, the type of my data determines the function behavior. E.g.
whereas
This could be more beautifully achieved using dispatch. E.g.
The text was updated successfully, but these errors were encountered: