Skip to content

🔄 From Aesara: #1347 and # 1365: " Add CI support for benchmarking" #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 56 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ on:
push:
branches:
- main
- checks
pull_request:
branches:
- main
Expand Down Expand Up @@ -115,7 +114,7 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov sympy
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib
pip install -e ./
Expand All @@ -132,7 +131,7 @@ jobs:
if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi
export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
python -m pytest -x -r A --verbose --runslow --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART
python -m pytest -x -r A --verbose --runslow --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART --benchmark-skip
env:
MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
MKL_THREADING_LAYER: GNU
Expand All @@ -148,6 +147,60 @@ jobs:
name: coverage
path: coverage/coverage-${{ steps.matrix-id.outputs.id }}.xml

benchmarks:
name: "Benchmarks"
needs:
- changes
- style
runs-on: ubuntu-latest
if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
strategy:
fail-fast: true
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Python 3.9
uses: conda-incubator/setup-miniconda@v2
with:
mamba-version: "*"
channels: conda-forge,defaults
channel-priority: true
python-version: 3.9
auto-update-conda: true
- name: Install dependencies
shell: bash -l {0}
run: |
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.55" numba-scipy jax jaxlib pytest-benchmark
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
python -c 'import pytensor; assert(pytensor.config.blas__ldflags != "")'
env:
PYTHON_VERSION: 3.9
- name: Download previous benchmark data
uses: actions/cache@v1
with:
path: ./cache
key: ${{ runner.os }}-benchmark
- name: Run benchmarks
shell: bash -l {0}
run: |
export PYTENSOR_FLAGS=mode=FAST_COMPILE,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
python -m pytest --runslow --benchmark-only --benchmark-json output.json
- name: Store benchmark result
uses: benchmark-action/github-action-benchmark@v1
with:
name: Python Benchmark with pytest-benchmark
tool: 'pytest'
output-file-path: output.json
external-data-json-path: ./cache/benchmark-data.json
alert-threshold: '200%'
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: ${{ github.event_name == 'push' }}
fail-on-alert: true
auto-push: false

all-checks:
if: ${{ always() }}
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies:
- pytest
- pytest-cov
- pytest-xdist
- pytest-benchmark
# For building docs
- sphinx>=1.3
- sphinx_rtd_theme
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ tests = [
"pre-commit",
"pytest-cov>=2.6.1",
"coverage>=5.1",
"pytest-benchmark",
]
rtd = [
"sphinx>=1.3.0",
Expand Down
24 changes: 24 additions & 0 deletions tests/link/jax/test_elemwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import pytest
import scipy.special

import pytensor
import pytensor.tensor as at
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
Expand Down Expand Up @@ -98,3 +101,24 @@ def test_softmax_grad(axis):
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])


@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark):
X = at.matrix("X")
X_max = at.max(X, axis=axis, keepdims=True)
X_max = at.switch(at.isinf(X_max), 0, X_max)
X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max

X_val = np.random.normal(size=size)

X_lse_fn = pytensor.function([X], X_lse, mode="JAX")

# JIT compile first
_ = X_lse_fn(X_val)

res = benchmark(X_lse_fn, X_val)

exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
10 changes: 7 additions & 3 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import inspect
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
from unittest import mock

import numba
Expand Down Expand Up @@ -190,7 +190,7 @@ def compare_numba_and_py(
numba_mode=numba_mode,
py_mode=py_mode,
updates=None,
):
) -> Tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality

In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
Expand All @@ -209,6 +209,10 @@ def compare_numba_and_py(
updates
Updates to be passed to `pytensor.function`.

Returns
-------
The compiled PyTensor function and its last computed result.

"""
if assert_fn is None:

Expand Down Expand Up @@ -248,7 +252,7 @@ def assert_fn(x, y):
else:
assert_fn(numba_res, py_res)

return numba_res
return pytensor_numba_fn, numba_res


@pytest.mark.parametrize(
Expand Down
6 changes: 4 additions & 2 deletions tests/link/numba/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_xit_xot_types(
assert np.allclose(res_val, output_vals)


def test_scan_multiple_output():
def test_scan_multiple_output(benchmark):
"""Test a scan implementation of a SEIR model.

SEIR model definition:
Expand Down Expand Up @@ -244,7 +244,9 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
gamma_val,
delta_val,
]
compare_numba_and_py(out_fg, test_input_vals)
scan_fn, _ = compare_numba_and_py(out_fg, test_input_vals)

