-
Notifications
You must be signed in to change notification settings - Fork 129
Issue FututureWarnings
for deprecated test_value
machinery
#831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Issue FututureWarnings
for deprecated test_value
machinery
#831
Conversation
pytensor/misc/pkl_utils.py
Outdated
warnings.warn( | ||
"compute_test_value is deprecated and will stop working in the future.", | ||
FutureWarning, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems too crude here?
pytensor/scan/basic.py
Outdated
warnings.warn( | ||
"test_value machinery is deprecated and will stop working in the future.", | ||
FutureWarning, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These warnings show up in too many places (not talking about this file/line specifically)? Should only be when users set the config flag or manually assign a test_value / call a function that is specific about test values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning is raised whenever a user tries to do a x.tag.test_value
and whenever in the internal code the config.compute_test_value != "off"
. Are you suggesting that instead of raising the warning when internal code has config.compute_test_value != "off"
, I raise the warning when user explicitly sets config.compute_test_value != "off"
?
In this way, the warning will appear at less places within the code but will be appropriately raised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we don't want it to become incredibly noisy. Just on the user entry points to tell them to stop using the machinery
6f9352c
to
8b740f6
Compare
FututreWarnings
to aid test_value
machinery deprecationFututureWarnings
for deprecated test_value
machinery
tests/compile/test_ops.py
Outdated
x.tag.test_value = np.zeros((2, 2)) | ||
y = dvector("y") | ||
y.tag.test_value = [0, 0, 0, 0] | ||
with pytest.warns(FutureWarning): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For tests that are not specific about test_values we should just stop using the machinery
tests/graph/test_op.py
Outdated
@@ -203,7 +212,7 @@ def test_get_test_values_success(): | |||
def test_get_test_values_exc(): | |||
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing.""" | |||
|
|||
with pytest.raises(TestValueError): | |||
with pytest.raises(TestValueError) and pytest.warns(FutureWarning): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you and
contexts like this? I always nest them, but if it works fine. Did you test it/ see it working like this anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I think this might be wrong. The tests passed so I assumed it works but turns out with
statements must be nested.
tests/link/jax/test_basic.py
Outdated
with pytest.warns(FutureWarning): | ||
a.tag.test_value = np.array(0.2, dtype=config.floatX) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests jax/numba don't need the tag stuff. I have no idea why they were written like this tbh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So are you suggesting deleting the line a.tag.test_value = np.array(0.2, dtype=config.floatX)
from the test itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we are adding the value to the tag and then retrieving it when we call the function below. We can just pass it directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about cases like this?
with pytest.warns(FutureWarning):
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
alpha = scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
beta = scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Here, setting the test_value
and then extracting the value from it actually comes handy when there are mutiple inputs to the function.
However we can also write this as:
with pytest.warns(FutureWarning):
y = vector("y")
x = vector("x")
A = matrix("A")
alpha = scalar("alpha")
beta = scalar("beta")
test_values = [np.r_[1.0, 2.0].astype(config.floatX), np.r_[3.0, 4.0].astype(config.floatX), np.empty((2, 2), dtype=config.floatX), np.array(3.0, dtype=config.floatX), np.array(5.0, dtype=config.floatX)]
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_jax_and_py(fgraph, [i for i in test_values])
which method should I go ahead with? @ricardoV94
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The one without test values is perfectly fine. If you want you can have a dictionary from variables to values, if that seems more readable
e944855
to
b45b107
Compare
HI @ricardoV94
Tests are failing in python3.12 with fast_compile==1. However, it works fine locally. |
Maybe a different numpy version? Did you turn FAST_COMPILE to True at the top of the test file when trying locally? |
Yep the FAST_COMPILE was not at the top of the code so I was not able to reproduce this locally. Now I can. pytensor/pytensor/graph/utils.py Lines 284 to 290 in 10f285a
However, this line is never reached in the FAST_COMPILE mode. |
I don't see why should test values be triggered in that test? They are not set explicitly? Is something inernally introducing the test value when FAST_COMPILE=False? Because we don't want the warnings to be triggered by internal use, only when users themselves start the test value machinery. We need to make sure we, ourselves, don't start it accidentally. |
Yes the
I understand that the warnings must be raised only when test_value is accessed explicitly.
in pytensor/pytensor/graph/utils.py Lines 288 to 289 in 223b739
or:
in pytensor/pytensor/graph/utils.py Lines 301 to 305 in 223b739
I also have added warnings in cases when the user wants to access the test_value explicitly using the helper
in: pytensor/pytensor/graph/utils.py Lines 284 to 286 in 223b739
However, there are internal codes which also use these functions ( in the case when FAST_COMPILE= False, compute_test_value is being called internally in the Op as provided in the trace). Hence the warnings appear everytime regardless of whether the user calls them or the internal code does. So should I drop the warning every time when the test_values are accessed (regardless of done by user or internal code) and only raise them only if someone is trying to set/modify the test_value ?
|
We should disable this
Should only be attempted if the config for |
This line should've taken care of exactly that: Lines 303 to 304 in ee4d4f7
However our |
We should make the default "off" instead of ignore perhaps |
The default is "off" already. Here's the problem: pytensor/pytensor/scalar/basic.py Lines 1225 to 1226 in ee4d4f7
|
Let's set that to |
This solves it. pytensor/tests/tensor/random/test_utils.py Lines 17 to 21 in ee4d4f7
Here, test_value is set to warn. This is in the |
6d43bf4
to
f519616
Compare
Just get rid of the warn in that fixture |
f519616
to
1080228
Compare
Just don't mention it in the change_flags, unless the test is specifically about test values |
debd572
to
716e365
Compare
Hi @ricardoV94 solved the errors and also some merge conflicts. |
716e365
to
1d76195
Compare
@Dhruvanshu-Joshi you seem to have accidentally pushed a coverage.xml file with many lines :) |
1d76195
to
5515580
Compare
My bad! Rectified it . |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #831 +/- ##
==========================================
- Coverage 81.59% 81.57% -0.03%
==========================================
Files 178 178
Lines 47236 47243 +7
Branches 11483 11484 +1
==========================================
- Hits 38544 38537 -7
- Misses 6504 6515 +11
- Partials 2188 2191 +3
|
5515580
to
b3f5d09
Compare
Added pytest Future Warning in relavant tests Removed and replaced usage of test_value in JAX/Numba tests
b3f5d09
to
c94b4de
Compare
c94b4de
to
7df44de
Compare
1897c25
to
92f53e2
Compare
g = rv_op(*dist_params, size=(10000, *base_size), rng=rng) | ||
g_fn = compile_random_function(dist_params, g, mode=jax_mode) | ||
samples = g_fn( | ||
*[ | ||
i.tag.test_value | ||
for i in g_fn.maker.fgraph.inputs | ||
test_values[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value for key, value in test_values.items() if not isinstance(key, ...)
@@ -443,7 +446,8 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): | |||
mode=no_mode, | |||
) | |||
|
|||
arg_values = [p.get_test_value() for p in f_inputs] | |||
with pytest.warns(FutureWarning): | |||
arg_values = [p.get_test_value() for p in f_inputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass the test values to the test function. Otherwise this test will stop working once we remove test values
@@ -627,7 +628,8 @@ def test_mvnormal_default_args(): | |||
@config.change_flags(compute_test_value="raise") | |||
def test_mvnormal_ShapeFeature(): | |||
M_pt = iscalar("M") | |||
M_pt.tag.test_value = 2 | |||
with pytest.warns(FutureWarning): | |||
M_pt.tag.test_value = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, let's stop using test_values in all these random tests
@@ -151,8 +151,6 @@ | |||
) | |||
|
|||
|
|||
pytestmark = pytest.mark.filterwarnings("error") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not remove these pytestmarks
@@ -63,7 +63,8 @@ | |||
|
|||
|
|||
def set_test_value(x, v): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove helper
Still here |
Description
Added
FutureWarnings
whenevertest_value
is accessed and also addedpytest.warnings
when tests use it.Related Issue
test_value
machinery #447Checklist
Type of change