Skip to content

Add a is_static_jax property to TensorVariable's tag #182

Open
@junpenglao

Description

@junpenglao

Jax jit requires static inputs for some of the function args (for example, shape in jnp.reshape, length in jax.lax.scan). Currently, if these are symbolic input it will break jax.jit in
https://github.com/pymc-devs/Theano-PyMC/blob/a9275c3dcc998c8cca5719037e493809b23422ff/theano/sandbox/jax_linker.py#L80

I propose we add a property to TensorVariable in:

diff --git a/theano/tensor/var.py b/theano/tensor/var.py
index 4cda4e5e1..6f2aaf398 100644
--- a/theano/tensor/var.py
+++ b/theano/tensor/var.py
@@ -872,6 +872,8 @@ class TensorVariable(_tensor_py_operators, Variable):
 
                 pdb.set_trace()
 
+    def is_static_jax(self):
+        return False
 
 TensorType.Variable = TensorVariable

and SharedVariable

diff --git a/theano/compile/sharedvalue.py b/theano/compile/sharedvalue.py
index cc3dd3cce..ca3e7af3b 100644
--- a/theano/compile/sharedvalue.py
+++ b/theano/compile/sharedvalue.py
@@ -224,6 +224,9 @@ class SharedVariable(Variable):
     # We keep this just to raise an error
     value = property(_value_get, _value_set)
 
+    def is_static_jax(self):
+        return False
+
 
 def shared_constructor(ctor, remove=False):
     if remove:

Then we can detect the additional static_argnums in:

diff --git a/theano/sandbox/jax_linker.py b/theano/sandbox/jax_linker.py
index 59b61caf3..0093c3fa7 100644
--- a/theano/sandbox/jax_linker.py
+++ b/theano/sandbox/jax_linker.py
@@ -62,7 +62,9 @@ class JAXLinker(PerformLinker):
         # I suppose we can consider `Constant`s to be "static" according to
         # JAX.
         static_argnums = [
-            n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
+            n
+            for n, i in enumerate(self.fgraph.inputs)
+            if isinstance(i, Constant) or i.is_static_jax
         ]
 
         thunk_inputs = [storage_map[n] for n in self.fgraph.inputs]

For user, they will need to mark these variable by hand for now, for example, we can do the following to make the tests pass:

diff --git a/tests/sandbox/test_jax.py b/tests/sandbox/test_jax.py
index 89c46ff9b..c3c3d7225 100644
--- a/tests/sandbox/test_jax.py
+++ b/tests/sandbox/test_jax.py
@@ -534,10 +534,10 @@ def test_jax_Reshape():
     compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])
 
 
-@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
 def test_jax_Reshape_nonconcrete():
     a = tt.vector("a")
     b = tt.iscalar("b")
+    b.is_static_jax = True
     x = tt.basic.reshape(a, (b, b))
     x_fg = theano.gof.FunctionGraph([a, b], [x])
     compare_jax_and_py(
@@ -666,10 +666,10 @@ def test_tensor_basics():
     compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
 
 
-@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
 def test_arange_nonconcrete():
 
     a = tt.scalar("a")
+    a.is_static_jax = True
     a.tag.test_value = 10
 
     out = tt.arange(a)
@@ -677,7 +677,6 @@ def test_arange_nonconcrete():
     compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

Metadata

Metadata

Labels

JAXInvolves JAX transpilationenhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions