-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
import genjax
import jax
from genjax import ChoiceMapBuilder as C
import jax.numpy as jnp
@genjax.Pytree.dataclass
class Wrapped(genjax.Pytree):
value: any
def my_logpdf(obs: Wrapped, arg):
print(obs.value.dtype)
return jnp.where(obs.value == arg, arg, -jnp.inf)
mydist = genjax.exact_density(
lambda key, val: Wrapped(val),
my_logpdf
)
@genjax.gen
def foo(args):
vals = mydist.vmap(in_axes=(0,))(args) @ "vals"
tr, wt = foo.importance(
jax.random.key(1),
C["vals", jnp.arange(2)].set(Wrapped(jnp.arange(2))),
(jnp.arange(5),)
)This results in
Cell In[52], [line 8](vscode-notebook-cell:?execution_count=52&line=8)
[7](vscode-notebook-cell:?execution_count=52&line=7) def my_logpdf(obs: Wrapped, arg):
----> [8](vscode-notebook-cell:?execution_count=52&line=8) print(obs.value.dtype)
[9](vscode-notebook-cell:?execution_count=52&line=9) return jnp.where(obs.value == arg, arg, -jnp.inf)
AttributeError: 'Mask' object has no attribute 'dtype'Other error messages occur if I remove the print(obs.value.dtype) line of code.
Metadata
Metadata
Assignees
Labels
No labels