Skip to content

Issue FututureWarnings for deprecated test_value machinery #831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48,682 changes: 48,682 additions & 0 deletions coverage/coverage-.xml

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pytensor/compile/sharedvalue.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Provide a simple user friendly API to PyTensor-managed memory."""

import copy
import warnings
from contextlib import contextmanager
from functools import singledispatch
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -134,6 +135,10 @@ def set_value(self, new_value, borrow=False):
self.container.value = copy.deepcopy(new_value)

def get_test_value(self):
warnings.warn(
"test_value machinery is deprecated and will stop working in the future.",
FutureWarning,
)
return self.get_value(borrow=True, return_internal_type=True)

def clone(self, **kwargs):
Expand Down
7 changes: 7 additions & 0 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import socket
import sys
import textwrap
import warnings
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -1282,6 +1283,12 @@ def add_caching_dir_configvars():
else:
gcc_version_str = "GCC_NOT_FOUND"

if config.compute_test_value != "off":
warnings.warn(
"test_value machinery is deprecated and will stop working in the future.",
FutureWarning,
)

# TODO: The caching dir resolution is a procedural mess of helper functions, local variables
# and config definitions. And the result is also not particularly pretty..
add_caching_dir_configvars()
4 changes: 4 additions & 0 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,10 @@ def __init__(self, type: _TypeType, data: Any, name: str | None = None):
add_tag_trace(self)

def get_test_value(self):
warnings.warn(
"test_value machinery is deprecated and will stop working in the future.",
FutureWarning,
)
return self.data

def signature(self):
Expand Down
5 changes: 5 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,11 @@ def get_test_values(*args: Variable) -> Any | list[Any]:
if config.compute_test_value == "off":
return []

warnings.warn(
"test_value machinery is deprecated and will stop working in the future.",
FutureWarning,
)

rval = []

for i, arg in enumerate(args):
Expand Down
16 changes: 16 additions & 0 deletions pytensor/graph/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import linecache
import sys
import traceback
import warnings
from abc import ABCMeta
from collections.abc import Sequence
from io import StringIO
Expand Down Expand Up @@ -282,9 +283,19 @@ def info(self):

# These two methods have been added to help Mypy
def __getattribute__(self, name):
if name == "test_value":
warnings.warn(
"test_value machinery is deprecated and will stop working in the future.",
FutureWarning,
)
return super().__getattribute__(name)

def __setattr__(self, name: str, value: Any) -> None:
if name == "test_value":
warnings.warn(
"test_value machinery is deprecated and will stop working in the future.",
FutureWarning,
)
self.__dict__[name] = value


Expand All @@ -299,6 +310,11 @@ def __init__(self, attr, attr_filter):

def __setattr__(self, attr, obj):
if getattr(self, "attr", None) == attr:
if attr == "test_value":
warnings.warn(
"test_value machinery is deprecated and will stop working in the future.",
FutureWarning,
)
obj = self.attr_filter(obj)

return object.__setattr__(self, attr, obj)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ def supports_c_code(self, inputs, outputs):
tmp_s_input.append(tmp)
mapping[ii] = tmp_s_input[-1]

with config.change_flags(compute_test_value="ignore"):
with config.change_flags(compute_test_value="off"):
s_op = self(*tmp_s_input, return_list=True)

# if the scalar_op don't have a c implementation,
Expand Down
11 changes: 6 additions & 5 deletions tests/compile/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,12 @@ def test_infer_shape(self):

@config.change_flags(compute_test_value="raise")
def test_compute_test_value(self):
x = scalar("x")
x.tag.test_value = np.array(1.0, dtype=config.floatX)
op = OpFromGraph([x], [x**3])
y = scalar("y")
y.tag.test_value = np.array(1.0, dtype=config.floatX)
with pytest.warns(FutureWarning):
x = scalar("x")
x.tag.test_value = np.array(1.0, dtype=config.floatX)
op = OpFromGraph([x], [x**3])
y = scalar("y")
y.tag.test_value = np.array(1.0, dtype=config.floatX)
f = op(y)
grad_f = grad(f, y)
assert grad_f.tag.test_value is not None
Expand Down
4 changes: 0 additions & 4 deletions tests/compile/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def cumprod(x):

def test_2arg(self):
x = dmatrix("x")
x.tag.test_value = np.zeros((2, 2))
y = dvector("y")
y.tag.test_value = [0, 0, 0, 0]

@as_op([dmatrix, dvector], dvector)
def cumprod_plus(x, y):
Expand All @@ -49,9 +47,7 @@ def cumprod_plus(x, y):

def test_infer_shape(self):
x = dmatrix("x")
x.tag.test_value = np.zeros((2, 2))
y = dvector("y")
y.tag.test_value = [0, 0, 0, 0]

def infer_shape(fgraph, node, shapes):
x, y = shapes
Expand Down
63 changes: 37 additions & 26 deletions tests/graph/test_compute_test_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def perform(self, node, inputs, outputs):

test_input = SomeType()()
orig_object = object()
test_input.tag.test_value = orig_object
with pytest.warns(FutureWarning):
test_input.tag.test_value = orig_object

res = InplaceOp(False)(test_input)
assert res.tag.test_value is orig_object
Expand All @@ -76,10 +77,11 @@ def perform(self, node, inputs, outputs):
assert res.tag.test_value is not orig_object

def test_variable_only(self):
x = matrix("x")
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
y = matrix("y")
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
with pytest.warns(FutureWarning):
x = matrix("x")
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
y = matrix("y")
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)

# should work
z = dot(x, y)
Expand All @@ -88,14 +90,16 @@ def test_variable_only(self):
assert _allclose(f(x.tag.test_value, y.tag.test_value), z.tag.test_value)

# this test should fail
y.tag.test_value = np.random.random((6, 5)).astype(config.floatX)
with pytest.warns(FutureWarning):
y.tag.test_value = np.random.random((6, 5)).astype(config.floatX)
with pytest.raises(ValueError):
dot(x, y)

def test_compute_flag(self):
x = matrix("x")
y = matrix("y")
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
with pytest.warns(FutureWarning):
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)

# should skip computation of test value
with config.change_flags(compute_test_value="off"):
Expand All @@ -111,10 +115,11 @@ def test_compute_flag(self):
dot(x, y)

def test_string_var(self):
x = matrix("x")
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
y = matrix("y")
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
with pytest.warns(FutureWarning):
x = matrix("x")
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
y = matrix("y")
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)

z = pytensor.shared(np.random.random((5, 6)).astype(config.floatX))

Expand All @@ -134,7 +139,8 @@ def f(x, y, z):

def test_shared(self):
x = matrix("x")
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
with pytest.warns(FutureWarning):
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
y = pytensor.shared(np.random.random((4, 6)).astype(config.floatX), "y")

# should work
Expand Down Expand Up @@ -190,30 +196,31 @@ def test_constant(self):
def test_incorrect_type(self):
x = vector("x")
with pytest.raises(TypeError):
# Incorrect shape for test value
x.tag.test_value = np.empty((2, 2))
with pytest.warns(FutureWarning):
# Incorrect shape for test value
x.tag.test_value = np.empty((2, 2))

x = fmatrix("x")
with pytest.raises(TypeError):
# Incorrect dtype (float64) for test value
x.tag.test_value = np.random.random((3, 4))
with pytest.warns(FutureWarning):
# Incorrect dtype (float64) for test value
x.tag.test_value = np.random.random((3, 4))

def test_overided_function(self):
# We need to test those as they mess with Exception
# And we don't want the exception to be changed.
x = matrix()
x.tag.test_value = np.zeros((2, 3), dtype=config.floatX)
y = matrix()
y.tag.test_value = np.zeros((2, 2), dtype=config.floatX)
with pytest.raises(ValueError):
x.__mul__(y)

def test_scan(self):
# Test the compute_test_value mechanism Scan.
k = iscalar("k")
A = vector("A")
k.tag.test_value = 3
A.tag.test_value = np.random.random((5,)).astype(config.floatX)
with pytest.warns(FutureWarning):
k.tag.test_value = 3
A.tag.test_value = np.random.random((5,)).astype(config.floatX)

def fx(prior_result, A):
return prior_result * A
Expand All @@ -233,8 +240,9 @@ def test_scan_err1(self):
# This test should fail when building fx for the first time
k = iscalar("k")
A = matrix("A")
k.tag.test_value = 3
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
with pytest.warns(FutureWarning):
k.tag.test_value = 3
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)

def fx(prior_result, A):
return dot(prior_result, A)
Expand All @@ -253,8 +261,9 @@ def test_scan_err2(self):
# but when calling the scan's perform()
k = iscalar("k")
A = matrix("A")
k.tag.test_value = 3
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
with pytest.warns(FutureWarning):
k.tag.test_value = 3
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)

def fx(prior_result, A):
return dot(prior_result, A)
Expand Down Expand Up @@ -288,7 +297,8 @@ def perform(self, node, inputs, outputs):
output[0] = input + 1

i = ps.int32("i")
i.tag.test_value = 3
with pytest.warns(FutureWarning):
i.tag.test_value = 3

o = IncOnePython()(i)

Expand All @@ -304,7 +314,8 @@ def perform(self, node, inputs, outputs):
)
def test_no_perform(self):
i = ps.int32("i")
i.tag.test_value = 3
with pytest.warns(FutureWarning):
i.tag.test_value = 3

# Class IncOneC is defined outside of the TestComputeTestValue
# so it can be pickled and unpickled
Expand Down
2 changes: 0 additions & 2 deletions tests/graph/test_destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest

from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable, clone
from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import ReplaceValidate
Expand Down Expand Up @@ -408,7 +407,6 @@ def test_value_repl():
assert g.consistent()


@config.change_flags(compute_test_value="off")
def test_value_repl_2():
x, y, z = inputs()
sy = sigmoid(y)
Expand Down
25 changes: 13 additions & 12 deletions tests/graph/test_fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,19 @@ def test_change_input(self):

@config.change_flags(compute_test_value="raise")
def test_replace_test_value(self):
var1 = MyVariable("var1")
var1.tag.test_value = 1
var2 = MyVariable("var2")
var2.tag.test_value = 2
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var4.tag.test_value = np.array([1, 2])
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

var6 = op3()
var6.tag.test_value = np.array(0)
with pytest.warns(FutureWarning):
var1 = MyVariable("var1")
var1.tag.test_value = 1
var2 = MyVariable("var2")
var2.tag.test_value = 2
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var4.tag.test_value = np.array([1, 2])
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)

var6 = op3()
var6.tag.test_value = np.array(0)

assert var6.tag.test_value.shape != var4.tag.test_value.shape

Expand Down
Loading
Loading