benchmark(scan_fn, *test_input_vals)


@config.change_flags(compute_test_value="raise")
Expand Down
2 changes: 1 addition & 1 deletion tests/link/numba/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_Alloc(v, shape):
g = at.alloc(v, *shape)
g_fg = FunctionGraph(outputs=[g])

(numba_res,) = compare_numba_and_py(
_, (numba_res,) = compare_numba_and_py(
g_fg,
[
i.tag.test_value
Expand Down
43 changes: 21 additions & 22 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import pickle
import shutil
import sys
import timeit
from collections import OrderedDict
from tempfile import mkdtemp

Expand Down Expand Up @@ -2179,15 +2178,13 @@ def scan_fn():
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_cython_performance():
def test_cython_performance(benchmark):

# This implicitly confirms that the Cython version is being used
from pytensor.scan import scan_perform_ext # noqa: F401

# Python usually out-performs PyTensor below 100 iterations
N = 200
n_timeit = 50

M = -1 / np.arange(1, 11).astype(config.floatX)
r = np.arange(N * 10).astype(config.floatX).reshape(N, 10)

Expand Down Expand Up @@ -2216,17 +2213,11 @@ def f_py():
# Make sure we're actually computing a `Scan`
assert any(isinstance(node.op, Scan) for node in f_cvm.maker.fgraph.apply_nodes)

cvm_res = f_cvm()
cvm_res = benchmark(f_cvm)

# Make sure the results are the same between the two implementations
assert np.allclose(cvm_res, py_res)

python_duration = timeit.timeit(lambda: f_py(), number=n_timeit)
cvm_duration = timeit.timeit(lambda: f_cvm(), number=n_timeit)
print(f"python={python_duration}, cvm={cvm_duration}")

assert cvm_duration <= python_duration


@config.change_flags(mode="FAST_COMPILE", compute_test_value="raise")
def test_compute_test_values():
Expand Down Expand Up @@ -2662,7 +2653,7 @@ def numpy_implementation(vsample):
n_result = numpy_implementation(v_vsample)
utt.assert_allclose(t_result, n_result)

def test_reordering(self):
def test_reordering(self, benchmark):
"""Test re-ordering of inputs.

some rnn with multiple outputs and multiple inputs; other
Expand Down Expand Up @@ -2722,14 +2713,14 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1):
v_x[i] = np.dot(v_u1[i], vW_in1) + v_u2[i] * vW_in2 + np.dot(v_x[i - 1], vW)
v_y[i] = np.dot(v_x[i - 1], vWout) + v_y[i - 1]

(pytensor_dump1, pytensor_dump2, pytensor_x, pytensor_y) = f4(
v_u1, v_u2, v_x0, v_y0, vW_in1
(pytensor_dump1, pytensor_dump2, pytensor_x, pytensor_y) = benchmark(
f4, v_u1, v_u2, v_x0, v_y0, vW_in1
)

utt.assert_allclose(pytensor_x, v_x)
utt.assert_allclose(pytensor_y, v_y)

def test_scan_as_tensor_on_gradients(self):
def test_scan_as_tensor_on_gradients(self, benchmark):
to_scan = dvector("to_scan")
seq = dmatrix("seq")
f1 = dscalar("f1")
Expand All @@ -2743,7 +2734,12 @@ def scanStep(prev, seq, f1):
function(inputs=[to_scan, seq, f1], outputs=scanned, allow_input_downcast=True)

t_grad = grad(scanned.sum(), wrt=[to_scan, f1], consider_constant=[seq])
function(inputs=[to_scan, seq, f1], outputs=t_grad, allow_input_downcast=True)
benchmark(
function,
inputs=[to_scan, seq, f1],
outputs=t_grad,
allow_input_downcast=True,
)

def caching_nsteps_by_scan_op(self):
W = matrix("weights")
Expand Down Expand Up @@ -3060,7 +3056,7 @@ def inner_fn(tap_m3, tap_m2, tap_m1):
utt.assert_allclose(outputs, expected_outputs)

@pytest.mark.slow
def test_hessian_bug_grad_grad_two_scans(self):
def test_hessian_bug_grad_grad_two_scans(self, benchmark):
# Bug reported by Bitton Tenessi
# NOTE : The test to reproduce the bug reported by Bitton Tenessi
# was modified from its original version to be faster to run.
Expand Down Expand Up @@ -3094,7 +3090,7 @@ def loss_inner(sum_inner, W):
H = hessian(cost, W)
print(".", file=sys.stderr)
f = function([W, n_steps], H)
f(np.ones((8,), dtype="float32"), 1)
benchmark(f, np.ones((8,), dtype="float32"), 1)

def test_grad_connectivity_matrix(self):
def inner_fn(x_tm1, y_tm1, z_tm1):
Expand Down Expand Up @@ -3710,7 +3706,7 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, W_in1):
utt.assert_allclose(pytensor_x, v_x)
utt.assert_allclose(pytensor_y, v_y)

def test_multiple_outs_taps(self):
def test_multiple_outs_taps(self, benchmark):
l = 5
rng = np.random.default_rng(utt.fetch_seed())

Expand Down Expand Up @@ -3753,8 +3749,6 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1):
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
)

f(v_u1, v_u2, v_x0, v_y0, vW_in1)

ny0 = np.zeros((5, 2))
ny1 = np.zeros((5,))
ny2 = np.zeros((5, 2))
Expand Down Expand Up @@ -3802,7 +3796,12 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1):
ny1[4] = (ny1[3] + ny1[1]) * np.dot(ny0[3], vWout)
ny2[4] = np.dot(v_u1[4], vW_in1)

# TODO FIXME: What is this testing? At least assert something.
res = f(v_u1, v_u2, v_x0, v_y0, vW_in1)
np.testing.assert_almost_equal(res[0], ny0)
np.testing.assert_almost_equal(res[1], ny1)
np.testing.assert_almost_equal(res[2], ny2)

benchmark(f, v_u1, v_u2, v_x0, v_y0, vW_in1)

def _grad_mout_helper(self, n_iters, mode):
rng = np.random.default_rng(utt.fetch_seed())
Expand Down
11 changes: 7 additions & 4 deletions tests/scan/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def test_sum_dot(self):
vB = rng.uniform(size=(5, 5)).astype(config.floatX)
utt.assert_allclose(f(vA, vB), np.dot(vA.T, vB))

def test_pregreedy_optimizer(self):
def test_pregreedy_optimizer(self, benchmark):
W = at.zeros((5, 4))
bv = at.zeros((5,))
bh = at.zeros((4,))
Expand All @@ -634,7 +634,9 @@ def test_pregreedy_optimizer(self):
n_steps=2,
)
# TODO FIXME: Make this a real test and assert something.
function([v], chain)(np.zeros((3, 5), dtype=config.floatX))
chain_fn = function([v], chain)

benchmark(chain_fn, np.zeros((3, 5), dtype=config.floatX))

def test_machine_translation(self):
"""
Expand Down Expand Up @@ -1291,15 +1293,16 @@ def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
]
assert len(scan_nodes) == 1

def test_savemem_opt(self):
def test_savemem_opt(self, benchmark):
y0 = shared(np.ones((2, 10)))
[y1, y2], updates = scan(
lambda y: [y, y],
outputs_info=[dict(initial=y0, taps=[-2]), None],
n_steps=5,
)
# TODO FIXME: Make this a real test and assert something.
function([], y2.sum(), mode=self.mode)()
fn = function([], y2.sum(), mode=self.mode)
benchmark(fn)

def test_savemem_opt_0_step(self):
"""
Expand Down