Open
Description
Jax jit requires static inputs for some of the function args (for example, shape
in jnp.reshape
, length
in jax.lax.scan
). Currently, if these are symbolic input it will break jax.jit
in
https://github.com/pymc-devs/Theano-PyMC/blob/a9275c3dcc998c8cca5719037e493809b23422ff/theano/sandbox/jax_linker.py#L80
I propose we add a property to TensorVariable
in:
diff --git a/theano/tensor/var.py b/theano/tensor/var.py
index 4cda4e5e1..6f2aaf398 100644
--- a/theano/tensor/var.py
+++ b/theano/tensor/var.py
@@ -872,6 +872,8 @@ class TensorVariable(_tensor_py_operators, Variable):
pdb.set_trace()
+ def is_static_jax(self):
+ return False
TensorType.Variable = TensorVariable
and SharedVariable
diff --git a/theano/compile/sharedvalue.py b/theano/compile/sharedvalue.py
index cc3dd3cce..ca3e7af3b 100644
--- a/theano/compile/sharedvalue.py
+++ b/theano/compile/sharedvalue.py
@@ -224,6 +224,9 @@ class SharedVariable(Variable):
# We keep this just to raise an error
value = property(_value_get, _value_set)
+ def is_static_jax(self):
+ return False
+
def shared_constructor(ctor, remove=False):
if remove:
Then we can detect the additional static_argnums
in:
diff --git a/theano/sandbox/jax_linker.py b/theano/sandbox/jax_linker.py
index 59b61caf3..0093c3fa7 100644
--- a/theano/sandbox/jax_linker.py
+++ b/theano/sandbox/jax_linker.py
@@ -62,7 +62,9 @@ class JAXLinker(PerformLinker):
# I suppose we can consider `Constant`s to be "static" according to
# JAX.
static_argnums = [
- n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
+ n
+ for n, i in enumerate(self.fgraph.inputs)
+ if isinstance(i, Constant) or i.is_static_jax
]
thunk_inputs = [storage_map[n] for n in self.fgraph.inputs]
For user, they will need to mark these variable by hand for now, for example, we can do the following to make the tests pass:
diff --git a/tests/sandbox/test_jax.py b/tests/sandbox/test_jax.py
index 89c46ff9b..c3c3d7225 100644
--- a/tests/sandbox/test_jax.py
+++ b/tests/sandbox/test_jax.py
@@ -534,10 +534,10 @@ def test_jax_Reshape():
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])
-@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_jax_Reshape_nonconcrete():
a = tt.vector("a")
b = tt.iscalar("b")
+ b.is_static_jax = True
x = tt.basic.reshape(a, (b, b))
x_fg = theano.gof.FunctionGraph([a, b], [x])
compare_jax_and_py(
@@ -666,10 +666,10 @@ def test_tensor_basics():
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
-@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_arange_nonconcrete():
a = tt.scalar("a")
+ a.is_static_jax = True
a.tag.test_value = 10
out = tt.arange(a)
@@ -677,7 +677,6 @@ def test_arange_nonconcrete():
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])