Skip to content

Issue FututureWarnings for deprecated test_value machinery #831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Dhruvanshu-Joshi
Copy link
Member

Description

Added FutureWarnings whenever test_value is accessed and also added pytest.warnings when tests use it.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Comment on lines 65 to 68
warnings.warn(
"compute_test_value is deprecated and will stop working in the future.",
FutureWarning,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too crude here?

Comment on lines 601 to 604
warnings.warn(
"test_value machinery is deprecated and will stop working in the future.",
FutureWarning,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These warnings show up in too many places (not talking about this file/line specifically)? Should only be when users set the config flag or manually assign a test_value / call a function that is specific about test values?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning is raised whenever a user tries to do a x.tag.test_value and whenever in the internal code the config.compute_test_value != "off". Are you suggesting that instead of raising the warning when internal code has config.compute_test_value != "off", I raise the warning when user explicitly sets config.compute_test_value != "off" ?
In this way, the warning will appear at less places within the code but will be appropriately raised.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we don't want it to become incredibly noisy. Just on the user entry points to tell them to stop using the machinery

@ricardoV94 ricardoV94 changed the title Raised FututreWarnings to aid test_value machinery deprecation Issue FututureWarnings for deprecated test_value machinery Jun 25, 2024
x.tag.test_value = np.zeros((2, 2))
y = dvector("y")
y.tag.test_value = [0, 0, 0, 0]
with pytest.warns(FutureWarning):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tests that are not specific about test_values we should just stop using the machinery

@@ -203,7 +212,7 @@ def test_get_test_values_success():
def test_get_test_values_exc():
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""

with pytest.raises(TestValueError):
with pytest.raises(TestValueError) and pytest.warns(FutureWarning):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you and contexts like this? I always nest them, but if it works fine. Did you test it/ see it working like this anywhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I think this might be wrong. The tests passed so I assumed it works but turns out with statements must be nested.

Comment on lines 192 to 193
with pytest.warns(FutureWarning):
a.tag.test_value = np.array(0.2, dtype=config.floatX)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests jax/numba don't need the tag stuff. I have no idea why they were written like this tbh

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are you suggesting deleting the line a.tag.test_value = np.array(0.2, dtype=config.floatX) from the test itself?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we are adding the value to the tag and then retrieving it when we call the function below. We can just pass it directly

Copy link
Member Author

@Dhruvanshu-Joshi Dhruvanshu-Joshi Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about cases like this?

with pytest.warns(FutureWarning):
        y = vector("y")
        y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
        x = vector("x")
        x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
        A = matrix("A")
        A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
        alpha = scalar("alpha")
        alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
        beta = scalar("beta")
        beta.tag.test_value = np.array(5.0, dtype=config.floatX)

    # This should be converted into a `Gemv` `Op` when the non-JAX compatible
    # optimizations are turned on; however, when using JAX mode, it should
    # leave the expression alone.
    out = y.dot(alpha * A).dot(x) + beta * y
    fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
    

Here, setting the test_value and then extracting the value from it actually comes handy when there are mutiple inputs to the function.

However we can also write this as:

with pytest.warns(FutureWarning):
        y = vector("y")
        x = vector("x")
        A = matrix("A")
        alpha = scalar("alpha")
        beta = scalar("beta")
        test_values = [np.r_[1.0, 2.0].astype(config.floatX), np.r_[3.0, 4.0].astype(config.floatX), np.empty((2, 2), dtype=config.floatX), np.array(3.0, dtype=config.floatX), np.array(5.0, dtype=config.floatX)]

    # This should be converted into a `Gemv` `Op` when the non-JAX compatible
    # optimizations are turned on; however, when using JAX mode, it should
    # leave the expression alone.
    out = y.dot(alpha * A).dot(x) + beta * y
    fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
    compare_jax_and_py(fgraph, [i for i in test_values])

which method should I go ahead with? @ricardoV94

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one without test values is perfectly fine. If you want you can have a dictionary from variables to values, if that seems more readable

@Dhruvanshu-Joshi Dhruvanshu-Joshi force-pushed the deprecate_test_val branch 3 times, most recently from e944855 to b45b107 Compare July 6, 2024 08:18
@Dhruvanshu-Joshi
Copy link
Member Author

HI @ricardoV94

FAILED tests/tensor/test_variable.py::test_numpy_method[arccos-0.5] - Failed: DID NOT WARN. No warnings of type (<class 'FutureWarning'>,) were emitted.
 Emitted warnings: []

Tests are failing in python3.12 with fast_compile==1. However, it works fine locally.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 9, 2024

HI @ricardoV94

FAILED tests/tensor/test_variable.py::test_numpy_method[arccos-0.5] - Failed: DID NOT WARN. No warnings of type (<class 'FutureWarning'>,) were emitted.
 Emitted warnings: []

Tests are failing in python3.12 with fast_compile==1. However, it works fine locally.

Maybe a different numpy version? Did you turn FAST_COMPILE to True at the top of the test file when trying locally?

@Dhruvanshu-Joshi
Copy link
Member Author

FAST_COMPILE

Yep the FAST_COMPILE was not at the top of the code so I was not able to reproduce this locally. Now I can.
Seems like the ___getattribute___ where I raise the error is not being called in case of FAST_COMPILE.
I have added the warnings in these functions when the name=='test_value":

# These two methods have been added to help Mypy
def __getattribute__(self, name):
return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any) -> None:
self.__dict__[name] = value

However, this line is never reached in the FAST_COMPILE mode.
I did try o debug this using breakpoints but there are too many function calls and I could not understand much on where to add the warnings in FAST_COMPILE mode. However, to the best of my knowledge, the test_value is not computed at all in the FAST_COMPILE mode and hence we do not get any warning.
How should I handle this?

@ricardoV94
Copy link
Member

I don't see why should test values be triggered in that test? They are not set explicitly? Is something inernally introducing the test value when FAST_COMPILE=False?

Because we don't want the warnings to be triggered by internal use, only when users themselves start the test value machinery. We need to make sure we, ourselves, don't start it accidentally.

@Dhruvanshu-Joshi
Copy link
Member Author

I don't see why should test values be triggered in that test? They are not set explicitly? Is something inernally introducing the test value when FAST_COMPILE=False?
Because we don't want the warnings to be triggered by internal use, only when users themselves start the test value machinery. We need to make sure we, ourselves, don't start it accidentally.

Yes the test_value are set internally in the test that fails using the compute_test_value in graph/op.py.
This is the exact trace which leads to compute_test_value in case when FAST_COMPILE=False.

tests\tensor\test_variable.py:85:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pytensor\compile\function\__init__.py:307: in function
    fn = pfunc(
pytensor\compile\function\pfunc.py:465: in pfunc
    return orig_function(
pytensor\compile\function\types.py:1753: in orig_function
    m = Maker(
pytensor\compile\function\types.py:1526: in __init__
    self.prepare_fgraph(inputs, outputs, found_updates, fgraph, mode, profile)
pytensor\compile\function\types.py:1414: in prepare_fgraph
    rewriter_profile = rewriter(fgraph)
pytensor\graph\rewriting\basic.py:127: in __call__
    return self.rewrite(fgraph)
pytensor\graph\rewriting\basic.py:123: in rewrite
    return self.apply(fgraph, *args, **kwargs)
pytensor\graph\rewriting\basic.py:293: in apply
    sub_prof = rewriter.apply(fgraph)
pytensor\graph\rewriting\basic.py:293: in apply
    sub_prof = rewriter.apply(fgraph)
pytensor\tensor\rewriting\elemwise.py:1026: in apply
    for inputs, outputs in find_next_fuseable_subgraph(fgraph):
pytensor\tensor\rewriting\elemwise.py:997: in find_next_fuseable_subgraph
    fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
pytensor\tensor\rewriting\elemwise.py:725: in initialize_fuseable_mappings
    and elemwise_scalar_op_has_c_code(out.owner)
pytensor\tensor\rewriting\elemwise.py:699: in elemwise_scalar_op_has_c_code
    if node.op.scalar_op.supports_c_code(node.inputs, node.outputs):
pytensor\scalar\basic.py:1225: in supports_c_code
    s_op = self(*tmp_s_input, return_list=True)
pytensor\graph\op.py:305: in __call__
    compute_test_value(node)

I understand that the warnings must be raised only when test_value is accessed explicitly.
Here is what I am doing:
Whenever users want to set test_value themselves using x.tag.test_value == val , a FutureWarning will be raised.
This happens using either:

def __setattr__(self, name: str, value: Any) -> None:
        if name == "test_value":
            warnings.warn(
                "test_value machinery is deprecated and will stop working in the future.",
                FutureWarning,
            )
        self.__dict__[name] = value

in

def __setattr__(self, name: str, value: Any) -> None:
self.__dict__[name] = value

or:

    def __setattr__(self, attr, obj):
        if getattr(self, "attr", None) == attr:
            if attr == "test_value":
                warnings.warn(
                    "test_value machinery is deprecated and will stop working in the future.",
                    FutureWarning,
                )
            obj = self.attr_filter(obj)

        return object.__setattr__(self, attr, obj)

in

def __setattr__(self, attr, obj):
if getattr(self, "attr", None) == attr:
obj = self.attr_filter(obj)
return object.__setattr__(self, attr, obj)

I also have added warnings in cases when the user wants to access the test_value explicitly using the helper get_test_value or compute_test_value. These helpers internally call the __getattribute__(self, name) function in graphs/utils. Hence warning is raised using:

def __getattribute__(self, name):
        if name == "test_value":
            warnings.warn(
                "test_value machinery is deprecated and will stop working in the future.",
                FutureWarning,
            )
        return super().__getattribute__(name)

in:

# These two methods have been added to help Mypy
def __getattribute__(self, name):
return super().__getattribute__(name)

However, there are internal codes which also use these functions ( in the case when FAST_COMPILE= False, compute_test_value is being called internally in the Op as provided in the trace). Hence the warnings appear everytime regardless of whether the user calls them or the internal code does. So should I drop the warning every time when the test_values are accessed (regardless of done by user or internal code) and only raise them only if someone is trying to set/modify the test_value?

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 10, 2024

We should disable this compute_test_default by default in the Op:

pytensor\graph\op.py:305: in __call__
    compute_test_value(node)

Should only be attempted if the config for test_value is enabled

@Dhruvanshu-Joshi
Copy link
Member Author

We should disable this compute_test_default by default in the Op:

pytensor\graph\op.py:305: in __call__
    compute_test_value(node)

Should only be attempted if the config for test_value is enabled

This line should've taken care of exactly that:

if config.compute_test_value != "off":
compute_test_value(node)

However our config.compute_test_value = "ignore". So the compute_test_value runs anyhow. Should I treat config.compute_test_value = "ignore" and config.compute_test_value = "off" similarly in the code everywhere?

@ricardoV94
Copy link
Member

We should make the default "off" instead of ignore perhaps

@Dhruvanshu-Joshi
Copy link
Member Author

We should make the default "off" instead of ignore perhaps

The default is "off" already. Here's the problem:
This line is being called internally which changes the config.compute_test_value to "ignore".

with config.change_flags(compute_test_value="ignore"):
s_op = self(*tmp_s_input, return_list=True)

@ricardoV94
Copy link
Member

Let's set that to off

@Dhruvanshu-Joshi
Copy link
Member Author

Let's set that to off

This solves it.
However, there are other places(these are in the tests) where test_value is set internally.
One such example is:

def set_pytensor_flags():
rewrites_query = RewriteDatabaseQuery(include=[None], exclude=[])
py_mode = Mode("py", rewrites_query)
with config.change_flags(mode=py_mode, compute_test_value="warn"):
yield

Here, test_value is set to warn. This is in the test_utils and hence the code does not use it anywhere else apart from the tests.
So should I keep this or replace the "warn" with "off" here also? I don't think so the test_value set here be of any significance on the user side since this is something in the tests itself.

@Dhruvanshu-Joshi Dhruvanshu-Joshi force-pushed the deprecate_test_val branch 2 times, most recently from 6d43bf4 to f519616 Compare July 10, 2024 14:02
@ricardoV94
Copy link
Member

Just get rid of the warn in that fixture

@Dhruvanshu-Joshi
Copy link
Member Author

Just get rid of the warn in that fixture

Can you elaborate more on this? This is is helper function ig so do you want me to replace

 with config.change_flags(mode=py_mode, compute_test_value="warn"):
        yield

to

with config.change_flags(mode=py_mode, compute_test_value="off"):
        yield

This appears at other places also:
image
Do you wan me to set evry instance of config.compute_test_value to "off"?

@ricardoV94
Copy link
Member

Just don't mention it in the change_flags, unless the test is specifically about test values

@Dhruvanshu-Joshi
Copy link
Member Author

Hi @ricardoV94 solved the errors and also some merge conflicts.

@ricardoV94
Copy link
Member

@Dhruvanshu-Joshi you seem to have accidentally pushed a coverage.xml file with many lines :)

@Dhruvanshu-Joshi
Copy link
Member Author

@Dhruvanshu-Joshi you seem to have accidentally pushed a coverage.xml file with many lines :)

My bad! Rectified it .

Copy link

codecov bot commented Jul 22, 2024

Codecov Report

Attention: Patch coverage is 75.00000% with 2 lines in your changes missing coverage. Please review.

Project coverage is 81.57%. Comparing base (981688c) to head (5515580).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #831      +/-   ##
==========================================
- Coverage   81.59%   81.57%   -0.03%     
==========================================
  Files         178      178              
  Lines       47236    47243       +7     
  Branches    11483    11484       +1     
==========================================
- Hits        38544    38537       -7     
- Misses       6504     6515      +11     
- Partials     2188     2191       +3     
Files Coverage Δ
pytensor/compile/sharedvalue.py 93.93% <100.00%> (+0.18%) ⬆️
pytensor/graph/basic.py 88.67% <100.00%> (+0.01%) ⬆️
pytensor/graph/op.py 87.95% <100.00%> (+0.06%) ⬆️
pytensor/scalar/basic.py 80.38% <100.00%> (ø)
pytensor/configdefaults.py 72.30% <33.33%> (-0.37%) ⬇️

... and 7 files with indirect coverage changes

Added pytest Future Warning in relavant tests

Removed and replaced usage of test_value in JAX/Numba tests
g = rv_op(*dist_params, size=(10000, *base_size), rng=rng)
g_fn = compile_random_function(dist_params, g, mode=jax_mode)
samples = g_fn(
*[
i.tag.test_value
for i in g_fn.maker.fgraph.inputs
test_values[i]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value for key, value in test_values.items() if not isinstance(key, ...)

@@ -443,7 +446,8 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
mode=no_mode,
)

arg_values = [p.get_test_value() for p in f_inputs]
with pytest.warns(FutureWarning):
arg_values = [p.get_test_value() for p in f_inputs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass the test values to the test function. Otherwise this test will stop working once we remove test values

@@ -627,7 +628,8 @@ def test_mvnormal_default_args():
@config.change_flags(compute_test_value="raise")
def test_mvnormal_ShapeFeature():
M_pt = iscalar("M")
M_pt.tag.test_value = 2
with pytest.warns(FutureWarning):
M_pt.tag.test_value = 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, let's stop using test_values in all these random tests

@@ -151,8 +151,6 @@
)


pytestmark = pytest.mark.filterwarnings("error")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not remove these pytestmarks

@@ -63,7 +63,8 @@


def set_test_value(x, v):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove helper

@ricardoV94
Copy link
Member

@Dhruvanshu-Joshi you seem to have accidentally pushed a coverage.xml file with many lines :)

My bad! Rectified it .

Still here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants