Skip to content

Commit b45b107

Browse files
Raise warnings when test_val is accessed
Added pytest Future Warning in relavant tests Removed and replaced usage of test_value in JAX/Numba tests
1 parent 31bf682 commit b45b107

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+881
-623
lines changed

pytensor/compile/sharedvalue.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Provide a simple user friendly API to PyTensor-managed memory."""
22

33
import copy
4+
import warnings
45
from contextlib import contextmanager
56
from functools import singledispatch
67
from typing import TYPE_CHECKING
@@ -134,6 +135,10 @@ def set_value(self, new_value, borrow=False):
134135
self.container.value = copy.deepcopy(new_value)
135136

136137
def get_test_value(self):
138+
warnings.warn(
139+
"test_value machinery is deprecated and will stop working in the future.",
140+
FutureWarning,
141+
)
137142
return self.get_value(borrow=True, return_internal_type=True)
138143

139144
def clone(self, **kwargs):

pytensor/configdefaults.py

+7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import socket
77
import sys
88
import textwrap
9+
import warnings
910

1011
import numpy as np
1112
from setuptools._distutils.spawn import find_executable
@@ -1447,6 +1448,12 @@ def add_caching_dir_configvars():
14471448
else:
14481449
gcc_version_str = "GCC_NOT_FOUND"
14491450

1451+
if config.compute_test_value != "off":
1452+
warnings.warn(
1453+
"test_value machinery is deprecated and will stop working in the future.",
1454+
FutureWarning,
1455+
)
1456+
14501457
# TODO: The caching dir resolution is a procedural mess of helper functions, local variables
14511458
# and config definitions. And the result is also not particularly pretty..
14521459
add_caching_dir_configvars()

pytensor/graph/basic.py

+4
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,10 @@ def __init__(self, type: _TypeType, data: Any, name: str | None = None):
784784
add_tag_trace(self)
785785

786786
def get_test_value(self):
787+
warnings.warn(
788+
"test_value machinery is deprecated and will stop working in the future.",
789+
FutureWarning,
790+
)
787791
return self.data
788792

789793
def signature(self):

pytensor/graph/op.py

+5
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,11 @@ def get_test_values(*args: Variable) -> Any | list[Any]:
708708
709709
"""
710710

711+
warnings.warn(
712+
"test_value machinery is deprecated and will stop working in the future.",
713+
FutureWarning,
714+
)
715+
711716
if config.compute_test_value == "off":
712717
return []
713718

pytensor/graph/utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import linecache
22
import sys
33
import traceback
4+
import warnings
45
from abc import ABCMeta
56
from collections.abc import Sequence
67
from io import StringIO
@@ -283,9 +284,19 @@ def info(self):
283284

284285
# These two methods have been added to help Mypy
285286
def __getattribute__(self, name):
287+
if name == "test_value":
288+
warnings.warn(
289+
"test_value machinery is deprecated and will stop working in the future.",
290+
FutureWarning,
291+
)
286292
return super().__getattribute__(name)
287293

288294
def __setattr__(self, name: str, value: Any) -> None:
295+
if name == "test_value":
296+
warnings.warn(
297+
"test_value machinery is deprecated and will stop working in the future.",
298+
FutureWarning,
299+
)
289300
self.__dict__[name] = value
290301

291302

@@ -300,6 +311,11 @@ def __init__(self, attr, attr_filter):
300311

301312
def __setattr__(self, attr, obj):
302313
if getattr(self, "attr", None) == attr:
314+
if attr == "test_value":
315+
warnings.warn(
316+
"test_value machinery is deprecated and will stop working in the future.",
317+
FutureWarning,
318+
)
303319
obj = self.attr_filter(obj)
304320

305321
return object.__setattr__(self, attr, obj)

tests/compile/test_builders.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,12 @@ def test_infer_shape(self):
523523

524524
@config.change_flags(compute_test_value="raise")
525525
def test_compute_test_value(self):
526-
x = scalar("x")
527-
x.tag.test_value = np.array(1.0, dtype=config.floatX)
528-
op = OpFromGraph([x], [x**3])
529-
y = scalar("y")
530-
y.tag.test_value = np.array(1.0, dtype=config.floatX)
526+
with pytest.warns(FutureWarning):
527+
x = scalar("x")
528+
x.tag.test_value = np.array(1.0, dtype=config.floatX)
529+
op = OpFromGraph([x], [x**3])
530+
y = scalar("y")
531+
y.tag.test_value = np.array(1.0, dtype=config.floatX)
531532
f = op(y)
532533
grad_f = grad(f, y)
533534
assert grad_f.tag.test_value is not None

