Skip to content

Commit 750d334

Browse files
brandonwillardricardoV94
authored andcommitted
Use pytest-benchmark
Co-authored-by: Brandon T. Willard <[email protected]>
1 parent 81ebcca commit 750d334

File tree

9 files changed

+64
-31
lines changed

9 files changed

+64
-31
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ jobs:
115115
- name: Install dependencies
116116
shell: bash -l {0}
117117
run: |
118-
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov sympy
118+
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
119119
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi
120120
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib
121121
pip install -e ./
@@ -132,7 +132,7 @@ jobs:
132132
if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
133133
if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi
134134
export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
135-
python -m pytest -x -r A --verbose --runslow --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART
135+
python -m pytest -x -r A --verbose --runslow --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART --benchmark-skip
136136
env:
137137
MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
138138
MKL_THREADING_LAYER: GNU

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies:
2929
- pytest
3030
- pytest-cov
3131
- pytest-xdist
32+
- pytest-benchmark
3233
# For building docs
3334
- sphinx>=1.3
3435
- sphinx_rtd_theme

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ tests = [
8686
"pre-commit",
8787
"pytest-cov>=2.6.1",
8888
"coverage>=5.1",
89+
"pytest-benchmark",
8990
]
9091
rtd = [
9192
"sphinx>=1.3.0",

tests/link/jax/test_elemwise.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22
import pytest
3+
import scipy.special
34

5+
import pytensor
6+
import pytensor.tensor as at
47
from pytensor.configdefaults import config
58
from pytensor.graph.fg import FunctionGraph
69
from pytensor.graph.op import get_test_value
@@ -98,3 +101,24 @@ def test_softmax_grad(axis):
98101
out = SoftmaxGrad(axis=axis)(dy, sm)
99102
fgraph = FunctionGraph([dy, sm], [out])
100103
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
104+
105+
106+
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)])
107+
@pytest.mark.parametrize("axis", [0, 1])
108+
def test_logsumexp_benchmark(size, axis, benchmark):
109+
X = at.matrix("X")
110+
X_max = at.max(X, axis=axis, keepdims=True)
111+
X_max = at.switch(at.isinf(X_max), 0, X_max)
112+
X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max
113+
114+
X_val = np.random.normal(size=size)
115+
116+
X_lse_fn = pytensor.function([X], X_lse, mode="JAX")
117+
118+
# JIT compile first
119+
_ = X_lse_fn(X_val)
120+
121+
res = benchmark(X_lse_fn, X_val)
122+
123+
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
124+
np.testing.assert_array_almost_equal(res, exp_res)

tests/link/numba/test_basic.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22
import inspect
3-
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple, Union
3+
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
44
from unittest import mock
55

66
import numba
@@ -190,7 +190,7 @@ def compare_numba_and_py(
190190
numba_mode=numba_mode,
191191
py_mode=py_mode,
192192
updates=None,
193-
):
193+
) -> Tuple[Callable, Any]:
194194
"""Function to compare python graph output and Numba compiled output for testing equality
195195
196196
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
@@ -209,6 +209,10 @@ def compare_numba_and_py(
209209
updates
210210
Updates to be passed to `pytensor.function`.
211211
212+
Returns
213+
-------
214+
The compiled PyTensor function and its last computed result.
215+
212216
"""
213217
if assert_fn is None:
214218

@@ -248,7 +252,7 @@ def assert_fn(x, y):
248252
else:
249253
assert_fn(numba_res, py_res)
250254

251-
return numba_res
255+
return pytensor_numba_fn, numba_res
252256

253257

254258
@pytest.mark.parametrize(

tests/link/numba/test_scan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_xit_xot_types(
159159
assert np.allclose(res_val, output_vals)
160160

161161

162-
def test_scan_multiple_output():
162+
def test_scan_multiple_output(benchmark):
163163
"""Test a scan implementation of a SEIR model.
164164
165165
SEIR model definition:
@@ -244,7 +244,9 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
244244
gamma_val,
245245
delta_val,
246246
]
247-
compare_numba_and_py(out_fg, test_input_vals)
247+
scan_fn, _ = compare_numba_and_py(out_fg, test_input_vals)
248+
249+
benchmark(scan_fn, *test_input_vals)
248250

249251

250252
@config.change_flags(compute_test_value="raise")

tests/link/numba/test_tensor_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_Alloc(v, shape):
3232
g = at.alloc(v, *shape)
3333
g_fg = FunctionGraph(outputs=[g])
3434

35-
(numba_res,) = compare_numba_and_py(
35+
_, (numba_res,) = compare_numba_and_py(
3636
g_fg,
3737
[
3838
i.tag.test_value

tests/scan/test_basic.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import pickle
1414
import shutil
1515
import sys
16-
import timeit
1716
from collections import OrderedDict
1817
from tempfile import mkdtemp
1918

@@ -2179,15 +2178,13 @@ def scan_fn():
21792178
@pytest.mark.skipif(
21802179
not config.cxx, reason="G++ not available, so we need to skip this test."
21812180
)
2182-
def test_cython_performance():
2181+
def test_cython_performance(benchmark):
21832182

21842183
# This implicitly confirms that the Cython version is being used
21852184
from pytensor.scan import scan_perform_ext # noqa: F401
21862185

21872186
# Python usually out-performs PyTensor below 100 iterations
21882187
N = 200
2189-
n_timeit = 50
2190-
21912188
M = -1 / np.arange(1, 11).astype(config.floatX)
21922189
r = np.arange(N * 10).astype(config.floatX).reshape(N, 10)
21932190

@@ -2216,17 +2213,11 @@ def f_py():
22162213
# Make sure we're actually computing a `Scan`
22172214
assert any(isinstance(node.op, Scan) for node in f_cvm.maker.fgraph.apply_nodes)
22182215

2219-
cvm_res = f_cvm()
2216+
cvm_res = benchmark(f_cvm)
22202217

22212218
# Make sure the results are the same between the two implementations
22222219
assert np.allclose(cvm_res, py_res)
22232220

2224-
python_duration = timeit.timeit(lambda: f_py(), number=n_timeit)
2225-
cvm_duration = timeit.timeit(lambda: f_cvm(), number=n_timeit)
2226-
print(f"python={python_duration}, cvm={cvm_duration}")
2227-
2228-
assert cvm_duration <= python_duration
2229-
22302221

22312222
@config.change_flags(mode="FAST_COMPILE", compute_test_value="raise")
22322223
def test_compute_test_values():
@@ -2662,7 +2653,7 @@ def numpy_implementation(vsample):
26622653
n_result = numpy_implementation(v_vsample)
26632654
utt.assert_allclose(t_result, n_result)
26642655

2665-
def test_reordering(self):
2656+
def test_reordering(self, benchmark):
26662657
"""Test re-ordering of inputs.
26672658
26682659
some rnn with multiple outputs and multiple inputs; other
@@ -2722,14 +2713,14 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1):
27222713
v_x[i] = np.dot(v_u1[i], vW_in1) + v_u2[i] * vW_in2 + np.dot(v_x[i - 1], vW)
27232714
v_y[i] = np.dot(v_x[i - 1], vWout) + v_y[i - 1]
27242715

2725-
(pytensor_dump1, pytensor_dump2, pytensor_x, pytensor_y) = f4(
2726-
v_u1, v_u2, v_x0, v_y0, vW_in1
2716+
(pytensor_dump1, pytensor_dump2, pytensor_x, pytensor_y) = benchmark(
2717+
f4, v_u1, v_u2, v_x0, v_y0, vW_in1
27272718
)
27282719

27292720
utt.assert_allclose(pytensor_x, v_x)
27302721
utt.assert_allclose(pytensor_y, v_y)
27312722

2732-
def test_scan_as_tensor_on_gradients(self):
2723+
def test_scan_as_tensor_on_gradients(self, benchmark):
27332724
to_scan = dvector("to_scan")
27342725
seq = dmatrix("seq")
27352726
f1 = dscalar("f1")
@@ -2743,7 +2734,12 @@ def scanStep(prev, seq, f1):
27432734
function(inputs=[to_scan, seq, f1], outputs=scanned, allow_input_downcast=True)
27442735

27452736
t_grad = grad(scanned.sum(), wrt=[to_scan, f1], consider_constant=[seq])
2746-
function(inputs=[to_scan, seq, f1], outputs=t_grad, allow_input_downcast=True)
2737+
benchmark(
2738+
function,
2739+
inputs=[to_scan, seq, f1],
2740+
outputs=t_grad,
2741+
allow_input_downcast=True,
2742+
)
27472743

27482744
def caching_nsteps_by_scan_op(self):
27492745
W = matrix("weights")
@@ -3060,7 +3056,7 @@ def inner_fn(tap_m3, tap_m2, tap_m1):
30603056
utt.assert_allclose(outputs, expected_outputs)
30613057

30623058
@pytest.mark.slow
3063-
def test_hessian_bug_grad_grad_two_scans(self):
3059+
def test_hessian_bug_grad_grad_two_scans(self, benchmark):
30643060
# Bug reported by Bitton Tenessi
30653061
# NOTE : The test to reproduce the bug reported by Bitton Tenessi
30663062
# was modified from its original version to be faster to run.
@@ -3094,7 +3090,7 @@ def loss_inner(sum_inner, W):
30943090
H = hessian(cost, W)
30953091
print(".", file=sys.stderr)
30963092
f = function([W, n_steps], H)
3097-
f(np.ones((8,), dtype="float32"), 1)
3093+
benchmark(f, np.ones((8,), dtype="float32"), 1)
30983094

30993095
def test_grad_connectivity_matrix(self):
31003096
def inner_fn(x_tm1, y_tm1, z_tm1):
@@ -3710,7 +3706,7 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, W_in1):
37103706
utt.assert_allclose(pytensor_x, v_x)
37113707
utt.assert_allclose(pytensor_y, v_y)
37123708

3713-
def test_multiple_outs_taps(self):
3709+
def test_multiple_outs_taps(self, benchmark):
37143710
l = 5
37153711
rng = np.random.default_rng(utt.fetch_seed())
37163712

@@ -3805,6 +3801,8 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1):
38053801
np.testing.assert_almost_equal(res[1], ny1)
38063802
np.testing.assert_almost_equal(res[2], ny2)
38073803

3804+
benchmark(f, v_u1, v_u2, v_x0, v_y0, vW_in1)
3805+
38083806
def _grad_mout_helper(self, n_iters, mode):
38093807
rng = np.random.default_rng(utt.fetch_seed())
38103808
n_hid = 3

tests/scan/test_rewriting.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def test_sum_dot(self):
620620
vB = rng.uniform(size=(5, 5)).astype(config.floatX)
621621
utt.assert_allclose(f(vA, vB), np.dot(vA.T, vB))
622622

623-
def test_pregreedy_optimizer(self):
623+
def test_pregreedy_optimizer(self, benchmark):
624624
W = at.zeros((5, 4))
625625
bv = at.zeros((5,))
626626
bh = at.zeros((4,))
@@ -634,7 +634,9 @@ def test_pregreedy_optimizer(self):
634634
n_steps=2,
635635
)
636636
# TODO FIXME: Make this a real test and assert something.
637-
function([v], chain)(np.zeros((3, 5), dtype=config.floatX))
637+
chain_fn = function([v], chain)
638+
639+
benchmark(chain_fn, np.zeros((3, 5), dtype=config.floatX))
638640

639641
def test_machine_translation(self):
640642
"""
@@ -1291,15 +1293,16 @@ def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
12911293
]
12921294
assert len(scan_nodes) == 1
12931295

1294-
def test_savemem_opt(self):
1296+
def test_savemem_opt(self, benchmark):
12951297
y0 = shared(np.ones((2, 10)))
12961298
[y1, y2], updates = scan(
12971299
lambda y: [y, y],
12981300
outputs_info=[dict(initial=y0, taps=[-2]), None],
12991301
n_steps=5,
13001302
)
13011303
# TODO FIXME: Make this a real test and assert something.
1302-
function([], y2.sum(), mode=self.mode)()
1304+
fn = function([], y2.sum(), mode=self.mode)
1305+
benchmark(fn)
13031306

13041307
def test_savemem_opt_0_step(self):
13051308
"""

0 commit comments

Comments
 (0)