Skip to content

jax.dispatch decorator - for improved readability #12031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
simon-bachhuber opened this issue Aug 20, 2022 · 8 comments
Closed

jax.dispatch decorator - for improved readability #12031

simon-bachhuber opened this issue Aug 20, 2022 · 8 comments
Labels
enhancement New feature or request

Comments

@simon-bachhuber
Copy link

I would like to request a jax.dispatch decorator that can be used to transform a function into a function supporting single/multiple dispatch on its typed arguments.

The main advantage of such a decorator is
a) improved code readability
b) less boilerplate code
c) no runtime overhead, since dispatch happens at JIT-compile time

Consider e.g. the following scenario

from typing import NamedTuple, Union

class Batch_1D(NamedTuple):
    x: jnp.ndarray

class Batch_2D(NamedTuple):
    x: jnp.ndarray

def _flatten_batch_dim(batch: Batch_2D) -> Batch_1D:
    x = batch.x.reshape((-1,)+batch.x.shape[2:])
    return Batch_1D(x)

@jax.jit
def flatten_batch_dim(batch: Union[Batch_1D, Batch_2D]) -> Batch_1D:
    if isinstance(batch, Batch_2D):
        batch = _flatten_batch_dim(batch) 
    return batch 

Now, the type of my data determines the function behavior. E.g.

batch = Batch_1D(x = jnp.ones((16,32,5)))
print(flatten_batch_dim(batch).x.shape)
(16,32,5)

whereas

batch = Batch_2D(x = jnp.ones((16,32,5)))
print(flatten_batch_dim(batch).x.shape)
(512,5)

This could be more beautifully achieved using dispatch. E.g.

from typing import NamedTuple

class Batch_1D(NamedTuple):
    x: jnp.ndarray

class Batch_2D(NamedTuple):
    x: jnp.ndarray

@jax.dispatch 
def flatten_batch_dim(batch: Batch_2D) -> Batch_1D:
    x = batch.x.reshape((-1,)+batch.x.shape[2:])
    return Batch_1D(x)

@jax.dispatch
def flatten_batch_dim(batch: Batch_1D) -> Batch_1D:
    return batch 
@simon-bachhuber simon-bachhuber added the enhancement New feature or request label Aug 20, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 20, 2022

If I understand correctly, I think the mechanism you have in mind is already provided by the built-in functools.singledispatch decorator. Here's an example:

import jax.numpy as jnp
from typing import Any, NamedTuple
from functools import singledispatch

class Batch_1D(NamedTuple):
    x: jnp.ndarray

class Batch_2D(NamedTuple):
    x: jnp.ndarray

@singledispatch
def flatten_batch_dim(batch: Any) -> Batch_1D:
    raise NotImplementedError(f"flatten_batch_dim for type {type(batch)}")

@flatten_batch_dim.register
def _(batch: Batch_2D) -> Batch_1D:
    x = batch.x.reshape((-1,)+batch.x.shape[2:])
    return Batch_1D(x)

@flatten_batch_dim.register
def _(batch: Batch_1D) -> Batch_1D:
    return batch 

batch = Batch_2D(x = jnp.ones((16,32,5)))
print(flatten_batch_dim(batch).x.shape)
# (512, 5)

batch = Batch_1D(x = jnp.ones((16,32,5)))
print(flatten_batch_dim(batch).x.shape)
# (16, 32, 5)

@simon-bachhuber
Copy link
Author

Ah poor example on my part.
I was more hoping for a jit-able version of e.g. fastcore-like dispatch, see https://fastcore.fast.ai/dispatch.html#typedispatch-decorator

It's quite powerful but unfortunately this implementation does not like jit ;)

@patrick-kidger
Copy link
Collaborator

Have you tried plum?

I've not tested to see how it works under jax.jit. But I think the argument to be made here is that function dispatching is really a Python thing, not a JAX thing.

(In passing -- there's been some offline discussion about trying to get plum working with jaxtyping, which would be pretty neat, and may solve your use case.)

@simon-bachhuber
Copy link
Author

Can confirm that plum (including multiple dispatch) works indeed perfectly under jit.
Very cool!

@simon-bachhuber
Copy link
Author

simon-bachhuber commented Sep 7, 2022

Unfortunately, both plum and functools can not handle single-dispatch (or multiple dispatch) when using a vmap-transformation.

Example:

@ft.singledispatch
def f(x):
    raise NotImplementedError()

@f.register 
def _(x: jnp.ndarray):
    print("Array-operation")

@jax.vmap 
def g(x: jnp.ndarray):
    return f(x)

g(jnp.ones((10,1)))

>>> NotImplementedError

Edit: *fixed typo

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 7, 2022

The reason this doesn't work is likely because jnp.ndarray is not actually in the class hierarchy of either arrays or tracers; you'd probably have to annotate with jax.DeviceArray and jax.core.Tracer directly if you want this dispatch method to work. We're exploring changing that in #11859, but regardless of the outcome there I'm still of the opinion that a multiple dispatch decorator is not a good fit to include in JAX.

@ASEM000
Copy link

ASEM000 commented Sep 11, 2022

according to jake this should work

@ft.singledispatch
def f(x):
    raise NotImplementedError()


@f.register(jnp.ndarray)
@f.register(jax.core.Tracer)
def _(x):
    print("Array-operation")

@jax.vmap 
def g(x: jnp.ndarray):
    return f(x)

g(jnp.ones((10,1))) # Array-operation

Also If you want multiple dispatch using only functools , try using the following code snippet
https://github.com/ASEM000/PyTreeClass/blob/main/pytreeclass/_src/dispatch.py

It seems that python functools implementation is a bit faster

# Tested on mac m1 CPU

from multipledispatch import dispatch as dispatch_md
from plum import dispatch as dispatch_plum
from functools import singledispatch as dispatch_std

@dispatch_md(int)
def f_md(x):
   return x


@dispatch_plum
def f_plum(x: int):
   return x


def f_native(x):
    return x

@dispatch_std
def f_std(x): ...

@f_std.register(int)
def _(x):
    return x

f_md(1); f_plum(1);  # Run once to populate cache.

%timeit f_native(1)
# 39.6 ns ± 0.629 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

%timeit f_md(1)
# 281 ns ± 2.08 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

%timeit f_plum(1)
# 337 ns ± 5.64 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

%timeit f_std(1)
# 267 ns ± 1.97 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

@simon-bachhuber
Copy link
Author

Yes i guess this does work. It is just a very unpretty amount of decorators which kind of defeats its purpose (improved readability) in the first place.
Either way, thank you all. I will close this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants