Skip to content

Get rid of floatX calls inside distributions #6675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Tracked by #7053
ricardoV94 opened this issue Apr 14, 2023 · 2 comments · Fixed by #7114
Closed
Tracked by #7053

Get rid of floatX calls inside distributions #6675

ricardoV94 opened this issue Apr 14, 2023 · 2 comments · Fixed by #7114

Comments

@ricardoV94
Copy link
Member

Description

These calls make inputs that are valid tensor_like fail:

import pymc as pm

with pm.Model() as m:
    x = pm.Normal("x")
    y = pm.Beta("y", [x+1, x-1], 1)  # Fails
    y = pm.Beta("y", pm.math.stack([x+1, x-1]), 1)  # works
TypeError: float() argument must be a string or a real number, not 'TensorVariable'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/home/ricardo/.conda/envs/pymc-experimental/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-8-9eb4432272c4>", line 5, in <module>
    y = pm.Beta("y", [x+1, x-1], 1)
  File "/home/ricardo/.conda/envs/pymc-experimental/lib/python3.10/site-packages/pymc/distributions/distribution.py", line 310, in __new__
    rv_out = cls.dist(*args, **kwargs)
  File "/home/ricardo/.conda/envs/pymc-experimental/lib/python3.10/site-packages/pymc/distributions/continuous.py", line 1142, in dist
    alpha = pt.as_tensor_variable(floatX(alpha))
  File "/home/ricardo/.conda/envs/pymc-experimental/lib/python3.10/site-packages/pymc/pytensorf.py", line 439, in floatX
    return np.asarray(X, dtype=pytensor.config.floatX)
ValueError: setting an array element with a sequence.

PyTensor has a builtin mechanism to control the promotion of inputs to float that should be accessible to users.

@jaharvey8
Copy link
Contributor

jaharvey8 commented Apr 15, 2023

I could help out here. But to make sure I understand what's needed, each distribution has a bunch of

"as_tensor_variable(floatX(something))"

which needs to be changed to

as_tensor_variable(something)

correct?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Apr 17, 2023

Yes, it may break some tests but we can assess on a case by case once that happens.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants