Skip to content

HalfNormal in JAX failing due to implicit downcasting of constant 0d TensorVariable to float #373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jessegrabowski opened this issue Jul 4, 2023 · 3 comments · Fixed by #374

Comments

@jessegrabowski
Copy link
Member

jessegrabowski commented Jul 4, 2023

Describe the issue:

You can't forward sample from a half-normal distribution in JAX mode

Reproduceable code example:

import pymc as pm
from pymc.pytensorf import get_mode

with pm.Model() as mod:
    x = pm.HalfNormal('x')
    prior = pm.sample_prior_predictive(compile_kwargs={'mode':get_mode('JAX')})

Error message:

AttributeError                            Traceback (most recent call last)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
    199 for thunk, node, old_storage in zip(
    200     thunks, order, post_thunk_old_storage
    201 ):
--> 202     thunk()
    203     for old_s in old_storage:

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663 def thunk(
    664     fgraph=self.fgraph,
    665     fgraph_jit=fgraph_jit,
    666     thunk_inputs=thunk_inputs,
    667     thunk_outputs=thunk_outputs,
    668 ):
--> 669     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    671     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

    [... skipping hidden 12 frame]

File /tmp/tmp3jdhf32h:3, in jax_funcified_fgraph(random_generator_shared_variable)
      1 def jax_funcified_fgraph(random_generator_shared_variable):
      2     # Second(0.0, 0.0)
----> 3     tensor_variable = second(tensor_constant, tensor_constant_1)
      4     # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FD6A9F4D8C0>), [], 11, Second.0, 1.0)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scalar.py:184, in jax_funcify_Second.<locals>.second(x, y)
    183 def second(x, y):
--> 184     return jnp.broadcast_to(y, x.shape)

AttributeError: 'float' object has no attribute 'shape'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
Cell In[54], line 3
      1 with pm.Model() as mod:
      2     x = pm.HalfNormal('x')
----> 3     prior = pm.sample_prior_predictive(compile_kwargs={'mode':get_mode('JAX')})

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/forward.py:425, in sample_prior_predictive(samples, model, var_names, random_seed, return_inferencedata, idata_kwargs, compile_kwargs)
    423 # All model variables have a name, but mypy does not know this
    424 _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}")  # type: ignore
--> 425 values = zip(*(sampler_fn() for i in range(samples)))
    427 data = {k: np.stack(v) for k, v in zip(names, values)}
    428 if data is None:

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/forward.py:425, in <genexpr>(.0)
    423 # All model variables have a name, but mypy does not know this
    424 _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}")  # type: ignore
--> 425 values = zip(*(sampler_fn() for i in range(samples)))
    427 data = {k: np.stack(v) for k, v in zip(names, values)}
    428 if data is None:

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    967 t0_fn = time.perf_counter()
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:
    975     restore_defaults()

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:206, in streamline.<locals>.streamline_default_f()
    204             old_s[0] = None
    205 except Exception:
--> 206     raise_with_op(fgraph, node, thunk)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    530     warnings.warn(
    531         f"{exc_type} error does not allow us to add an extra error message"
    532     )
    533     # Some exception need extra parameter in inputs. So forget the
    534     # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
    198 try:
    199     for thunk, node, old_storage in zip(
    200         thunks, order, post_thunk_old_storage
    201     ):
--> 202         thunk()
    203         for old_s in old_storage:
    204             old_s[0] = None

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663 def thunk(
    664     fgraph=self.fgraph,
    665     fgraph_jit=fgraph_jit,
    666     thunk_inputs=thunk_inputs,
    667     thunk_outputs=thunk_outputs,
    668 ):
--> 669     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    671     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
    672         compute_map[o_var][0] = True

    [... skipping hidden 12 frame]

File /tmp/tmp3jdhf32h:3, in jax_funcified_fgraph(random_generator_shared_variable)
      1 def jax_funcified_fgraph(random_generator_shared_variable):
      2     # Second(0.0, 0.0)
----> 3     tensor_variable = second(tensor_constant, tensor_constant_1)
      4     # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FD6A9F4D8C0>), [], 11, Second.0, 1.0)
      5     variable, tensor_variable_1 = sample_fn(random_generator_shared_variable, tensor_constant_2, tensor_constant_3, tensor_variable, tensor_constant_4)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scalar.py:184, in jax_funcify_Second.<locals>.second(x, y)
    183 def second(x, y):
--> 184     return jnp.broadcast_to(y, x.shape)

AttributeError: 'float' object has no attribute 'shape'
Apply node that caused the error: Add(Abs.0, 0.0)
Toposort index: 3
Inputs types: [TensorType(float64, shape=()), TensorType(float32, shape=())]
Inputs shapes: ['No shapes']
Inputs strides: ['No strides']
Inputs values: [{'bit_generator': 1, 'state': {'state': -4621532023338195650, 'inc': 8471148850022962065}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([3218933020, 1456842046], dtype=uint32)}]
Outputs clients: [['output']]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

PyMC version information:

PyMC Version: 5.5.0 Pytensor version: 2.12.3

Context for the issue:

No response

@jessegrabowski jessegrabowski added the bug Something isn't working label Jul 4, 2023
@ricardoV94 ricardoV94 transferred this issue from pymc-devs/pymc Jul 7, 2023
@ricardoV94
Copy link
Member

ricardoV94 commented Jul 7, 2023

Were calling broadcast_to on a scalar in JAX it seems.

Probably has to do with how in JAX we don't respect the scalar vs 0d array case (does JAX even allow it)?

@ricardoV94 ricardoV94 added the jax label Jul 7, 2023
@ricardoV94
Copy link
Member

Something like this also showed up in #372

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 7, 2023

The solution is to either revert 4235ccc or tweak the JAX Broadcast funcify to work with scalars.

The source of the problem is actually silly, we have a BroadcastTo(np.array(0), np.array(0).shape) in the graph that should be optimized away. I think this specific case may no longer be introduced after #361

The compiled graph looks like

Add [id A] <Scalar(float64, shape=())> 3
 ├─ Abs [id B] <Scalar(float64, shape=())> 2
 │  └─ normal_rv{0, (0, 0), floatX, False}.1 [id C] <Scalar(float64, shape=())> 1
 │     ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F582FC4A820>) [id D] <RandomGeneratorType>
 │     ├─ [] [id E] <Vector(int64, shape=(0,))>
 │     ├─ 11 [id F] <Scalar(int64, shape=())>
 │     ├─ Second [id G] <Scalar(int8, shape=())> 0
 │     │  ├─ 0 [id H] <Scalar(int8, shape=())>
 │     │  └─ 0 [id I] <Scalar(int8, shape=())>
 │     └─ 1 [id J] <Scalar(int8, shape=())>
 └─ 0 [id H] <Scalar(int8, shape=())>

Second is the broadcast operation

Just not spinning up a PR immediately because we should probably discuss whether we want to implicitly downcast constant 0d array to float/integers in JAX or keep types consistent. We could always tweak the specific Op dispatch functions to handle Constant TensorScalarVariable in a special way.

This touches on a more general question of handling scalars in our graphs, that also applies to other backends. See #107 and #349

@ricardoV94 ricardoV94 changed the title BUG: Forward sampling fails for half-normal distribution when mode="JAX" HalfNormal in JAX failing due to implicit downcasting of constant 0d TensorVariable to float Jul 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants