Replies: 5 comments 4 replies
-
Hi, Not an answer but I have a similar issue. I found (and was hoping) jaxopt.bisection solved my need for a way to find a slightly complicated root that is compatible with jax.vmap and jax.grad and next jax.hessian. I am told:, " Most solvers in JAXopt don't work out of the box with scalars (maybe they should?) So typically, ones needs to use an array of size 1 instead. " and I need to find a work around or a bit of clarity. Indeed jaxopt.bisection produced the correct root individually for multiple tests but later I tried to vmap a a broader code and it turns out my problem lies with data types and jaxopt.bisection. At this stage of my code mostly everything is a jax device, dtype = jnp.float64 array including single valued jax arrays. What was thought to work did amazingly well on a single datapoint basis:
Turns out .item() is not compatible with jax tracing and applying .vmap or .grad. Using .item() succeeds in a single case basis. Alternatively, not using .item() and sending as jnp.array to Bisection causes boolean errors. For single use cases, all of the following are jax device arrays with a single float 64 entry. When applying vmap, I am expanding over a jax array of regime values to return an array of answers. Struggling to find a way past this apparent data type problem in using jaxopt.bisection. Thanks, |
Beta Was this translation helpful? Give feedback.
-
Half my problem is fixed. For the single case this works with removing the d-type on t_initial which now permits alt_pytree to accept 0-dim jax_arrays (scalars).
|
Beta Was this translation helpful? Give feedback.
-
assigning vmap returns error.
state = self.init_state(init_params, *args, **kwargs) File "/opt/anaconda3/lib/python3.8/site-packages/jaxopt/_src/bisection.py", line 105, in init_state File "/opt/anaconda3/lib/python3.8/site-packages/jax/core.py", line 634, in bool File "/opt/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1267, in error ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<BatchTrace(level=1/1)> with |
Beta Was this translation helpful? Give feedback.
-
Restructured and rewrote the entire code block to apply jax.vmap to a code block containing jaxopt.bisection isn a more direct way. I still get the same concretization boolean error from within jaxopt.bisection. The Bisection solver works for a single input set and appears to not work with vmap. |
Beta Was this translation helpful? Give feedback.
-
@yyang97 Indeed, second-order derivatives via implicit diff are not supported yet in JAXopt :( I'm curious, what is your use case? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
When my project is involved with calculating the Hessian of the root finding, I just found that
jaxopt.Bisection()
does not support that...Do I implement it correctly or it just does not support Hessian?
Beta Was this translation helpful? Give feedback.
All reactions