tests/compile/test_ops.py

-4
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ def cumprod(x):
3333

3434
def test_2arg(self):
3535
x = dmatrix("x")
36-
x.tag.test_value = np.zeros((2, 2))
3736
y = dvector("y")
38-
y.tag.test_value = [0, 0, 0, 0]
3937

4038
@as_op([dmatrix, dvector], dvector)
4139
def cumprod_plus(x, y):
@@ -49,9 +47,7 @@ def cumprod_plus(x, y):
4947

5048
def test_infer_shape(self):
5149
x = dmatrix("x")
52-
x.tag.test_value = np.zeros((2, 2))
5350
y = dvector("y")
54-
y.tag.test_value = [0, 0, 0, 0]
5551

5652
def infer_shape(fgraph, node, shapes):
5753
x, y = shapes

tests/graph/test_basic.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -371,14 +371,15 @@ def test_eval_kwargs(self):
371371
def test_eval_unashable_kwargs(self):
372372
y_repl = constant(2.0, dtype="floatX")
373373

374-
assert self.w.eval({self.x: 1.0}, givens=((self.y, y_repl),)) == 6.0
375-
376-
with pytest.warns(
377-
UserWarning,
378-
match="Keyword arguments could not be used to create a cache key",
379-
):
380-
# givens dict is not hashable
381-
assert self.w.eval({self.x: 1.0}, givens={self.y: y_repl}) == 6.0
374+
with pytest.warns(FutureWarning):
375+
assert self.w.eval({self.x: 1.0}, givens=((self.y, y_repl),)) == 6.0
376+
377+
with pytest.warns(
378+
UserWarning,
379+
match="Keyword arguments could not be used to create a cache key",
380+
):
381+
# givens dict is not hashable
382+
assert self.w.eval({self.x: 1.0}, givens={self.y: y_repl}) == 6.0
382383

383384

384385
class TestAutoName:

tests/graph/test_compute_test_value.py

+37-26
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def perform(self, node, inputs, outputs):
6767

6868
test_input = SomeType()()
6969
orig_object = object()
70-
test_input.tag.test_value = orig_object
70+
with pytest.warns(FutureWarning):
71+
test_input.tag.test_value = orig_object
7172

7273
res = InplaceOp(False)(test_input)
7374
assert res.tag.test_value is orig_object
@@ -76,10 +77,11 @@ def perform(self, node, inputs, outputs):
7677
assert res.tag.test_value is not orig_object
7778

7879
def test_variable_only(self):
79-
x = matrix("x")
80-
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
81-
y = matrix("y")
82-
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
80+
with pytest.warns(FutureWarning):
81+
x = matrix("x")
82+
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
83+
y = matrix("y")
84+
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
8385

8486
# should work
8587
z = dot(x, y)
@@ -88,14 +90,16 @@ def test_variable_only(self):
8890
assert _allclose(f(x.tag.test_value, y.tag.test_value), z.tag.test_value)
8991

9092
# this test should fail
91-
y.tag.test_value = np.random.random((6, 5)).astype(config.floatX)
93+
with pytest.warns(FutureWarning):
94+
y.tag.test_value = np.random.random((6, 5)).astype(config.floatX)
9295
with pytest.raises(ValueError):
9396
dot(x, y)
9497

9598
def test_compute_flag(self):
9699
x = matrix("x")
97100
y = matrix("y")
98-
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
101+
with pytest.warns(FutureWarning):
102+
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
99103

100104
# should skip computation of test value
101105
with config.change_flags(compute_test_value="off"):
@@ -111,10 +115,11 @@ def test_compute_flag(self):
111115
dot(x, y)
112116

113117
def test_string_var(self):
114-
x = matrix("x")
115-
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
116-
y = matrix("y")
117-
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
118+
with pytest.warns(FutureWarning):
119+
x = matrix("x")
120+
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
121+
y = matrix("y")
122+
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
118123

119124
z = pytensor.shared(np.random.random((5, 6)).astype(config.floatX))
120125

@@ -134,7 +139,8 @@ def f(x, y, z):
134139

135140
def test_shared(self):
136141
x = matrix("x")
137-
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
142+
with pytest.warns(FutureWarning):
143+
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
138144
y = pytensor.shared(np.random.random((4, 6)).astype(config.floatX), "y")
139145

