Skip to content

Commit 42b5b5f

Browse files
Merge branch 'main' into numba_slinalg
2 parents 7ccd3df + 071eadd commit 42b5b5f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1726
-613
lines changed

.github/workflows/pypi.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
name: Make SDist
2323
runs-on: ubuntu-latest
2424
steps:
25-
- uses: actions/checkout@v3
25+
- uses: actions/checkout@v4
2626
with:
2727
fetch-depth: 0
2828
submodules: true
@@ -44,7 +44,7 @@ jobs:
4444
- windows-2022
4545
- ubuntu-20.04
4646
steps:
47-
- uses: actions/checkout@v3
47+
- uses: actions/checkout@v4
4848
with:
4949
fetch-depth: 0
5050

.github/workflows/test.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
outputs:
2323
changes: ${{ steps.changes.outputs.src }}
2424
steps:
25-
- uses: actions/checkout@v3
25+
- uses: actions/checkout@v4
2626
with:
2727
fetch-depth: 0
2828
- uses: dorny/paths-filter@v2
@@ -54,7 +54,7 @@ jobs:
5454
matrix:
5555
python-version: ["3.9", "3.10", "3.11"]
5656
steps:
57-
- uses: actions/checkout@v3
57+
- uses: actions/checkout@v4
5858
- uses: actions/setup-python@v4
5959
with:
6060
python-version: ${{ matrix.python-version }}
@@ -116,7 +116,7 @@ jobs:
116116
float32: 0
117117
part: "tests/link/jax"
118118
steps:
119-
- uses: actions/checkout@v3
119+
- uses: actions/checkout@v4
120120
with:
121121
fetch-depth: 0
122122
- name: Set up Python ${{ matrix.python-version }}
@@ -139,7 +139,7 @@ jobs:
139139
- name: Install dependencies
140140
shell: bash -l {0}
141141
run: |
142-
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
142+
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl "numpy<1.26" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
143143
# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but
144144
# not numpy, even though scipy 1.7 requires numpy<1.23. When installing
145145
# PyTensor next, pip installs a lower version of numpy via the PyPI.
@@ -187,7 +187,7 @@ jobs:
187187
strategy:
188188
fail-fast: false
189189
steps:
190-
- uses: actions/checkout@v3
190+
- uses: actions/checkout@v4
191191
with:
192192
fetch-depth: 0
193193
- name: Set up Python 3.9
@@ -244,7 +244,7 @@ jobs:
244244
needs: [changes, all-checks]
245245
if: ${{ needs.changes.outputs.changes == 'true' && needs.all-checks.result == 'success' }}
246246
steps:
247-
- uses: actions/checkout@v3
247+
- uses: actions/checkout@v4
248248

249249
- name: Set up Python
250250
uses: actions/setup-python@v4

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ channels:
99
dependencies:
1010
- python
1111
- compilers
12-
- numpy>=1.17.0
12+
- numpy>=1.17.0,<1.26.0
1313
- scipy>=0.14
1414
- filelock
1515
- etuples

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
requires = [
33
"setuptools>=48.0.0",
44
"cython",
5-
"numpy>=1.17.0",
5+
"numpy>=1.17.0,<1.26",
66
"versioneer[toml]==0.28",
77
]
88
build-backend = "setuptools.build_meta"
@@ -52,7 +52,7 @@ keywords = [
5252
dependencies = [
5353
"setuptools>=48.0.0",
5454
"scipy>=0.14",
55-
"numpy>=1.17.0",
55+
"numpy>=1.17.0,<1.26",
5656
"filelock",
5757
"etuples",
5858
"logical-unification",

pytensor/compile/mode.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,14 @@ def apply(self, fgraph):
251251
# especially constant merge
252252
optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49)
253253

254+
optdb.register("py_only", EquilibriumDB(), "fast_compile", position=49.1)
255+
254256
optdb.register(
255257
"add_destroy_handler", AddDestroyHandler(), "fast_run", "inplace", position=49.5
256258
)
257259

258260
# final pass just to make sure
259261
optdb.register("merge3", MergeOptimizer(), "fast_run", "merge", position=100)
260-
optdb.register("py_only", EquilibriumDB(), "fast_compile", position=100)
261262

262263
_tags: Union[Tuple[str, str], Tuple]
263264

@@ -463,7 +464,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
463464
)
464465
NUMBA = Mode(
465466
NumbaLinker(),
466-
RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
467+
RewriteDatabaseQuery(
468+
include=["fast_run"],
469+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
470+
),
467471
)
468472

