Skip to content

Deprecate test_value machinery #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue Sep 21, 2023 · 11 comments
Open

Deprecate test_value machinery #447

ricardoV94 opened this issue Sep 21, 2023 · 11 comments
Labels
help wanted Extra attention is needed maintenance

Comments

@ricardoV94
Copy link
Member

Description

This adds a lot of complexity for little user benefit (most don't know about this functionality in the first place)

@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 10, 2023

A first pass would be to just put FutureWarnings in the right places, when flags are changed or test values accessed from tags.

As a smaller scope alternative, it would be nice to have a helper that computes intermediate values for all variables in a graph so we can show them in dprint.

Something like:

def eval_intermediate_values(
  variables: Union[Sequence[Variable], FunctionGraph],
  vars_to_values: Mapping[Variable, Any],
) -> Mapping[Variable, Any] :

For instance

x = pt.scalar("x")
y = x - 1
z = pt.log(y)

eval_intermediate_values(z, {x: 0.5})
# {x: 0.5, y: -0.5, z: nan}

@ferrine
Copy link
Member

ferrine commented Dec 11, 2023

A first pass would be to just put FutureWarnings in the right places, when flags are changed or test values accessed from tags.

As a smaller scope alternative, it would be nice to have a helper that computes intermediate values for all variables in a graph so we can show them in dprint.

Something like:

def eval_intermediate_values(
  variables: Union[Sequence[Variable], FunctionGraph],
  vars_to_values: Mapping[Variable, Any],
) -> Mapping[Variable, Any] :

For instance

x = pt.scalar("x")
y = x - 1
z = pt.log(y)

eval_intermediate_values(z, {x: 0.5})
# {x: 0.5, y: -0.5, z: nan}

Can this approach lead to out of memory in some scenarios?

@ferrine
Copy link
Member

ferrine commented Dec 11, 2023

Moreover the eval_intermediate_values seem to be useless since for small graphs you can eval by hand and for large graphs there is no idea what is a variable corresponding to and what leads to nans

@ricardoV94
Copy link
Member Author

Moreover the eval_intermediate_values seem to be useless since for small graphs you can eval by hand and for large graphs there is no idea what is a variable corresponding to and what leads to nans

The idea of that is that you can see it in dprint, which you can already with test values. That's useful because it shows which operations produced nans

@ricardoV94
Copy link
Member Author

Can this approach lead to out of memory in some scenarios?

This wouldn't take up more memory than the current test value approach so I don't think it's an important concern

@ricardoV94
Copy link
Member Author

The idea of that is that you can see it in dprint, which you can already with test values. That's useful because it shows which operations produced nans

Apparently I imagined that functionality. I still think it could be worth exploring but can be done in a separate issue.

@ricardoV94
Copy link
Member Author

Here is the kind of thing I had in mind:

import numpy as np
import pytensor
import pytensor.tensor as pt

pytensor.config.compute_test_value = "warn"
x = pt.vector("x")
x.tag.test_value = np.array([1, -2, 3])
y = pt.exp(pt.log(pt.tanh(x * 2)) + 3).sum()

pytensor.dprint(y)
# Sum{axes=None} [id A]nan
#  └─ Exp [id B][19.36301155         nan 20.08529011]
#     └─ Add [id C][2.96336463        nan 2.99998771]
#        ├─ Log [id D][-3.66353747e-02             nan -1.22884247e-05]
#        │  └─ Tanh [id E][ 0.96402758 -0.9993293   0.99998771]
#        │     └─ Mul [id F][ 2. -4.  6.]
#        │        ├─ x [id G][ 1. -2.  3.]
#        │        └─ ExpandDims{axis=0} [id H][2]
#        │           └─ 2 [id I]
#        └─ ExpandDims{axis=0} [id J][3]
#           └─ 3 [id K]

@ricardoV94
Copy link
Member Author

Here is another idea about providing more useful test_value-like machinery, that need not be so ingrained in the PyTensor codebase: https://gist.github.com/ricardoV94/e8902b4c35c26e87e189ab477f8d9288

@Dhruvanshu-Joshi
Copy link
Member

Hi @ricardoV94
So in the following lines:

if config.compute_test_value != "off":
compute_test_value(node)

will we want to add a warning whenever config.compute_test_value != "off" or in cases whenever compute_test_value is called?

Going by the first logic:

action = config.compute_test_value
if action == "raise":
raise TestValueError(msg)
elif action == "warn":
warnings.warn(msg, stacklevel=2)
else:
assert action in ("ignore", "off")

we will raise a warning here only when assert action in ("ignore", "off") fails right (I am not sure we should raise a warning when assert itself fails?)

@Dhruvanshu-Joshi
Copy link
Member

Dhruvanshu-Joshi commented Jun 13, 2024

I think the warnings make sense when config.compute_test_value!="off" and when tag.test_value is accessed ( which automatically includes the get_test_value cases). Is this correct?

@ricardoV94
Copy link
Member Author

Yes, that's probably a good start. Then the challenging part is making sure we don't use those anywhere internally, other than direct tests of the test_value machinery (and those we just put a with pytest.warns everytime we expect the functionality to be used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed maintenance
Projects
None yet
Development

No branches or pull requests

3 participants