140146
# should work
@@ -190,30 +196,31 @@ def test_constant(self):
190196
def test_incorrect_type(self):
191197
x = vector("x")
192198
with pytest.raises(TypeError):
193-
# Incorrect shape for test value
194-
x.tag.test_value = np.empty((2, 2))
199+
with pytest.warns(FutureWarning):
200+
# Incorrect shape for test value
201+
x.tag.test_value = np.empty((2, 2))
195202

196203
x = fmatrix("x")
197204
with pytest.raises(TypeError):
198-
# Incorrect dtype (float64) for test value
199-
x.tag.test_value = np.random.random((3, 4))
205+
with pytest.warns(FutureWarning):
206+
# Incorrect dtype (float64) for test value
207+
x.tag.test_value = np.random.random((3, 4))
200208

201209
def test_overided_function(self):
202210
# We need to test those as they mess with Exception
203211
# And we don't want the exception to be changed.
204212
x = matrix()
205-
x.tag.test_value = np.zeros((2, 3), dtype=config.floatX)
206213
y = matrix()
207-
y.tag.test_value = np.zeros((2, 2), dtype=config.floatX)
208214
with pytest.raises(ValueError):
209215
x.__mul__(y)
210216

211217
def test_scan(self):
212218
# Test the compute_test_value mechanism Scan.
213219
k = iscalar("k")
214220
A = vector("A")
215-
k.tag.test_value = 3
216-
A.tag.test_value = np.random.random((5,)).astype(config.floatX)
221+
with pytest.warns(FutureWarning):
222+
k.tag.test_value = 3
223+
A.tag.test_value = np.random.random((5,)).astype(config.floatX)
217224

218225
def fx(prior_result, A):
219226
return prior_result * A
@@ -233,8 +240,9 @@ def test_scan_err1(self):
233240
# This test should fail when building fx for the first time
234241
k = iscalar("k")
235242
A = matrix("A")
236-
k.tag.test_value = 3
237-
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
243+
with pytest.warns(FutureWarning):
244+
k.tag.test_value = 3
245+
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
238246

239247
def fx(prior_result, A):
240248
return dot(prior_result, A)
@@ -253,8 +261,9 @@ def test_scan_err2(self):
253261
# but when calling the scan's perform()
254262
k = iscalar("k")
255263
A = matrix("A")
256-
k.tag.test_value = 3
257-
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
264+
with pytest.warns(FutureWarning):
265+
k.tag.test_value = 3
266+
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
258267

259268
def fx(prior_result, A):
260269
return dot(prior_result, A)
@@ -288,7 +297,8 @@ def perform(self, node, inputs, outputs):
288297
output[0] = input + 1
289298

290299
i = ps.int32("i")
291-
i.tag.test_value = 3
300+
with pytest.warns(FutureWarning):
301+
i.tag.test_value = 3
292302

293303
o = IncOnePython()(i)
294304

@@ -304,7 +314,8 @@ def perform(self, node, inputs, outputs):
304314
)
305315
def test_no_perform(self):
306316
i = ps.int32("i")
307-
i.tag.test_value = 3
317+
with pytest.warns(FutureWarning):
318+
i.tag.test_value = 3
308319

309320
# Class IncOneC is defined outside of the TestComputeTestValue
310321
# so it can be pickled and unpickled

tests/graph/test_fg.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -233,18 +233,19 @@ def test_change_input(self):
233233

234234
@config.change_flags(compute_test_value="raise")
235235
def test_replace_test_value(self):
236-
var1 = MyVariable("var1")
237-
var1.tag.test_value = 1
238-
var2 = MyVariable("var2")
239-
var2.tag.test_value = 2
240-
var3 = op1(var2, var1)
241-
var4 = op2(var3, var2)
242-
var4.tag.test_value = np.array([1, 2])
243-
var5 = op3(var4, var2, var2)
244-
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
245-
246-
var6 = op3()
247-
var6.tag.test_value = np.array(0)
236+
with pytest.warns(FutureWarning):
237+
var1 = MyVariable("var1")
238+
var1.tag.test_value = 1
239+
var2 = MyVariable("var2")
240+
var2.tag.test_value = 2
241+
var3 = op1(var2, var1)
242+
var4 = op2(var3, var2)
243+
var4.tag.test_value = np.array([1, 2])
244+
var5 = op3(var4, var2, var2)
245+
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
246+
247+
var6 = op3()
248+
var6.tag.test_value = np.array(0)
248249

249250
assert var6.tag.test_value.shape != var4.tag.test_value.shape
250251

0 commit comments

Comments
 (0)