469473

pytensor/gradient.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,13 @@ def Rop(
196196
197197
Returns
198198
-------
199+
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
199200
A symbolic expression such obeying
200201
``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
201202
where the indices in that expression are magic multidimensional
202203
indices that specify both the position within a list and all
203204
coordinates of the tensor elements.
204-
If `wrt` is a list/tuple, then return a list/tuple with the results.
205+
If `f` is a list/tuple, then return a list/tuple with the results.
205206
"""
206207

207208
if not isinstance(wrt, (list, tuple)):
@@ -384,6 +385,7 @@ def Lop(
384385
385386
Returns
386387
-------
388+
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
387389
A symbolic expression satisfying
388390
``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]``
389391
where the indices in that expression are magic multidimensional
@@ -481,10 +483,10 @@ def grad(
481483
482484
Returns
483485
-------
486+
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
484487
A symbolic expression for the gradient of `cost` with respect to each
485488
of the `wrt` terms. If an element of `wrt` is not differentiable with
486489
respect to the output, then a zero variable is returned.
487-
488490
"""
489491
t0 = time.perf_counter()
490492

@@ -701,7 +703,6 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
701703
702704
Parameters
703705
----------
704-
705706
wrt : list of variables
706707
Gradients are computed with respect to `wrt`.
707708
@@ -876,7 +877,6 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
876877
877878
(A variable in consider_constant is not a function of
878879
anything)
879-
880880
"""
881881

882882
# Validate and format consider_constant
@@ -1035,7 +1035,6 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
10351035
-------
10361036
list of Variables
10371037
A list of gradients corresponding to `wrt`
1038-
10391038
"""
10401039
# build a dict mapping node to the terms node contributes to each of
10411040
# its inputs' gradients
@@ -1423,8 +1422,9 @@ def access_grad_cache(var):
14231422

14241423

14251424
def _float_zeros_like(x):
1426-
"""Like zeros_like, but forces the object to have a
1427-
a floating point dtype"""
1425+
"""Like zeros_like, but forces the object to have
1426+
a floating point dtype
1427+
"""
14281428

14291429
rval = x.zeros_like()
14301430

@@ -1436,7 +1436,8 @@ def _float_zeros_like(x):
14361436

14371437
def _float_ones_like(x):
14381438
"""Like ones_like, but forces the object to have a
1439-
floating point dtype"""
1439+
floating point dtype
1440+
"""
14401441

14411442
dtype = x.type.dtype
14421443
if dtype not in pytensor.tensor.type.float_dtypes:
@@ -1613,7 +1614,6 @@ def abs_rel_errors(self, g_pt):
16131614
16141615
Corresponding ndarrays in `g_pt` and `self.gf` must have the same
16151616
shape or ValueError is raised.
1616-
16171617
"""
16181618
if len(g_pt) != len(self.gf):
16191619
raise ValueError("argument has wrong number of elements", len(g_pt))
@@ -1740,7 +1740,6 @@ def verify_grad(
17401740
This function does not support multiple outputs. In `tests.scan.test_basic`
17411741
there is an experimental `verify_grad` that covers that case as well by
17421742
using random projections.
1743-
17441743
"""
17451744
from pytensor.compile.function import function
17461745
from pytensor.compile.sharedvalue import shared
@@ -2267,7 +2266,6 @@ def grad_clip(x, lower_bound, upper_bound):
22672266
-----
22682267
We register an opt in tensor/opt.py that remove the GradClip.
22692268
So it have 0 cost in the forward and only do work in the grad.
2270-
22712269
"""
22722270
return GradClip(lower_bound, upper_bound)(x)
22732271

pytensor/graph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
clone,
1010
ancestors,
1111
)
12-
from pytensor.graph.replace import clone_replace, graph_replace
12+
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
1313
from pytensor.graph.op import Op
1414
from pytensor.graph.type import Type
1515
from pytensor.graph.fg import FunctionGraph

pytensor/graph/replace.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from functools import partial
2-
from typing import Iterable, Optional, Sequence, Union, cast, overload
1+
from functools import partial, singledispatch
2+
from typing import Iterable, Mapping, Optional, Sequence, Union, cast, overload
33

44
from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
55
from pytensor.graph.fg import FunctionGraph
6+
from pytensor.graph.op import Op
67

78

89
ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
@@ -198,3 +199,112 @@ def toposort_key(
198199
return list(fg.outputs)
199200
else:
200201
return fg.outputs[0]
202+
203+
204+
@singledispatch
205+
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
206+
# Default implementation is provided in pytensor.tensor.blockwise
207+
raise NotImplementedError
208+
209+
210+
def vectorize_node(node: Apply, *batched_inputs) -> Apply:
211+
"""Returns vectorized version of node with new batched inputs."""
212+
op = node.op
213+
return _vectorize_node(op, node, *batched_inputs)
214+
215+
216+
@overload
217+
def vectorize(
218+
outputs: Variable,
219+
replace: Mapping[Variable, Variable],
220+
) -> Variable:
221+
...
222+
223+
224+
@overload
225+
def vectorize(
226+
outputs: Sequence[Variable],
227+
replace: Mapping[Variable, Variable],
228+
) -> Sequence[Variable]:
229+
...
230+
231+
232+
def vectorize(
233+
outputs: Union[Variable, Sequence[Variable]],
234+
replace: Mapping[Variable, Variable],
235+
) -> Union[Variable, Sequence[Variable]]:
236+
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version.
237+
238+
Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
239+
240+
Examples
241+
--------
242+
.. code-block:: python
243+
244+
import pytensor
245+
import pytensor.tensor as pt
246+
247+
from pytensor.graph import vectorize
248+
249+
# Original graph
250+
x = pt.vector("x")
251+
y = pt.exp(x) / pt.sum(pt.exp(x))
252+
253+
# Vectorized graph
254+
new_x = pt.matrix("new_x")
255+
new_y = vectorize(y, replace={x: new_x})
256+
257+
fn = pytensor.function([new_x], new_y)
258+
fn([[0, 1, 2], [2, 1, 0]])
259+
# array([[0.09003057, 0.24472847, 0.66524096],
260+
# [0.66524096, 0.24472847, 0.09003057]])
261+
262+
263+
.. code-block:: python
264+
265+
import pytensor
266+
import pytensor.tensor as pt
267+
268+
from pytensor.graph import vectorize
269+
270+
# Original graph
271+
x = pt.vector("x")
272+
y1 = x[0]
273+
y2 = x[-1]
274+
275+
# Vectorized graph
276+
new_x = pt.matrix("new_x")
277+
[new_y1, new_y2] = vectorize([y1, y2], replace={x: new_x})
278+
279+
fn = pytensor.function([new_x], [new_y1, new_y2])
280+
fn([[-10, 0, 10], [-11, 0, 11]])
281+
# [array([-10., -11.]), array([10., 11.])]
282+
283+
"""
284+
if isinstance(outputs, Sequence):
285+
seq_outputs = outputs
286+
else:
287+
seq_outputs = [outputs]
288+
289+
inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
290+
new_inputs = [replace.get(inp, inp) for inp in inputs]
291+
292+
def transform(var: Variable) -> Variable:
293+
if var in inputs:
294+
return new_inputs[inputs.index(var)]
295+
296+
node = var.owner
297+
batched_inputs = [transform(inp) for inp in node.inputs]
298+
batched_node = vectorize_node(node, *batched_inputs)
299+
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
300+
301+
return cast(Variable, batched_var)
302+
303+
# TODO: MergeOptimization or node caching?
304+
seq_vect_outputs = [transform(out) for out in seq_outputs]
305+
306+
if isinstance(outputs, Sequence):
307+
return seq_vect_outputs
308+
else:
309+
[vect_output] = seq_vect_outputs
310+
return vect_output

pytensor/ifelse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ def cond_make_inplace(fgraph, node):
482482
at.basic.Alloc,
483483
at.elemwise.Elemwise,
484484
at.elemwise.DimShuffle,
485+
at.blockwise.Blockwise,
485486
)
486487

487488

pytensor/link/jax/dispatch/scan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import jax
22
import jax.numpy as jnp
33

4+
from pytensor.compile.mode import JAX
45
from pytensor.link.jax.dispatch.basic import jax_funcify
56
from pytensor.scan.op import Scan
67

@@ -17,8 +18,8 @@ def jax_funcify_Scan(op: Scan, **kwargs):
1718
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
1819
)
1920

20-
# Optimize inner graph
21-
rewriter = op.mode_instance.optimizer
21+
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
22+
rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer
2223
rewriter(op.fgraph)
2324
scan_inner_func = jax_funcify(op.fgraph, **kwargs)
2425

0 commit comments

Comments
 (0)