From daabeb34ddbca22ba0b9abbeb43b72d4cc418fca Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 18 Jan 2023 15:29:55 +0100 Subject: [PATCH 01/13] Fix failing TestHyp2F1Broadcast --- pytensor/scalar/math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 6733c62a86..b24d9d0341 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -1638,9 +1638,9 @@ def compute_grad_2f1(a, b, c, z, wrt): return compute_grad_2f1(a, b, c, z, wrt=wrt) - def __call__(self, a, b, c, z, wrt): + def __call__(self, a, b, c, z, wrt, **kwargs): # This allows wrt to be a keyword argument - return super().__call__(a, b, c, z, wrt) + return super().__call__(a, b, c, z, wrt, **kwargs) def c_code(self, *args, **kwargs): raise NotImplementedError() From a2f101adbf8677fe2cdc9bc3a5844e96c33cb09a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Dec 2022 10:23:02 +0100 Subject: [PATCH 02/13] Cleanup Fusion rewrites * Move local_add_mul_fusion to `rewriting/elemwise` and remove unused/duplicated TestAddMulFusion tests * Use EquilibriumGraphRewriter for local_add_mul_fusion * Do not register optional rewrites if tensor__local_elemwise_fusion flag is disabled --- pytensor/graph/rewriting/db.py | 2 +- pytensor/tensor/rewriting/elemwise.py | 71 ++- pytensor/tensor/rewriting/math.py | 61 -- tests/tensor/rewriting/test_elemwise.py | 40 +- tests/tensor/rewriting/test_math.py | 764 +----------------------- 5 files changed, 84 insertions(+), 854 deletions(-) diff --git a/pytensor/graph/rewriting/db.py b/pytensor/graph/rewriting/db.py index f303c1840e..645faf9911 100644 --- a/pytensor/graph/rewriting/db.py +++ b/pytensor/graph/rewriting/db.py @@ -427,7 +427,7 @@ def query( position_cutoff = tags[0].position_cutoff # The RewriteDatabaseQuery instance might contain extra rewrites which need - # to be added the the sequence of rewrites (don't alter the + # to be added to the sequence of rewrites (don't alter the # original dictionary) if len(tags[0].extra_rewrites) > 0: position_dict = position_dict.copy() diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e9952a3908..47f2de28d7 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -13,6 +13,7 @@ from pytensor.graph.features import ReplaceValidate from pytensor.graph.op import compute_test_value, get_test_value from pytensor.graph.rewriting.basic import ( + EquilibriumGraphRewriter, GraphRewriter, copy_stack_trace, in2out, @@ -529,6 +530,60 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): return rval +@node_rewriter([Elemwise]) +def local_add_mul_fusion(fgraph, node): + """Fuse consecutive add or mul in one such node with more inputs. + + It is better to fuse add/mul that way then in a Composite node as + this make the inner graph of the Composite smaller. This allows to + put more computation in a Composite before hitting the max + recursion limit when pickling Composite. + + This rewrite is almost useless after the AlgebraicCanonizer is used, + but it catches a few edge cases that are not canonicalized by it + """ + if not isinstance(node.op, Elemwise) or not isinstance( + node.op.scalar_op, (aes.Add, aes.Mul) + ): + return False + + s_op = node.op.scalar_op.__class__ + new_inp = [] + fused = False + nb_inputs = len(node.inputs) + max_inputs = float("inf") + if hasattr(node.op, "max_inputs"): + max_inputs = node.op.max_inputs(node) + for inp in node.inputs: + if ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and isinstance(inp.owner.op.scalar_op, s_op) + and + # Do not duplicate the operation. + len(fgraph.clients[inp]) == 1 + and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs + ): + new_inp.extend(inp.owner.inputs) + fused = True + else: + new_inp.append(inp) + + # We can not compare the number of inputs as Mul and Add could have + # 0 or 1 inputs in some corner cases. + if fused: + output = node.op(*new_inp) + copy_stack_trace(node.outputs[0], output) + + # Do the recursion here to help lower the number of + # FusionOptimizer iteration. + if output.owner: + output2 = local_add_mul_fusion.transform(fgraph, output.owner) + if output2: + return output2 + return [output] + + def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None): r"""Create a recursive function that fuses `Elemwise` `Op`\s. @@ -901,6 +956,13 @@ def print_profile(cls, stream, prof, level=0): if config.tensor__local_elemwise_fusion: # Must be after gpu(48.5) and before AddDestroyHandler(49.5) fuse_seqopt = SequenceDB() + fuse_seqopt.register( + "local_add_mul_fusion", + EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000), + "fast_run", + "fusion", + position=0, + ) fuse_seqopt.register( "composite_elemwise_fusion", FusionOptimizer(local_elemwise_fusion), @@ -917,15 +979,6 @@ def print_profile(cls, stream, prof, level=0): "FusionOptimizer", position=49, ) -else: - compile.optdb.register( # type: ignore - "elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), - "fusion", - "local_elemwise_fusion", - "FusionOptimizer", - position=49, - ) @register_canonicalize diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 7cf495b56a..c2efda11d0 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -92,7 +92,6 @@ register_uncanonicalize, register_useless, ) -from pytensor.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( @@ -2966,66 +2965,6 @@ def check_input(inputs): return [ret] -def local_add_mul_fusion(fgraph, node): - """Fuse consecutive add or mul in one such node with more inputs. - - It is better to fuse add/mul that way then in a Composite node as - this make the inner graph of the Composite smaller. This allow to - put more computation in a Composite before hitting the max - recursion limit when pickling Composite. - - """ - if not isinstance(node.op, Elemwise) or not isinstance( - node.op.scalar_op, (aes.Add, aes.Mul) - ): - return False - - s_op = node.op.scalar_op.__class__ - new_inp = [] - fused = False - nb_inputs = len(node.inputs) - max_inputs = float("inf") - if hasattr(node.op, "max_inputs"): - max_inputs = node.op.max_inputs(node) - for inp in node.inputs: - if ( - inp.owner - and isinstance(inp.owner.op, Elemwise) - and isinstance(inp.owner.op.scalar_op, s_op) - and - # Do not duplicate the operation. - len(fgraph.clients[inp]) == 1 - and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs - ): - new_inp.extend(inp.owner.inputs) - fused = True - else: - new_inp.append(inp) - - # We can not compare the number of inputs as Mul and Add could have - # 0 or 1 inputs in some corner cases. - if fused: - output = node.op(*new_inp) - copy_stack_trace(node.outputs[0], output) - - # Do the recursion here to help lower the number of - # FusionOptimizer iteration. - if output.owner: - output2 = local_add_mul_fusion(fgraph, output.owner) - if output2: - return output2 - return [output] - - -fuse_seqopt.register( - "local_add_mul_fusion", - FusionOptimizer(local_add_mul_fusion), - "fast_run", - "fusion", - position=0, -) - - def _skip_mul_1(r): if r.owner and r.owner.op == mul: not_is_1 = [i for i in r.owner.inputs if not _is_1(i)] diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index e8dce3e5ff..613a956384 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -4,9 +4,9 @@ import pytest import pytensor -import pytensor.scalar as aes -import pytensor.tensor as at +from pytensor import scalar as aes from pytensor import shared +from pytensor import tensor as at from pytensor.compile.function import function from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config @@ -263,9 +263,8 @@ def test_local_useless_dimshuffle_in_reshape(): class TestFusion: rewrites = RewriteDatabaseQuery( include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", "canonicalize", + "fusion", "inplace", ], exclude=["cxx_only", "BlasOpt"], @@ -1007,22 +1006,10 @@ def test_big_fusion(self): ) def test_add_mul_fusion_inplace(self): - - rewrites = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - - mode = Mode(self.mode.linker, rewrites) - x, y, z = dmatrices("xyz") out = dot(x, y) + x + y + z - f = function([x, y, z], out, mode=mode) + + f = function([x, y, z], out, mode=self.mode) topo = [n for n in f.maker.fgraph.toposort()] assert len(topo) == 2 assert topo[-1].op.inplace_pattern @@ -1050,8 +1037,7 @@ def impl(self, x): mode = Mode(linker="cvm") mode._optimizer = mode._optimizer.including( - "local_elemwise_fusion", - "composite_elemwise_fusion", + "fusion", "canonicalize", "inplace", ) @@ -1073,18 +1059,6 @@ def test_test_values(self, test_value): are checked. """ - - rewrites = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - ], - exclude=["cxx_only", "BlasOpt"], - ) - - mode = Mode(self.mode.linker, rewrites) - x, y, z = dmatrices("xyz") x.tag.test_value = test_value @@ -1101,7 +1075,7 @@ def test_test_values(self, test_value): ): out = x * y + z with cm: - f = function([x, y, z], out, mode=mode) + f = function([x, y, z], out, mode=self.mode) if test_value.size != 0: # Confirm that the fusion happened diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 116c3e4ad0..850ba0db6b 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -16,7 +16,7 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, Constant, equal_computations +from pytensor.graph.basic import Apply, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( SequentialNodeRewriter, @@ -46,7 +46,6 @@ bitwise_or, bitwise_xor, conj, - cos, cosh, deg2rad, dot, @@ -59,14 +58,10 @@ ge, gt, int_div, - invert, - iround, le, log, log1mexp, log1p, - log2, - log10, lt, ) from pytensor.tensor.math import max as at_max @@ -74,11 +69,20 @@ from pytensor.tensor.math import min as at_min from pytensor.tensor.math import minimum, mul, neg, neq from pytensor.tensor.math import pow as at_pow -from pytensor.tensor.math import prod, rad2deg, reciprocal -from pytensor.tensor.math import round as at_round -from pytensor.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub +from pytensor.tensor.math import ( + prod, + rad2deg, + reciprocal, + sgn, + sigmoid, + sinh, + softplus, + sqr, + sqrt, + sub, +) from pytensor.tensor.math import sum as at_sum -from pytensor.tensor.math import tan, tanh, true_div, xor +from pytensor.tensor.math import tanh, true_div, xor from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.math import ( compute_mul, @@ -102,7 +106,6 @@ dvector, fmatrices, fmatrix, - fscalar, ftensor4, fvector, imatrices, @@ -1072,745 +1075,6 @@ def test_cast_in_mul_canonizer(): f([1], [1]) -class TestFusion: - rewrites = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - mode = Mode(get_default_mode().linker, rewrites) - _shared = staticmethod(shared) - topo_exclude = () - - def do(self, mode, shared_fn, shp, nb_repeat=1, assert_len_topo=True, slice=None): - """ - param shared_fn: if None, will use function - verify that the elemwise fusion work - Test with and without DimShuffle - """ - # TODO: disable the canonizer? - def my_init(shp, dtype="float64", num=0): - ret = np.zeros(shp, dtype=dtype) + num - return ret - - fw, fx, fy, fz = ( - tensor(dtype="float32", shape=(None,) * len(shp), name=n) for n in "wxyz" - ) - dw, dx, dy, dz = ( - tensor(dtype="float64", shape=(None,) * len(shp), name=n) for n in "wxyz" - ) - ix, iy, iz = ( - tensor(dtype="int32", shape=(None,) * len(shp), name=n) for n in "xyz" - ) - fv = fvector("v") - fs = fscalar("s") - - fwv = my_init(shp, "float32", 1) - fxv = my_init(shp, "float32", 2) - fyv = my_init(shp, "float32", 3) - fzv = my_init(shp, "float32", 4) - fvv = _asarray(np.random.random(shp[0]), dtype="float32") - fsv = np.asarray(np.random.random(), dtype="float32") - dwv = my_init(shp, "float64", 5) - ixv = _asarray(my_init(shp, num=60), dtype="int32") - iyv = _asarray(my_init(shp, num=70), dtype="int32") - izv = _asarray(my_init(shp, num=70), dtype="int32") - fwx = fw + fx - ftanx = tan(fx) - cases = [ - ( - fx + fy + fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + fzv, - "float32", - ), # 0 - ( - fx * fy * fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv * fzv, - "float32", - ), # 1 - ( - fx + fy * fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv * fzv, - "float32", - ), # 2 - ( - fx * fy + fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv + fzv, - "float32", - ), # 3 - ( - fw + fx + fy + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), # 5 - ( - ((fw + fx) + fy) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + (fx + fy)) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + (fx + fy) + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - fw + (fx + (fy + fz)), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), # 10 - ( - fw * fx * fy * fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv * fxv * fyv * fzv, - "float32", - ), - ( - fw + fx * fy * fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv * fyv * fzv, - "float32", - ), - ( - fx + fy * fz * fx, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv * fzv * fxv, - "float32", - ), - ( - fx * fy + fz + fy, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv + fzv + fyv, - "float32", - ), - ( - fx * fy * fz * fw + fx + fy + fz + fw, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fxv * fyv * fzv * fwv + fxv + fyv + fzv + fwv, - "float32", - ), # 15 - # test with constant - ( - (fw + fx) + (fy + fz) + 2.0, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - ((fw + fx) + 2.0 + fy) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - (fw + (fx + 2.0 + fy)) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - (fw + (fx + fy) + 2 + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - fw + (fx + (fy + fz) + 2.0), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), # 20 - ( - 2 + (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - # mix float32 and float64 - ( - 2 + (dw + fx) + (fy + fz), - (dw, fx, fy, fz), - (dwv, fxv, fyv, fzv), - 1, - dwv + fxv + fyv + fzv + 2, - "float64", - ), - ( - 2 + (fw + dw) + (fy + fz), - (fw, dw, fy, fz), - (fwv, dwv, fyv, fzv), - 1, - fwv + dwv + fyv + fzv + 2, - "float64", - ), - ( - 2 + (fw + fx) + (dw + fz), - (fw, fx, dw, fz), - (fwv, fxv, dwv, fzv), - 1, - fwv + fxv + dwv + fzv + 2, - "float64", - ), - ( - 2 + (fw + fx) + (fy + dw), - (fw, fx, fy, dw), - (fwv, fxv, fyv, dwv), - 1, - fwv + fxv + fyv + dwv + 2, - "float64", - ), # 25 - # test when their is other op then elemwise. - ( - (fwx.sum()) + (fwx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 4, - (fwv + fxv).sum() + fwv + fxv + fyv + fzv, - "float32", - ), - # test other elemwise op - ( - fx + fy + cos(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.cos(fzv), - "float32", - ), - ( - fx + fy + cosh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.cosh(fzv), - "float32", - ), - ( - fx + fy + abs(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.absolute(fzv), - "float32", - ), - ( - ix + iy + abs(iz), - (ix, iy, iz), - (ixv, iyv, izv), - 1, - ixv + iyv + np.absolute(izv), - "int32", - ), # 30 - ( - fx + fy + log(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log(fzv), - "float32", - ), - ( - fx + fy + log2(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log2(fzv), - "float32", - ), - ( - fx + fy + log10(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log10(fzv), - "float32", - ), - ( - fx + fy**fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv**fzv, - "float32", - ), # pow - ( - fx + fy + exp(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.exp(fzv), - "float32", - ), # 35 - ( - fx - fy - fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv - fzv, - "float32", - ), - ( - fx - (fy / fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv / fzv), - "float32", - ), - ( - fx - true_div(fy, 2), - (fx, fy), - (fxv, fyv), - 1, - fxv - (fyv / 2), - "float32", - ), - ( - fx - true_div(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv / fzv), - "float32", - ), - ( - fx - int_div(ix * 100, iy * 1000), - (fx, ix, iy), - (fxv, ixv, iyv), - 1, - fxv - ((ixv * 100) // (iyv * 1000)), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), # 40 - (fx - (fy / 2), (fx, fy), (fxv, fyv), 1, fxv - (fyv / 2), "float32"), - ( - fx - (fy % fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv % fzv), - "float32", - ), - ( - fx - (fy > fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv > fzv), - "float32", - ), - ( - fx - (fy >= fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv >= fzv), - "float32", - ), - ( - fx - (fy < fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv < fzv), - "float32", - ), # 45 - ( - fx - (fy <= fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv <= fzv), - "float32", - ), - ( - fx - eq(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv == fzv), - "float32", - ), - ( - fx - neq(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv != fzv), - "float32", - ), - ( - fx - fy + tan(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.tan(fzv), - "float32", - ), - ( - fx - fy + tanh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.tanh(fzv), - "float32", - ), # 50 - ( - fx - fy + sin(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sin(fzv), - "float32", - ), - ( - fx - fy + sinh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sinh(fzv), - "float32", - ), - ( - fx - fy + sqr(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (fzv * fzv), - "float32", - ), - ( - fx - fy + sqrt(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sqrt(fzv), - "float32", - ), - ( - fx - fy + reciprocal(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (1 / fzv), - "float32", - ), # 55 - ( - fx - fy + neg(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (-fzv), - "float32", - ), - ( - fx - fy + at_round(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.round(fzv), - "float32", - ), - ( - ix - iy + iround(fz), - (ix, iy, fz), - (ixv, iyv, fzv), - 1, - ixv - iyv + np.round(fzv), - "int64", - ), - # Bit op - ( - fx - bitwise_or(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv | izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - xor(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv ^ izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), # 60 - ( - fx - bitwise_and(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv & izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - invert(iy), - (fx, iy), - (fxv, iyv), - 1, - fxv - (~iyv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - at.cast(fy, dtype="float64"), - (fx, fy), - (fxv, fyv), - 1, - fxv - np.asarray(fyv, "float64"), - "float64", - ), - ( - at_pow(fx * fy + fz, fx * fy), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - np.power(fxv * fyv + fzv, fxv * fyv), - "float32", - ), - ( - fv + fy**fz, - (fv, fy, fz), - (fvv, fyv, fzv), - 2, - fvv + fyv**fzv, - "float32", - ), # fused with a dimshuffle #65 - ( - fv - fy + tanh(fz), - (fv, fy, fz), - (fvv, fyv, fzv), - 2, - fvv - fyv + np.tanh(fzv), - "float32", - ), # fused with a dimshuffle - # Cases where the same input is reused many times. - ( - mul(fx, fx, fx, fx), - (fx,), - (fxv,), - 1, - fxv * fxv * fxv * fxv, - "float32", - ), - ( - mul(fx, ftanx, ftanx), - (fx,), - (fxv,), - 1, - fxv * np.tan(fxv) * np.tan(fxv), - "float32", - ), - ( - mul(fx, ftanx, ftanx, fx), - (fx,), - (fxv,), - 1, - fxv * np.tan(fxv) * np.tan(fxv) * fxv, - "float32", - ), - ( - mul(ftanx, ftanx, fx + fy), - (fx, fy), - (fxv, fyv), - 1, - np.tan(fxv) * np.tan(fxv) * (fxv + fyv), - "float32", - ), # 70 - # Cases with different broadcast pattern. They should not - # be merged as this would duplicate computation - # The graph should have 2 elemwise and 1 dimshuffle - ( - fx * sin(fs), - (fx, fs), - (fxv, fsv), - 3, - fxv * np.sin(fsv), - "float32", - ), - ] - if slice: - cases = cases[slice] - times = np.zeros(len(cases)) - fail1 = [] - fail2 = [] - fail3 = [] - fail4 = [] - for ( - id, - [g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype], - ) in enumerate(cases): - if isinstance(out_dtype, dict): - out_dtype = out_dtype[config.cast_policy] - - if shared_fn is None: - f = function(list(sym_inputs), g, mode=mode) - for x in range(nb_repeat): - out = f(*val_inputs) - t1 = time.perf_counter() - else: - out = shared_fn(np.zeros(shp, dtype=out_dtype), "out") - assert out.dtype == g.dtype - f = function(sym_inputs, [], updates=[(out, g)], mode=mode) - t0 = time.perf_counter() - for x in range(nb_repeat): - f(*val_inputs) - t1 = time.perf_counter() - out = out.get_value() - - times[id] = t1 - t0 - atol = 1e-8 - if out_dtype == "float32": - atol = 1e-6 - if not np.allclose(out, answer * nb_repeat, atol=atol): - fail1.append(id) - topo = f.maker.fgraph.toposort() - topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] - if assert_len_topo: - if len(topo_) != nb_elemwise: - fail3.append((id, topo_, nb_elemwise)) - if nb_elemwise == 1: - # if no variable appears multiple times in the - # input of g, - # check that the number of input to the Composite - # Elemwise is ok - if len(set(g.owner.inputs)) == len(g.owner.inputs): - expected_len_sym_inputs = sum( - not isinstance(x, Constant) for x in topo_[0].inputs - ) - assert expected_len_sym_inputs == len(sym_inputs) - - if out_dtype != out.dtype: - fail4.append((id, out_dtype, out.dtype)) - - assert len(fail1 + fail2 + fail3 + fail4) == 0 - - return times - - def test_add_mul_fusion_inplace(self): - - rewrites_query = RewriteDatabaseQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - - mode = Mode(self.mode.linker, rewrites_query) - - x, y, z = dmatrices("xyz") - out = dot(x, y) + x + y + z - f = function([x, y, z], out, mode=mode) - topo = [n for n in f.maker.fgraph.toposort()] - assert len(topo) == 2 - assert topo[-1].op.inplace_pattern - - new_out = f.maker.fgraph.outputs[0] - assert isinstance(new_out.owner.op, Elemwise) - assert isinstance(new_out.owner.op.scalar_op, aes.basic.Add) - assert len(new_out.owner.inputs) == 4 - - # TODO: Do we really need to do this? - _ = f( - np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) - ) - - @utt.assertFailure_fast def test_log1p(): m = config.mode From 00b5b902d28e82243094e202ee7adbbcaa0e03db Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 7 Oct 2022 10:50:03 +0200 Subject: [PATCH 03/13] Add direct test for nested broadcasted Composite graphs --- tests/tensor/rewriting/test_elemwise.py | 31 +++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 613a956384..943aac9970 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1167,6 +1167,37 @@ def test_CAReduce_multiple_inputs(self, linker, axis): assert out_val.shape == exp_res.shape assert np.allclose(out_val, exp_res) + def test_not_fusing_broadcasted_subgraphs(self): + """Test that broadcasted Elemwise subgraphs are not fused in a single Elemwise Composite Op. + + There are some cases in self.test_elemwise_fusion, but this test confirms that the + fused subgraphs are exactly the expected ones. + """ + xs = vector("xm") + xm = matrix("xs") + + es = log(xs + 5) + em = exp(xm * 5) + esm = es - em + + f = pytensor.function([xs, xm], esm, mode=self.mode) + apply_nodes = f.maker.fgraph.toposort() + assert len(apply_nodes) == 3 + assert isinstance(apply_nodes[0].op, DimShuffle) + # Inner Vector output Composite + assert isinstance(apply_nodes[1].op.scalar_op, Composite) + assert {node.op for node in apply_nodes[1].op.scalar_op.fgraph.apply_nodes} == { + aes.add, + aes.log, + } + # Outer Matrix output Composite + assert isinstance(apply_nodes[2].op.scalar_op, Composite) + assert {node.op for node in apply_nodes[2].op.scalar_op.fgraph.apply_nodes} == { + aes.sub, + aes.exp, + aes.mul, + } + class TimesN(aes.basic.UnaryScalarOp): """ From da87b0cc93b1bbe5ce4156691305b8d016a7dcb4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Oct 2022 20:28:40 +0200 Subject: [PATCH 04/13] Fix bug in Composite when multiple outputs are identical --- pytensor/scalar/basic.py | 15 +++++++++++++++ tests/scalar/test_basic.py | 11 +++++++++++ 2 files changed, 26 insertions(+) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index f428f2528b..fb2656e50e 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4146,6 +4146,21 @@ def fgraph(self): "The fgraph to Composite must be exclusively" " composed of ScalarOp instances." ) + + # Clone identical outputs that have been merged + if len(set(fgraph.outputs)) != len(self.outputs): + old_outputs = fgraph.outputs + new_outputs = [] + for output in old_outputs: + if output not in new_outputs: + new_outputs.append(output) + else: + node = output.owner + output_idx = node.outputs.index(output) + new_output = node.clone().outputs[output_idx] + new_outputs.append(new_output) + fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False) + self._fgraph = fgraph return self._fgraph diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 02dc2001f8..6e694f5be6 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -156,6 +156,17 @@ def test_many_outputs(self): fn = make_function(DualLinker().accept(g)) assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] + def test_identical_outputs(self): + x, y, z = floats("xyz") + e0 = x + y + z + e1 = x + y + z + e2 = x / y + C = Composite([x, y, z], [e0, e1, e2]) + c = C.make_node(x, y, z) + g = FunctionGraph([x, y, z], c.outputs) + fn = make_function(DualLinker().accept(g)) + assert fn(1.0, 2.0, 3.0) == [6.0, 6.0, 0.5] + def test_composite_printing(self): x, y, z = floats("xyz") e0 = x + y + z From 1f490944144765acbf5b7782341b48a308d60f07 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 11 Oct 2022 14:40:41 +0200 Subject: [PATCH 05/13] Disable invalid inplace logic for multiple-output Composites --- pytensor/tensor/rewriting/elemwise.py | 12 ++++++++-- tests/tensor/rewriting/test_elemwise.py | 29 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 47f2de28d7..6e7a1763da 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -59,6 +59,14 @@ def print_profile(cls, stream, prof, level=0): for n in sorted(ndim.keys()): print(blanc, n, ndim[n], file=stream) + def candidate_input_idxs(self, node): + if isinstance(node.op.scalar_op, aes.Composite) and len(node.outputs) > 1: + # TODO: Implement specialized InplaceCompositeOptimizer with logic + # needed to correctly assign inplace for multi-output Composites + return [] + else: + return range(len(node.outputs)) + def apply(self, fgraph): r""" @@ -149,7 +157,7 @@ def apply(self, fgraph): baseline = op.inplace_pattern candidate_outputs = [ - i for i in range(len(node.outputs)) if i not in baseline + i for i in self.candidate_input_idxs(node) if i not in baseline ] # node inputs that are Constant, already destroyed, # or fgraph protected inputs and fgraph outputs can't be used as @@ -167,7 +175,7 @@ def apply(self, fgraph): ] else: baseline = [] - candidate_outputs = list(range(len(node.outputs))) + candidate_outputs = self.candidate_input_idxs(node) # node inputs that are Constant, already destroyed, # fgraph protected inputs and fgraph outputs can't be used as inplace # target. diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 943aac9970..f2955e1a26 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -4,6 +4,7 @@ import pytest import pytensor +from pytensor import In from pytensor import scalar as aes from pytensor import shared from pytensor import tensor as at @@ -1024,6 +1025,34 @@ def test_add_mul_fusion_inplace(self): np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) ) + def test_fusion_multiout_inplace(self): + x = vector("x") + + # Create Composite where inplacing the first non-constant output would corrupt the second output + xs = aes.float64("xs") + outs = ( + Elemwise(Composite([xs], [xs + 1, aes.cos(xs + 1) + xs])) + .make_node(x) + .outputs + ) + + f = pytensor.function( + [In(x, mutable=True)], + outs, + mode=self.mode.including("inplace"), + ) + (composite_node,) = f.maker.fgraph.apply_nodes + + # Destroy map must be None or the last toposorted output + destroy_map = composite_node.op.destroy_map + assert (destroy_map == {}) or ( + destroy_map == {1: [composite_node.inputs.index(x)]} + ) + + res = f([0, 1, 2]) + assert np.allclose(res[0], [1, 2, 3]) + assert np.allclose(res[1], np.cos([1, 2, 3]) + np.array([0, 1, 2])) + @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") def test_no_c_code(self): r"""Make sure we avoid fusions for `Op`\s without C code implementations.""" From dd4509909e6cf421de616efa1fc01c99c1f7ce9e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 7 Feb 2023 17:35:37 +0100 Subject: [PATCH 06/13] Exclude unnecessary inputs in useless_composite rewrite --- pytensor/tensor/rewriting/elemwise.py | 30 ++++++++++++++++--------- tests/tensor/rewriting/test_elemwise.py | 29 ++++++++++++++++++------ 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 6e7a1763da..de223c895b 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -990,23 +990,33 @@ def print_profile(cls, stream, prof, level=0): @register_canonicalize +@register_specialize @node_rewriter([Elemwise]) def local_useless_composite(fgraph, node): - """For elemwise Composite that have multiple outputs, remove the - outputs that are not used. - - """ + """Remove inputs and outputs of Composite Ops that are not used anywhere.""" if not isinstance(node.op, Elemwise) or not isinstance( node.op.scalar_op, aes.Composite ): return comp = node.op.scalar_op - idx = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]] - if len(idx) < len(node.outputs): - new_outputs = [comp.outputs[i] for i in idx] - c = aes.Composite(inputs=comp.inputs, outputs=new_outputs) - e = Elemwise(scalar_op=c)(*node.inputs, return_list=True) - return dict(zip([node.outputs[i] for i in idx], e)) + used_outputs_idxs = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]] + used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs] + comp_fgraph = FunctionGraph( + inputs=comp.inputs, outputs=used_inner_outputs, clone=False + ) + used_inputs_idxs = [ + i + for i, i_intern in enumerate(comp_fgraph.inputs) + if comp_fgraph.clients[i_intern] + ] + used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs] + if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len( + node.outputs + ): + used_inputs = [node.inputs[i] for i in used_inputs_idxs] + c = aes.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) + e = Elemwise(scalar_op=c)(*used_inputs, return_list=True) + return dict(zip([node.outputs[i] for i in used_outputs_idxs], e)) @node_rewriter([CAReduce]) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index f2955e1a26..899d056588 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1292,22 +1292,37 @@ def test_nested_composite(self): def test_local_useless_composite(self): x = aes.float32() - c = aes.Composite([x], [x + 1, x - 1]) - X = matrix() - o = Elemwise(scalar_op=c)(X) + y = aes.float32() + z = aes.float32() + c = aes.Composite([x, y, z], [x + 1, y - 1]) + X = matrix("X") + Y = matrix("Y") + Z = matrix("Z") + o1, o2 = Elemwise(scalar_op=c)(X, Y, Z) mode = get_default_mode().including("local_useless_composite") - f = function([X], o[0], mode=mode) + f = function([X, Y, Z], [o1, o2], mode=mode) topo = f.maker.fgraph.toposort() assert len(topo) == 1 + assert len(topo[0].inputs) == 2 + assert len(topo[0].outputs) == 2 + res1, res2 = f([[1.0]], [[1.0]], [[np.nan]]) + utt.assert_allclose(res1, [[2.0]]) + utt.assert_allclose(res2, [[0.0]]) + + f = function([X, Y, Z], o1, mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert len(topo[0].inputs) == 1 assert len(topo[0].outputs) == 1 - utt.assert_allclose(f([[1.0]]), [[2.0]]) + utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]]) - f = function([X], o[1], mode=mode) + f = function([X, Y, Z], o2, mode=mode) topo = f.maker.fgraph.toposort() assert len(topo) == 1 + assert len(topo[0].inputs) == 1 assert len(topo[0].outputs) == 1 - utt.assert_allclose(f([[1.0]]), [[0.0]]) + utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) def test_local_useless_dimshuffle_makevector(): From 3304baf9c1055293f8b792f87a086fd4d9f4e628 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 6 Oct 2022 20:05:02 +0200 Subject: [PATCH 07/13] Fuse consecutive Elemwise nodes with multiple clients --- pytensor/tensor/elemwise.py | 22 +- pytensor/tensor/rewriting/elemwise.py | 749 +++++++++++++---------- scripts/mypy-failing.txt | 1 - tests/compile/function/test_pfunc.py | 9 +- tests/tensor/rewriting/test_elemwise.py | 185 ++++-- tests/tensor/rewriting/test_subtensor.py | 5 +- tests/test_printing.py | 4 +- 7 files changed, 596 insertions(+), 379 deletions(-) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 4ac1ffdd33..17d4c5776b 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -652,10 +652,10 @@ def transform(r): def prepare_node(self, node, storage_map, compute_map, impl): # Postpone the ufunc building to the last minutes due to: - # - NumPy ufunc support only up to 31 inputs. + # - NumPy ufunc support only up to 32 operands (inputs and outputs) # But our c code support more. # - nfunc is reused for scipy and scipy is optional - if len(node.inputs) > 32 and self.ufunc and impl == "py": + if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py": impl = "c" if getattr(self, "nfunc_spec", None) and impl != "c": @@ -677,7 +677,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): self.nfunc = module if ( - len(node.inputs) < 32 + (len(node.inputs) + len(node.outputs)) <= 32 and (self.nfunc is None or self.scalar_op.nin != len(node.inputs)) and self.ufunc is None and impl == "py" @@ -727,28 +727,18 @@ def prepare_node(self, node, storage_map, compute_map, impl): self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl) def perform(self, node, inputs, output_storage): - if len(node.inputs) >= 32: + if (len(node.inputs) + len(node.outputs)) > 32: # Some versions of NumPy will segfault, other will raise a - # ValueError, if the number of inputs to a ufunc is 32 or more. + # ValueError, if the number of operands in an ufunc is more than 32. # In that case, the C version should be used, or Elemwise fusion # should be disabled. + # FIXME: This no longer calls the C implementation! super().perform(node, inputs, output_storage) for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))): if len(set(dim_shapes) - {1}) > 1: raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}") - # Determine the shape of outputs - out_shape = [] - for values in zip(*[input.shape for input in inputs]): - if any(v == 0 for v in values): - # All non-broadcasted dimensions should be zero - assert max(values) <= 1 - out_shape.append(0) - else: - out_shape.append(max(values)) - out_shape = tuple(out_shape) - ufunc_args = inputs ufunc_kwargs = {} # We supported in the past calling manually op.perform. diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index de223c895b..19c5eabd03 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -1,17 +1,18 @@ import sys -import time -from collections import defaultdict -from typing import Optional +from collections import defaultdict, deque +from functools import lru_cache +from typing import DefaultDict, Generator, List, Set, Tuple, TypeVar from warnings import warn import pytensor import pytensor.scalar.basic as aes -from pytensor import compile +from pytensor import clone_replace, compile from pytensor.compile.mode import get_target_language from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, Constant, io_toposort +from pytensor.graph import FunctionGraph +from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort from pytensor.graph.features import ReplaceValidate -from pytensor.graph.op import compute_test_value, get_test_value +from pytensor.graph.fg import ApplyOrOutput from pytensor.graph.rewriting.basic import ( EquilibriumGraphRewriter, GraphRewriter, @@ -20,7 +21,7 @@ node_rewriter, ) from pytensor.graph.rewriting.db import SequenceDB -from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError +from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError @@ -592,333 +593,438 @@ def local_add_mul_fusion(fgraph, node): return [output] -def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None): - r"""Create a recursive function that fuses `Elemwise` `Op`\s. - - The basic idea is that we loop through an `Elemwise` node's inputs, find - other `Elemwise` nodes, determine the scalars input types for all of the - `Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types - and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a - new "fused" `Elemwise`. - - It's parameterized in order to work for `Elemwise` `Op`\s. - - Parameters - ---------- - op_class : type - `Elemwise` class (the one that we want to fuse) - max_input_fct : callable - A function that returns the maximum number of inputs that this `Elemwise` - can take. - On the CPU we limit to 32 input variables since that is the maximum - NumPy support. - - maker: callable - A function with the signature ``(node, *args)`` that constructs an - `op_class` instance (e.g. ``op_class(*args)``). - - """ - if maker is None: - - def maker(node, scalar_op): - return op_class(scalar_op) - - def local_fuse(fgraph, node): - r"""Fuse `Elemwise` `Op`\s in a node. - - As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the - same shape. +def elemwise_max_operands_fct(node) -> int: + # `Elemwise.perform` uses NumPy ufuncs and they are limited to 32 operands (inputs and outputs) + if not config.cxx: + return 32 + return 1024 - For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C - compiler do the cast. - The number of dimensions is validated at call time by PyTensor itself. +class FusionOptimizer(GraphRewriter): + """Graph optimizer that fuses consecutive Elemwise operations.""" - """ - # TODO: use broadcast flag? + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) - # TODO: don't do this rewrite as a `NodeRewriter`. - # Analyze the graph in terms of elemwise subgraphs, and then - # replace each subgraph with a Composite version. + @staticmethod + def elemwise_to_scalar(inputs, outputs): + replace_inputs = [(inp, inp.clone()) for inp in inputs] + outputs = clone_replace(outputs, replace=replace_inputs) - # TODO: use malloc and copy to transfer arguments that don't - # fit within the parameter space of 256 bytes - # - # TODO: Merge with multiple output to merge when an inputs - # have multiple clients. This can't be done with a `NodeRewriter` - - # TODO: Related: Support composites with multiple outputs - - # TODO: Use Composite to combine Elemwise and Reduce - # operations. We have to loop over the data anyway... might - # as well sum it up while we're at it (this can be trickier - # than i'm making it seound here. The data-traversal should be - # done contiguously, and the summing-up might not be easy or - # worthwhile if the summation axis doesn't line up with a - # contiguous dimension) - - if type(node.op) is not op_class: - return False - - if len(node.outputs) > 1: - # We don't support fusion for nodes with multiple outputs. - return - - inputs = [] # inputs of the new Elemwise op. - s_inputs = [] # inputs of the new scalar op used by the Composite. - # Inputs of the new scalar op that represents the current node. - s_g = [] - - # There is a hard limit of 256 bytes for the formal argument list to a - # GPU kernel function. - max_nb_input = max_input_fct(node) - # The number of inputs to the new fused op if we do not fuse more - # inputs. - new_nb_input = len(node.inputs) - # Did we fuse something? - # Needed as we can fuse unary op that don't change the number of - # inputs. - # And there is a case where the inputs are the same as the current - # node. That won't change the number of inputs of the new op. - fused = False - - for i in node.inputs: - scalar_node: Optional[Apply] = None - # Will store inputs of the fused node that are not currently inputs - # of the node we want to create (to avoid duplicating inputs). - tmp_input = [] - # Same as tmp_input, but for scalars. - tmp_scalar = [] - - # We should not check the number of inputs here - # As fusing op don't always change the number of input. - # If a variable is used as multiple into to the same node, - # we still want to fusion. So we take the set. - if ( - i.owner - and isinstance(i.owner.op, op_class) - and len({n for n, idx in fgraph.clients[i]}) == 1 - and - # Do not merge elemwise that don't have the same - # broadcastable pattern to don't redo duplicate - # computation due to broadcast. - i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable - ): - try: - tmp_s_input = [] - # we should not put duplicate input into s_inputs and inputs - for ii in i.owner.inputs: - if ii in inputs: - tmp_s_input.append(s_inputs[inputs.index(ii)]) - elif ii in tmp_input: - tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) - else: - tmp = aes.get_scalar_type(ii.type.dtype).make_variable() - - try: - tv = get_test_value(ii) - # Sometimes the original inputs have - # zero-valued shapes in some dimensions, which - # implies that this whole scalar thing doesn't - # make sense (i.e. we're asking for the scalar - # value of an entry in a zero-dimensional - # array). - # This will eventually lead to an error in the - # `compute_test_value` call below when/if - # `config.compute_test_value_opt` is enabled - # (for debugging, more or less) - tmp.tag.test_value = tv.item() - except (TestValueError, ValueError): - pass - - tmp_s_input.append(tmp) - tmp_input.append(ii) - tmp_scalar.append(tmp_s_input[-1]) - - # Use the `Op.make_node` interface in case `Op.__call__` - # has been customized - scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input) - - if config.compute_test_value_opt != "off": - # This is required because `Op.make_node` won't do it - compute_test_value(scalar_node) - - # If the scalar_op doesn't have a C implementation, we skip - # its fusion to allow fusion of the other ops - i.owner.op.scalar_op.c_code( - scalar_node, - "test_presence_of_c_code", - ["x" for x in i.owner.inputs], - ["z" for z in i.owner.outputs], - {"fail": "%(fail)s"}, - ) + inputs = [inp for _, inp in replace_inputs] + fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) + middle_inputs = [] - except (NotImplementedError, MethodNotDefined): - warn( - "Rewrite warning: " - f"The Op {i.owner.op.scalar_op} does not provide a C implementation." - " As well as being potentially slow, this also disables " - "loop fusion." + scalar_inputs = [ + aes.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + ] + middle_scalar_inputs = [] + + for node in fg.toposort(): + node_scalar_inputs = [] + for inp in node.inputs: + if inp in inputs: + node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) + elif inp in middle_inputs: + node_scalar_inputs.append( + middle_scalar_inputs[middle_inputs.index(inp)] ) - scalar_node = None - - # Compute the number of inputs in case we fuse this input. - # We subtract 1 because we replace the existing input with the new - # inputs from `tmp_input`. - new_nb_input_ = new_nb_input + len(tmp_input) - 1 - - # If the new input is already an input of the current node, it was - # already counted when `new_nb_input` was initialized to - # len(node.inputs). - # This can happen when a variable is used both by the Elemwise to - # fuse and the current node. - for x in tmp_input: - if x in node.inputs: - new_nb_input_ -= 1 - - if scalar_node and (new_nb_input_ <= max_nb_input): - fused = True - new_nb_input = new_nb_input_ - inputs.extend(tmp_input) - s_inputs.extend(tmp_scalar) - s_g.extend(scalar_node.outputs) - else: - # We must support the case where the same variable appears many - # times within the inputs - if inputs.count(i) == node.inputs.count(i): - s = s_inputs[inputs.index(i)] else: - s = aes.get_scalar_type(i.type.dtype).make_variable() - if config.compute_test_value_opt != "off": - try: - v = get_test_value(i) - # See the zero-dimensional test value situation - # described above. - s.tag.test_value = v.item() - except (TestValueError, ValueError): - pass - - inputs.append(i) - s_inputs.append(s) - s_g.append(s) - - if not fused: - return False - - if new_nb_input != len(inputs) or len(s_inputs) != len(inputs): - # TODO FIXME: This shouldn't be a generic `Exception` - raise Exception( - "Something has gone wrong with the elemwise fusion rewrite; skipping." - ) - - s_new_out = node.op.scalar_op(*s_g, return_list=True) - try: - s_new_out[0].owner.op.c_code( - s_new_out[0].owner, - "test_presence_of_c_code", - ["x" for x in s_g], - ["z" for x in s_new_out], - {"fail": "%(fail)s"}, - ) - except (NotImplementedError, MethodNotDefined): - name = str(s_new_out[0].owner.op) - warn( - "Rewrite warning: " - f"The Op {name} does not provide a C implementation." - " As well as being potentially slow, this also disables " - "loop fusion." - ) - return False - - # create the composite op. - composite_op = aes.Composite(s_inputs, s_new_out) - - # create the new node. - # Do not call make_node to have test_value - new_node = maker(node, composite_op)(*inputs).owner - - assert len(new_node.outputs) == 1 - assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype + new_scalar_input = aes.get_scalar_type( + inp.type.dtype + ).make_variable() + node_scalar_inputs.append(new_scalar_input) + middle_scalar_inputs.append(new_scalar_input) + middle_inputs.append(inp) + + new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) + middle_scalar_inputs.append(new_scalar_node.outputs[0]) + middle_inputs.append(node.outputs[0]) + + scalar_outputs = [ + middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs + ] + return scalar_inputs, scalar_outputs - if len(new_node.inputs) > max_nb_input: - warn( - "Loop fusion failed because the resulting node " - "would exceed the kernel argument limit." - ) - return False - - # we fuse as many that we can at the same time to make debug mode faster - # debug mode will be faster as it won't test all intermediate step. - while True: - ret = local_fuse(fgraph, new_node) - if ret is not False and ret is not None: - assert len(ret) == len(new_node.outputs) - assert len(ret) == 1 - new_node = ret[0].owner - else: - break + def apply(self, fgraph): + nb_replacement = 0 - return new_node.outputs + if fgraph.profile: + validate_before = fgraph.profile.validate_time + callbacks_before = fgraph.execute_callbacks_times.copy() + callback_before = fgraph.execute_callbacks_time - return local_fuse + max_operands = elemwise_max_operands_fct(None) + + def find_next_fuseable_subgraph( + fg: FunctionGraph, + ) -> Generator[Tuple[List[Variable], List[Variable]], None, None]: + """Find all subgraphs in a FunctionGraph that can be fused together + + Yields + ------- + List of inputs and outputs that determine subgraphs which can be fused. + This generator assumes that such subgraph is replaced by a single + Elemwise Composite before being accessed again in the next iteration. + """ + + FUSEABLE_MAPPING = DefaultDict[Variable, List[Apply]] + UNFUSEABLE_MAPPING = DefaultDict[Variable, Set[ApplyOrOutput]] + + def initialize_fuseable_mappings( + *, fg: FunctionGraph + ) -> Tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]: + @lru_cache(maxsize=None) + def elemwise_scalar_op_has_c_code(node: Apply) -> bool: + # TODO: This should not play a role in non-c backends! + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + else: + warn( + "Optimization Warning: " + f"The Op {node.op.scalar_op} does not provide a C implementation." + " As well as being potentially slow, this also disables " + "loop fusion." + ) + return False + + # Fuseable nodes have to be accessed in a deterministic manner + # to ensure the rewrite remains deterministic. + # This is not a problem from unfuseable ones, as they can never + # become part of the graph. + fuseable_clients: FUSEABLE_MAPPING = defaultdict(list) + unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) + for out, clients in fg.clients.items(): + out_maybe_fuseable = ( + out.owner + and isinstance(out.owner.op, Elemwise) + # and not isinstance(out.owner.op.scalar_op, aes.Composite) + and len(out.owner.outputs) == 1 + and elemwise_scalar_op_has_c_code(out.owner) + ) + for client, _ in clients: + if ( + out_maybe_fuseable + and not isinstance(client, str) # "output" + and isinstance(client.op, Elemwise) + # and not isinstance(client.op.scalar_op, aes.Composite) + and len(client.outputs) == 1 + and out.type.broadcastable + == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ): + if client not in fuseable_clients[out]: + fuseable_clients[out].append(client) + else: + unfuseable_clients[out].add(client) + + return fuseable_clients, unfuseable_clients + + def find_fuseable_subgraph( + *, + fg: FunctionGraph, + visited_nodes: Set[Apply], + fuseable_clients: FUSEABLE_MAPPING, + unfuseable_clients: UNFUSEABLE_MAPPING, + ) -> Tuple[List[Variable], List[Variable]]: + + KT = TypeVar("KT") + VT = TypeVar("VT", list, set) + + def shallow_clone_defaultdict( + d: DefaultDict[KT, VT] + ) -> DefaultDict[KT, VT]: + new_dict: DefaultDict[KT, VT] = defaultdict(d.default_factory) + new_dict.update({k: v.copy() for k, v in d.items()}) + return new_dict + + def variables_depend_on( + variables, depend_on, stop_search_at=None + ) -> bool: + return any( + a in depend_on + for a in ancestors(variables, blockers=stop_search_at) + ) + toposort = fg.toposort() + for starting_node in toposort: + if starting_node in visited_nodes: + continue -def elemwise_max_input_fct(node): - # `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs. - if not config.cxx: - return 31 - return 1024 + starting_out = starting_node.outputs[0] + if not fuseable_clients.get(starting_out): + visited_nodes.add(starting_node) + continue + subgraph_inputs: List[Variable] = [] + subgraph_outputs: List[Variable] = [] + unfuseable_clients_subgraph: Set[Variable] = set() -local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct) + # Shallow cloning of maps so that they can be manipulated in place + fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients) + unfuseable_clients_clone = shallow_clone_defaultdict( + unfuseable_clients + ) + fuseable_nodes_to_visit = deque([starting_node]) + + # We now try to expand as much as possible towards the potentially + # fuseable clients and ancestors to detect the largest possible + # subgraph that can be Composed together into a single `Op`. The + # largest issue to watch out is for cyclical dependencies, where + # some inputs or clients may depend on other nodes of the same + # subgraph via a path that cannot be included in the Composite + # (unfuseable) + while fuseable_nodes_to_visit: + next_node = fuseable_nodes_to_visit.popleft() + visited_nodes.add(next_node) + next_out = next_node.outputs[0] + + # If the output variable of next_node has no fuseable clients + # or has unfuseable clients, then next_node must become an output + # if it is to be fused. + must_become_output = ( + next_out not in fuseable_clients_temp + or next_out in unfuseable_clients_clone + ) -class FusionOptimizer(GraphRewriter): - """Graph rewriter that simply runs node fusion operations. + # We have backtracked to this node, and it may no longer be a viable output, + # so we remove it and check again as if we had never seen this node + if must_become_output and next_out in subgraph_outputs: + subgraph_outputs.remove(next_out) + + required_unfuseable_inputs = [ + inp + for inp in next_node.inputs + if next_node in unfuseable_clients_clone.get(inp, ()) + ] + new_required_unfuseable_inputs = [ + inp + for inp in required_unfuseable_inputs + if inp not in subgraph_inputs + ] + + must_backtrack = False + if new_required_unfuseable_inputs and subgraph_outputs: + # We need to check that any new inputs required by this node + # do not depend on other outputs of the current subgraph, + # via an unfuseable path. + if variables_depend_on( + [next_out], + depend_on=unfuseable_clients_subgraph, + stop_search_at=subgraph_outputs, + ): + must_backtrack = True + + if not must_backtrack: + implied_unfuseable_clients = { + c + for client in unfuseable_clients_clone.get(next_out, ()) + if not isinstance(client, str) # "output" + for c in client.outputs + } + + new_implied_unfuseable_clients = ( + implied_unfuseable_clients - unfuseable_clients_subgraph + ) - TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that. + if new_implied_unfuseable_clients and subgraph_inputs: + # We need to check that any inputs of the current subgraph + # do not depend on other clients of this node, + # via an unfuseable path. + if variables_depend_on( + subgraph_inputs, + depend_on=new_implied_unfuseable_clients, + ): + must_backtrack = True + + if must_backtrack: + for inp in next_node.inputs: + if ( + inp.owner in visited_nodes + # next_node could have the same input repeated + and next_node in fuseable_clients_temp[inp] + ): + fuseable_clients_temp[inp].remove(next_node) + unfuseable_clients_clone[inp].add(next_node) + # This input must become an output of the subgraph, + # because it can't be merged with next_node. + # We will revisit it to make sure this is safe. + fuseable_nodes_to_visit.appendleft(inp.owner) + + for client in fuseable_clients_temp[next_out]: + if client in visited_nodes: + fuseable_clients_temp[next_out].remove(client) + unfuseable_clients_clone[next_out].add(client) + # next_out must become an input of the subgraph. + # We will revisit any of its clients currently + # in the subgraph to make sure this is safe. + fuseable_nodes_to_visit.appendleft(client) + + # Revisit node at a later time + visited_nodes.remove(next_node) + continue + + # Adding next_node to subgraph does not result in any + # immediate dependency problems. Update subgraph + # mappings as if it next_node was part of it. + # Useless inputs will be removed by the useless Composite rewrite + for inp in new_required_unfuseable_inputs: + if inp not in subgraph_inputs: + subgraph_inputs.append(inp) + + if must_become_output: + subgraph_outputs.append(next_out) + unfuseable_clients_subgraph.update( + new_implied_unfuseable_clients + ) - """ + # Expand through unvisited fuseable ancestors + for inp in sorted( + ( + inp + for inp in next_node.inputs + if ( + inp not in required_unfuseable_inputs + and inp.owner not in visited_nodes + ) + ), + key=lambda inp: toposort.index(inp.owner), + reverse=True, + ): + fuseable_nodes_to_visit.appendleft(inp.owner) + + # Expand through unvisited fuseable clients + for next_node in sorted( + ( + node + for node in fuseable_clients_temp.get(next_out, ()) + if node not in visited_nodes + ), + key=lambda node: toposort.index(node), + ): + fuseable_nodes_to_visit.append(next_node) + + # Don't return if final subgraph is just the original Elemwise + if len(subgraph_outputs) == 1 and set( + subgraph_outputs[0].owner.inputs + ) == set(subgraph_inputs): + # Update global fuseable mappings + # No input was actually fuseable + for inp in starting_node.inputs: + if starting_node in fuseable_clients.get(inp, ()): + fuseable_clients[inp].remove(starting_node) + unfuseable_clients[inp].add(starting_node) + # No client was actually fuseable + unfuseable_clients[starting_out].update( + fuseable_clients.pop(starting_out, ()) + ) + continue - def __init__(self, node_rewriter): - super().__init__() - self.node_rewriter = node_rewriter + return subgraph_inputs, subgraph_outputs + raise ValueError + + def update_fuseable_mappings_after_fg_replace( + *, + fg: FunctionGraph, + visited_nodes: Set[Apply], + fuseable_clients: FUSEABLE_MAPPING, + unfuseable_clients: UNFUSEABLE_MAPPING, + starting_nodes: Set[Apply], + ) -> None: + # Find new composite node and dropped intermediate nodes + # by comparing the current fg.apply nodes with the cached + # original nodes + next_nodes = fg.apply_nodes + (new_composite_node,) = next_nodes - starting_nodes + dropped_nodes = starting_nodes - next_nodes + + # Remove intermediate Composite nodes from mappings + for dropped_node in dropped_nodes: + (dropped_out,) = dropped_node.outputs + fuseable_clients.pop(dropped_out, None) + unfuseable_clients.pop(dropped_out, None) + visited_nodes.remove(dropped_node) + + # Update fuseable information for subgraph inputs + for inp in subgraph_inputs: + if inp in fuseable_clients: + new_fuseable_clients = [ + client + for client in fuseable_clients[inp] + if client not in dropped_nodes + ] + if new_fuseable_clients: + fuseable_clients[inp] = new_fuseable_clients + else: + fuseable_clients.pop(inp) + unfuseable_clients[inp] = ( + unfuseable_clients[inp] - dropped_nodes + ) | {new_composite_node} + + # Update fuseable information for subgraph outputs + for out in new_composite_node.outputs: + unfuseable_clients[out] = {client for client, _ in fg.clients[out]} + + visited_nodes.add(new_composite_node) + return + + # We start by creating two maps, 1) from each node to each potentially + # fuseable client (both nodes must be single output Elemwise with same + # broadcast type) and 2) from each node to each certainly unfuseable + # client (those that don't fit into 1)) + fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) + visited_nodes: Set[Apply] = set() + while True: + starting_nodes = fg.apply_nodes.copy() + try: + subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( + fg=fg, + visited_nodes=visited_nodes, + fuseable_clients=fuseable_clients, + unfuseable_clients=unfuseable_clients, + ) + except ValueError: + return + else: + # The caller is now expected to update fg in place, + # by replacing the subgraph with a Composite Op + yield subgraph_inputs, subgraph_outputs + + # This is where we avoid repeated work by using a stateful + # generator. For large models (as in `TestFusion.test_big_fusion`) + # this can provide huge speedups + update_fuseable_mappings_after_fg_replace( + fg=fg, + visited_nodes=visited_nodes, + fuseable_clients=fuseable_clients, + unfuseable_clients=unfuseable_clients, + starting_nodes=starting_nodes, + ) - def add_requirements(self, fgraph): - fgraph.attach_feature(ReplaceValidate()) + for inputs, outputs in find_next_fuseable_subgraph(fgraph): + if (len(inputs) + len(outputs)) > max_operands: + warn( + "Loop fusion failed because the resulting node would exceed " + "the kernel argument limit." + ) + break - def apply(self, fgraph): - did_something = True - nb_iter = 0 - nb_replacement = 0 - nb_inconsistency_replace = 0 - time_toposort = 0 - if fgraph.profile: - validate_before = fgraph.profile.validate_time - callbacks_before = fgraph.execute_callbacks_times.copy() - callback_before = fgraph.execute_callbacks_time - while did_something: - t0 = time.perf_counter() - nodelist = list(fgraph.toposort()) - time_toposort += time.perf_counter() - t0 - nodelist.reverse() - did_something = False - for node in nodelist: - # Don't try to fuse node that have already been fused. - if node in fgraph.apply_nodes: - new_outputs = self.node_rewriter(fgraph, node) - if new_outputs: - assert len(new_outputs) == len(node.outputs) - try: - fgraph.replace_all_validate( - list(zip(node.outputs, new_outputs)), - reason=self.__class__.__name__, - ) - did_something = True - nb_replacement += 1 - except InconsistencyError: - nb_inconsistency_replace += 1 - nb_iter += 1 + scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) + composite_outputs = Elemwise(aes.Composite(scalar_inputs, scalar_outputs))( + *inputs + ) + if not isinstance(composite_outputs, list): + composite_outputs = [composite_outputs] + for old_out, composite_out in zip(outputs, composite_outputs): + if old_out.name: + composite_out.name = old_out.name + + fgraph.replace_all_validate( + list(zip(outputs, composite_outputs)), + reason=self.__class__.__name__, + ) + nb_replacement += 1 if fgraph.profile: validate_time = fgraph.profile.validate_time - validate_before @@ -933,21 +1039,22 @@ def apply(self, fgraph): validate_time = None callback_time = None callbacks_time = {} + return ( self, - nb_iter, + 1, # nb_iter nb_replacement, - nb_inconsistency_replace, + 0, # nb_inconsintency_replace validate_time, callback_time, callbacks_time, - time_toposort, + -1, # toposort_time ) - @classmethod - def print_profile(cls, stream, prof, level=0): + @staticmethod + def print_profile(stream, prof, level=0): blanc = " " * level - print(blanc, cls.__name__, file=stream) + print(blanc, "FusionOptimizer", file=stream) print(blanc, " nb_iter", prof[1], file=stream) print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_inconsistency_replace", prof[3], file=stream) @@ -973,7 +1080,7 @@ def print_profile(cls, stream, prof, level=0): ) fuse_seqopt.register( "composite_elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), + FusionOptimizer(), "fast_run", "fusion", position=1, @@ -999,7 +1106,9 @@ def local_useless_composite(fgraph, node): ): return comp = node.op.scalar_op - used_outputs_idxs = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]] + used_outputs_idxs = [ + i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern] + ] used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs] comp_fgraph = FunctionGraph( inputs=comp.inputs, outputs=used_inner_outputs, clone=False diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 307dcb5572..8800294c74 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py pytensor/tensor/random/op.py pytensor/tensor/random/utils.py pytensor/tensor/rewriting/basic.py -pytensor/tensor/rewriting/elemwise.py pytensor/tensor/shape.py pytensor/tensor/slinalg.py pytensor/tensor/subtensor.py diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index c37ed77dea..d9c4548d24 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -2,7 +2,7 @@ import pytest import pytensor.tensor as at -from pytensor.compile import UnusedInputError +from pytensor.compile import UnusedInputError, get_mode from pytensor.compile.function import function, pfunc from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.io import In @@ -200,7 +200,12 @@ def test_shared_mutable(self): bval = np.arange(5) b.set_value(bval, borrow=True) bval = data_of(b) - f = pfunc([], [b_out], updates=[(b, (b_out + 3))], mode="FAST_RUN") + f = pfunc( + [], + [b_out], + updates=[(b, (b_out + 3))], + mode=get_mode("FAST_RUN").excluding("fusion"), + ) assert (f() == (np.arange(5) * 2)).all() # because of the update assert (b.get_value(borrow=True) == ((np.arange(5) * 2) + 3)).all() diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 899d056588..7484cd75d9 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1,5 +1,3 @@ -import contextlib - import numpy as np import pytest @@ -17,11 +15,14 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.misc.safe_asarray import _asarray +from pytensor.raise_op import assert_op from pytensor.scalar.basic import Composite from pytensor.tensor.basic import MakeVector from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.math import abs as at_abs +from pytensor.tensor.math import add +from pytensor.tensor.math import all as at_all from pytensor.tensor.math import ( - add, bitwise_and, bitwise_or, cos, @@ -29,6 +30,7 @@ dot, eq, exp, + ge, int_div, invert, iround, @@ -900,6 +902,72 @@ def my_init(dtype="float64", num=0): fxv * np.sin(fsv), "float32", ), + # Multiple output cases # 72 + ( + ( + # sum(logp) + at_sum(-((fx - fy) ** 2) / 2), + # grad(logp) + at.grad(at_sum(-((fx - fy) ** 2) / 2), wrt=fx), + ), + (fx, fy), + (fxv, fyv), + 3, + ( + np.sum(-((fxv - fyv) ** 2) / 2), + -(fxv - fyv), + ), + ("float32", "float32"), + ), + # Two Composite graphs that share the same input, but are split by + # a non-elemwise operation (Assert) + ( + ( + log( + ge( + assert_op( + at_abs(fx), + at_all(ge(at_abs(fx), 0)), + ), + 0, + ) + ), + ), + (fx,), + (fxv,), + 4, + (np.zeros_like(fxv),), + ("float32",), + ), + # Two subgraphs that share the same non-fuseable input, but are otherwise + # completely independent + ( + ( + true_div( + mul( + at_sum(fx + 5), # breaks fusion + exp(fx), + ), + (fx + 5), + ), + ), + (fx,), + (fxv,), + 4, + (np.sum(fxv + 5) * np.exp(fxv) / (fxv + 5),), + ("float32",), + ), + pytest.param( + ( + (sin(exp(fx)), exp(sin(fx))), + (fx,), + (fxv,), + 1, + (np.sin(np.exp(fxv)), np.exp(np.sin(fxv))), + ("float32", "float32"), + ), + marks=pytest.mark.xfail, # Not implemented yet + ), ], ) def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): @@ -910,23 +978,34 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): if isinstance(out_dtype, dict): out_dtype = out_dtype[config.cast_policy] + if not isinstance(g, (tuple, list)): + g = (g,) + answer = (answer,) + out_dtype = (out_dtype,) + if self._shared is None: f = function(list(sym_inputs), g, mode=self.mode) for x in range(nb_repeat): out = f(*val_inputs) + if not isinstance(out, list): + out = (out,) else: - out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out") - assert out.dtype == g.dtype - f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode) + out = [ + self._shared(np.zeros((5,) * g_.ndim, dtype=od), "out") + for g_, od in zip(g, out_dtype) + ] + assert all(o.dtype == g_.dtype for o, g_ in zip(out, g)) + f = function(sym_inputs, [], updates=list(zip(out, g)), mode=self.mode) for x in range(nb_repeat): f(*val_inputs) - out = out.get_value() + out = [o.get_value() for o in out] atol = 1e-8 - if out_dtype == "float32": + if any(o == "float32" for o in out_dtype): atol = 1e-6 - assert np.allclose(out, answer * nb_repeat, atol=atol) + for o, a in zip(out, answer): + np.testing.assert_allclose(o, a * nb_repeat, atol=atol) topo = f.maker.fgraph.toposort() topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] @@ -939,13 +1018,15 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): # input of g, # check that the number of input to the Composite # Elemwise is ok - if len(set(g.owner.inputs)) == len(g.owner.inputs): - expected_len_sym_inputs = sum( - not isinstance(x, Constant) for x in topo_[0].inputs - ) - assert expected_len_sym_inputs == len(sym_inputs) + for g_ in g: + if len(set(g_.owner.inputs)) == len(g_.owner.inputs): + expected_len_sym_inputs = sum( + not isinstance(x, Constant) for x in topo_[0].inputs + ) + assert expected_len_sym_inputs == len(sym_inputs) - assert out_dtype == out.dtype + for od, o in zip(out_dtype, out): + assert od == o.dtype def test_fusion_35_inputs(self): r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit.""" @@ -1006,6 +1087,30 @@ def test_big_fusion(self): for node in dlogp.maker.fgraph.toposort() ) + @pytest.mark.xfail(reason="Fails due to #1244") + def test_add_mul_fusion_precedence(self): + """Test that additions and multiplications are "fused together" before + a `Composite` `Op` is introduced. This fusion is done by canonicalization + """ + x, y, z = vectors("x", "y", "z") + out = log((x + y + z) / (x * y * z)) + f = pytensor.function([x, y, z], out, mode=self.mode) + # There should be a single Composite Op + nodes = f.maker.fgraph.apply_nodes + assert len(nodes) == 1 + (node,) = nodes + assert isinstance(node.op, Elemwise) + scalar_op = node.op.scalar_op + assert isinstance(scalar_op, Composite) + assert [node.op for node in scalar_op.fgraph.toposort()] == [ + # There should be a single mul + aes.mul, + # There should be a single add + aes.add, + aes.true_div, + aes.log, + ] + def test_add_mul_fusion_inplace(self): x, y, z = dmatrices("xyz") out = dot(x, y) + x + y + z @@ -1082,11 +1187,8 @@ def impl(self, x): @pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]]) def test_test_values(self, test_value): - """Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions. - - The test values we're talking about are the ones used when C implementations - are checked. - + """Make sure that `local_elemwise_fusion_op` uses test values correctly + when they have zero dimensions. """ x, y, z = dmatrices("xyz") @@ -1094,27 +1196,20 @@ def test_test_values(self, test_value): y.tag.test_value = test_value z.tag.test_value = test_value - if test_value.size == 0: - cm = pytest.raises(ValueError) - else: - cm = contextlib.suppress() - with config.change_flags( compute_test_value="raise", compute_test_value_opt="raise" ): out = x * y + z - with cm: - f = function([x, y, z], out, mode=self.mode) + f = function([x, y, z], out, mode=self.mode) - if test_value.size != 0: - # Confirm that the fusion happened - assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) - assert len(f.maker.fgraph.toposort()) == 1 + # Confirm that the fusion happened + assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) + assert len(f.maker.fgraph.toposort()) == 1 - x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs - assert np.array_equal( - f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] - ) + assert np.array_equal( + f.maker.fgraph.outputs[0].tag.test_value, + np.full_like(test_value, 2.0), + ) @pytest.mark.parametrize("linker", ["cvm", "py"]) @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)]) @@ -1227,6 +1322,26 @@ def test_not_fusing_broadcasted_subgraphs(self): aes.mul, } + def test_multiple_outputs_fused_root_elemwise(self): + """Test that a root elemwise output (single layer) is reused when + there is another fused output""" + + # By default, we do not introduce Composite for single layers of Elemwise + x = at.vector("x") + out1 = at.cos(x) + f = pytensor.function([x], out1, mode=self.mode) + nodes = tuple(f.maker.fgraph.apply_nodes) + assert len(nodes) == 1 + assert isinstance(nodes[0].op.scalar_op, aes.Cos) + + # However, when it can be composed with another output, we should not + # compute that root Elemwise twice + out2 = at.log(out1) + f = pytensor.function([x], [out1, out2], mode=self.mode) + nodes = tuple(f.maker.fgraph.apply_nodes) + assert len(nodes) == 1 + assert isinstance(nodes[0].op.scalar_op, Composite) + class TimesN(aes.basic.UnaryScalarOp): """ diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 8395fb4a60..aa275a03a6 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -887,10 +887,9 @@ def test_basic_6(self): prog = f.maker.fgraph.toposort() assert isinstance(prog[0].op, DimShuffle) assert isinstance(prog[1].op.scalar_op, aes.Composite) # Composite{add,exp} - assert prog[2].op == add or prog[3].op == add # first subtensor - assert isinstance(prog[2].op, Subtensor) or isinstance(prog[3].op, Subtensor) - assert len(prog) == 4 + assert isinstance(prog[2].op, Subtensor) + assert len(prog) == 3 f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something def test_basic_7(self): diff --git a/tests/test_printing.py b/tests/test_printing.py index f5a10c8aeb..cef7b872fb 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -273,8 +273,7 @@ def test_debugprint(): s = s.getvalue() exp_res = dedent( r""" - Elemwise{Composite{(i0 + (i1 - i2))}} 4 - |A + Elemwise{Composite{(i2 + (i0 - i1))}} 4 |InplaceDimShuffle{x,0} v={0: [0]} 3 | |CGemv{inplace} d={0: [0]} 2 | |AllocEmpty{dtype='float64'} 1 @@ -285,6 +284,7 @@ def test_debugprint(): | | | |TensorConstant{0.0} |D + |A """ ).lstrip() From d7eadb463442c3595338bcf1980b8b50535e9077 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 14 Dec 2022 10:39:57 +0100 Subject: [PATCH 08/13] Lower precision in TestFusion parametrization This was not an issue in my local machine, but failed on the Github CI. It could be due to compiler optimizations. Case 69 used to look like this: ```python Elemwise{Composite{(i0 * tan(i0) * tan(i0) * i1)}} [id C] |x [id A] |x [id A] ``` And now looks like this ```python Elemwise{Composite{(i0 * tan(i0) * tan(i0) * i0)}} [id C] |x [id A] [None] ``` --- tests/tensor/rewriting/test_elemwise.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 7484cd75d9..1708c0fe47 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -882,6 +882,7 @@ def my_init(dtype="float64", num=0): 1, fxv * np.tan(fxv) * np.tan(fxv) * fxv, "float32", + 1e-5, ), ( mul(ftanx, ftanx, fx + fy), @@ -890,6 +891,7 @@ def my_init(dtype="float64", num=0): 1, np.tan(fxv) * np.tan(fxv) * (fxv + fyv), "float32", + 1e-5, ), # 70 # Cases with different broadcast pattern. They should not # be merged as this would duplicate computation @@ -973,7 +975,11 @@ def my_init(dtype="float64", num=0): def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): """Verify that `Elemwise` fusion works.""" - g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case + if len(case) == 6: + g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case + atol = None + else: + g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype, atol = case if isinstance(out_dtype, dict): out_dtype = out_dtype[config.cast_policy] @@ -1000,9 +1006,10 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): f(*val_inputs) out = [o.get_value() for o in out] - atol = 1e-8 - if any(o == "float32" for o in out_dtype): - atol = 1e-6 + if atol is None: + atol = 1e-8 + if any(o == "float32" for o in out_dtype): + atol = 1e-6 for o, a in zip(out, answer): np.testing.assert_allclose(o, a * nb_repeat, atol=atol) From 636fc840bbc9a66c97c917cc812e276c8dca6118 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Dec 2022 17:35:49 +0100 Subject: [PATCH 09/13] Do not manually include fast_run in test_shape_i_* It doesn't make sense to include `fast_run` if `fast_compile` mode is being used. Some rewrites such as the FusionOptimizer are not compatible with `fast_compile` mode which prevents the creation of C thunks. The FusionOptimizer has no way of knowing this is the case, and assumes it is safe to return Composites with more than 32 operands, even though that's not the case with the Python perform method. --- tests/tensor/test_subtensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 1d0bfe9b21..431c236442 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -986,7 +986,7 @@ def test_adv_sub1_idx_broadcast(self): def test_shape_i_const(self): # Each axis is treated independently by shape_i/shape operators - mode_opt = self.mode.including("fast_run") + mode_opt = self.mode data = self.shared(np.array(np.arange(5), dtype=self.dtype)) for start in [None] + [-8, -5, -1, 0, 1, 5, 8]: outs = [] @@ -1004,7 +1004,7 @@ def test_shape_i_const(self): def test_shape_i_scalar(self): # Each axis is treated independently by shape_i/shape operators - mode_opt = self.mode.including("fast_run") + mode_opt = self.mode v_data = np.array(np.arange(5), dtype=self.dtype) t_data = self.shared(v_data) From e34ca2e1f90bdd20a13da360e0863e074f67b175 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 18 Jan 2023 12:22:51 +0100 Subject: [PATCH 10/13] Fix wrong backend in Numba logsumexp benchmark --- tests/link/numba/test_elemwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index d05b5ab95d..18f5950d96 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -548,7 +548,7 @@ def test_logsumexp_benchmark(size, axis, benchmark): rng = np.random.default_rng(23920) X_val = rng.normal(size=size) - X_lse_fn = pytensor.function([X], X_lse, mode="JAX") + X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA") # JIT compile first _ = X_lse_fn(X_val) From 2c63155844925e5d32163cc939328e9692afdd36 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 18 Jan 2023 12:23:40 +0100 Subject: [PATCH 11/13] Add benchmark tests for fused Elemwises --- tests/link/numba/test_elemwise.py | 16 ++++++++++++++++ tests/tensor/rewriting/test_elemwise.py | 13 +++++++++++++ 2 files changed, 29 insertions(+) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 18f5950d96..8fbf026e11 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -11,6 +11,7 @@ from pytensor import config, function from pytensor.compile.ops import deep_copy_op from pytensor.compile.sharedvalue import SharedVariable +from pytensor.gradient import grad from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.tensor import elemwise as at_elemwise @@ -555,3 +556,18 @@ def test_logsumexp_benchmark(size, axis, benchmark): res = benchmark(X_lse_fn, X_val) exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) np.testing.assert_array_almost_equal(res, exp_res) + + +def test_fused_elemwise_benchmark(benchmark): + rng = np.random.default_rng(123) + size = 100_000 + x = pytensor.shared(rng.normal(size=size), name="x") + mu = pytensor.shared(rng.normal(size=size), name="mu") + + logp = -((x - mu) ** 2) / 2 + grad_logp = grad(logp.sum(), x) + + func = pytensor.function([], [logp, grad_logp], mode="NUMBA") + # JIT compile first + func() + benchmark(func) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 1708c0fe47..9b6dc304f4 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -9,6 +9,7 @@ from pytensor.compile.function import function from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config +from pytensor.gradient import grad from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import check_stack_trace, out2in @@ -1349,6 +1350,18 @@ def test_multiple_outputs_fused_root_elemwise(self): assert len(nodes) == 1 assert isinstance(nodes[0].op.scalar_op, Composite) + def test_eval_benchmark(self, benchmark): + rng = np.random.default_rng(123) + size = 100_000 + x = pytensor.shared(rng.normal(size=size), name="x") + mu = pytensor.shared(rng.normal(size=size), name="mu") + + logp = -((x - mu) ** 2) / 2 + grad_logp = grad(logp.sum(), x) + + func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN") + benchmark(func) + class TimesN(aes.basic.UnaryScalarOp): """ From 2d8d83b1a413c4aa947ba9e1851c23ff688eb750 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 7 Feb 2023 19:11:28 +0100 Subject: [PATCH 12/13] Add benchmark test for FusionRewriter --- tests/tensor/rewriting/test_elemwise.py | 65 ++++++++++++++----------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 9b6dc304f4..deaf92ef43 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -48,7 +48,7 @@ from pytensor.tensor.math import sin, sinh, sqr, sqrt from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import tan, tanh, true_div, xor -from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift +from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape from pytensor.tensor.shape import reshape from pytensor.tensor.type import ( @@ -302,6 +302,29 @@ def my_init(dtype="float64", num=0): fwx = fw + fx ftanx = tan(fx) + def large_fuseable_graph(self, n): + factors = [] + sd = dscalar() + means = dvector() + + cst_05 = at.constant(0.5) + cst_m05 = at.constant(-0.5) + cst_2 = at.constant(2) + cst_m2 = at.constant(-2) + ones = at.constant(np.ones(10)) + + for i in range(n): + f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log( + cst_05 * (sd**cst_m2) / np.pi + ) + factors.append(at_sum(f)) + + logp = add(*factors) + + vars = [sd, means] + dlogp = [pytensor.grad(logp, v) for v in vars] + return vars, dlogp + @pytest.mark.parametrize( "case", [ @@ -1059,35 +1082,9 @@ def test_fusion_35_inputs(self): @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") def test_big_fusion(self): - # In the past, pickle of Composite generated in that case - # crashed with max recursion limit. So we were not able to - # generate C code in that case. - factors = [] - sd = dscalar() - means = dvector() - - cst_05 = at.constant(0.5) - cst_m05 = at.constant(-0.5) - cst_2 = at.constant(2) - cst_m2 = at.constant(-2) - ones = at.constant(np.ones(10)) - n = 85 - if config.mode in ["DebugMode", "DEBUG_MODE"]: - n = 10 - - for i in range(n): - f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log( - cst_05 * (sd**cst_m2) / np.pi - ) - factors.append(at_sum(f)) - - logp = add(*factors) - - vars = [sd, means] - # Make sure that C compilation is used mode = Mode("cvm", self.rewrites) - dlogp = function(vars, [pytensor.grad(logp, v) for v in vars], mode=mode) + dlogp = function(*self.large_fuseable_graph(n=85), mode=mode) # Make sure something was fused assert any( @@ -1362,6 +1359,18 @@ def test_eval_benchmark(self, benchmark): func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN") benchmark(func) + @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") + def test_rewrite_benchmark(self, benchmark): + inps, outs = self.large_fuseable_graph(n=25) + fg = FunctionGraph(inps, outs) + opt = FusionOptimizer() + + def rewrite_func(): + nb_replacement = opt.apply(fg.clone())[2] + return nb_replacement + + assert benchmark(rewrite_func) == 103 + class TimesN(aes.basic.UnaryScalarOp): """ From 11c2b2c4f9ce08b662d81cc0b58f0aba81070d8f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 20 Jan 2023 15:42:40 +0100 Subject: [PATCH 13/13] Render inner-graphs of Composite Ops in debugprint --- pytensor/printing.py | 24 ++++++++++----- pytensor/scalar/basic.py | 27 ++--------------- tests/scalar/test_basic.py | 7 +---- tests/scan/test_printing.py | 45 +++++++++++++++++----------- tests/tensor/rewriting/test_basic.py | 8 ++--- tests/tensor/rewriting/test_math.py | 7 +++-- tests/test_printing.py | 11 ++++++- 7 files changed, 65 insertions(+), 64 deletions(-) diff --git a/pytensor/printing.py b/pytensor/printing.py index 8b24884944..1042a5897a 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -312,7 +312,11 @@ def debugprint( ): if hasattr(var.owner, "op"): - if isinstance(var.owner.op, HasInnerGraph) and var not in inner_graph_vars: + if ( + isinstance(var.owner.op, HasInnerGraph) + or hasattr(var.owner.op, "scalar_op") + and isinstance(var.owner.op.scalar_op, HasInnerGraph) + ) and var not in inner_graph_vars: inner_graph_vars.append(var) if print_op_info: op_information.update(op_debug_information(var.owner.op, var.owner)) @@ -355,8 +359,12 @@ def debugprint( inner_inputs = inner_fn.maker.fgraph.inputs inner_outputs = inner_fn.maker.fgraph.outputs else: - inner_inputs = ig_var.owner.op.inner_inputs - inner_outputs = ig_var.owner.op.inner_outputs + if hasattr(ig_var.owner.op, "scalar_op"): + inner_inputs = ig_var.owner.op.scalar_op.inner_inputs + inner_outputs = ig_var.owner.op.scalar_op.inner_outputs + else: + inner_inputs = ig_var.owner.op.inner_inputs + inner_outputs = ig_var.owner.op.inner_outputs outer_inputs = ig_var.owner.inputs @@ -422,8 +430,9 @@ def debugprint( if ( isinstance(getattr(out.owner, "op", None), HasInnerGraph) - and out not in inner_graph_vars - ): + or hasattr(getattr(out.owner, "op", None), "scalar_op") + and isinstance(out.owner.op.scalar_op, HasInnerGraph) + ) and out not in inner_graph_vars: inner_graph_vars.append(out) _debugprint( @@ -664,8 +673,9 @@ def get_id_str( if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"): if ( isinstance(in_var.owner.op, HasInnerGraph) - and in_var not in inner_graph_ops - ): + or hasattr(in_var.owner.op, "scalar_op") + and isinstance(in_var.owner.op.scalar_op, HasInnerGraph) + ) and in_var not in inner_graph_ops: inner_graph_ops.append(in_var) _debugprint( diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index fb2656e50e..fee02684fe 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4000,7 +4000,8 @@ class Composite(ScalarOp, HasInnerGraph): init_param: Tuple[str, ...] = ("inputs", "outputs") - def __init__(self, inputs, outputs): + def __init__(self, inputs, outputs, name="Composite"): + self.name = name # We need to clone the graph as sometimes its nodes already # contain a reference to an fgraph. As we want the Composite # to be pickable, we can't have reference to fgraph. @@ -4106,30 +4107,6 @@ def _perform(*inputs, outputs=[[None]]): self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert) return self._py_perform_fn - @property - def name(self): - if hasattr(self, "_name"): - return self._name - - # TODO FIXME: Just implement pretty printing for the `Op`; don't do - # this redundant, outside work in the `Op` itself. - for i, r in enumerate(self.fgraph.inputs): - r.name = f"i{int(i)}" - for i, r in enumerate(self.fgraph.outputs): - r.name = f"o{int(i)}" - io = set(self.fgraph.inputs + self.fgraph.outputs) - for i, r in enumerate(self.fgraph.variables): - if r not in io and len(self.fgraph.clients[r]) > 1: - r.name = f"t{int(i)}" - outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs]) - rval = f"Composite{{{outputs_str}}}" - self._name = rval - return self._name - - @name.setter - def name(self, name): - self._name = name - @property def fgraph(self): if hasattr(self, "_fgraph"): diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 6e694f5be6..c27f220c06 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -183,12 +183,7 @@ def test_composite_printing(self): make_function(DualLinker().accept(g)) assert str(g) == ( - "FunctionGraph(*1 -> Composite{((i0 + i1) + i2)," - " (i0 + (i1 * i2)), (i0 / i1), " - "(i0 // 5), " - "(-i0), (i0 - i1), ((i0 ** i1) + (-i2))," - " (i0 % 3)}(x, y, z), " - "*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" + "FunctionGraph(*1 -> Composite(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" ) def test_non_scalar_error(self): diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 2a8adbbe30..7373c78cb4 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -604,31 +604,40 @@ def no_shared_fn(n, x_tm1, M): out = pytensor.function([M], out, updates=updates, mode="FAST_RUN") expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0) - |TensorConstant{20000} [id B] (n_steps) - |TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0) - |IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0) - | |AllocEmpty{dtype='int64'} [id E] 0 - | | |TensorConstant{20000} [id B] - | |TensorConstant{(1,) of 0} [id F] - | |ScalarConstant{1} [id G] - | [id H] (outer_in_non_seqs-0) + |TensorConstant{20000} [id B] (n_steps) + |TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0) + |IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0) + | |AllocEmpty{dtype='int64'} [id E] 0 + | | |TensorConstant{20000} [id B] + | |TensorConstant{(1,) of 0} [id F] + | |ScalarConstant{1} [id G] + | [id H] (outer_in_non_seqs-0) Inner graphs: forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0) - >Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0) - > |TensorConstant{0} [id J] - > |Subtensor{int64, int64, uint8} [id K] - > | |*2- [id L] -> [id H] (inner_in_non_seqs-0) - > | |ScalarFromTensor [id M] - > | | |*0- [id N] -> [id C] (inner_in_seqs-0) - > | |ScalarFromTensor [id O] - > | | |*1- [id P] -> [id D] (inner_in_sit_sot-0) - > | |ScalarConstant{0} [id Q] - > |TensorConstant{1} [id R] + >Elemwise{Composite} [id I] (inner_out_sit_sot-0) + > |TensorConstant{0} [id J] + > |Subtensor{int64, int64, uint8} [id K] + > | |*2- [id L] -> [id H] (inner_in_non_seqs-0) + > | |ScalarFromTensor [id M] + > | | |*0- [id N] -> [id C] (inner_in_seqs-0) + > | |ScalarFromTensor [id O] + > | | |*1- [id P] -> [id D] (inner_in_sit_sot-0) + > | |ScalarConstant{0} [id Q] + > |TensorConstant{1} [id R] + + Elemwise{Composite} [id I] + >Switch [id S] + > |LT [id T] + > | | [id U] + > | | [id V] + > | [id W] + > | [id U] """ output_str = debugprint(out, file="str", print_op_info=True) + print(output_str) lines = output_str.split("\n") for truth, out in zip(expected_output.split("\n"), lines): diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index f28c95037f..144fcf0eb9 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -16,7 +16,7 @@ from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import rewrite_graph -from pytensor.printing import pprint +from pytensor.printing import debugprint, pprint from pytensor.raise_op import Assert, CheckAndRaise from pytensor.tensor.basic import ( Alloc, @@ -1105,7 +1105,7 @@ def test_elemwise_float_ops(self, op): s2 = at.switch(c, x, y) g = rewrite(FunctionGraph(mats, [op(s1, s2)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 @pytest.mark.parametrize( "op", @@ -1122,7 +1122,7 @@ def test_elemwise_int_ops(self, op): s1 = at.switch(c, a, b) s2 = at.switch(c, x, y) g = rewrite(FunctionGraph(mats, [op(s1, s2)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 @pytest.mark.parametrize("op", [add, mul]) def test_elemwise_multi_inputs(self, op): @@ -1134,7 +1134,7 @@ def test_elemwise_multi_inputs(self, op): u, v = matrices("uv") s3 = at.switch(c, u, v) g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 class TestLocalOptAlloc: diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 850ba0db6b..a662b5b325 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -28,6 +28,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph from pytensor.misc.safe_asarray import _asarray +from pytensor.printing import debugprint from pytensor.tensor import inplace from pytensor.tensor.basic import Alloc, join, switch from pytensor.tensor.blas import Dot22, Gemv @@ -2416,7 +2417,7 @@ def test_elemwise(self): at_pow, ): g = rewrite(FunctionGraph(mats, [op(s1, s2)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 # integer Ops mats = imatrices("cabxy") c, a, b, x, y = mats @@ -2428,13 +2429,13 @@ def test_elemwise(self): bitwise_xor, ): g = rewrite(FunctionGraph(mats, [op(s1, s2)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 # add/mul with more than two inputs u, v = matrices("uv") s3 = at.switch(c, u, v) for op in (add, mul): g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) - assert str(g).count("Switch") == 1 + assert debugprint(g, file="str").count("Switch") == 1 class TestLocalSumProd: diff --git a/tests/test_printing.py b/tests/test_printing.py index cef7b872fb..d9592dd9af 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -273,7 +273,7 @@ def test_debugprint(): s = s.getvalue() exp_res = dedent( r""" - Elemwise{Composite{(i2 + (i0 - i1))}} 4 + Elemwise{Composite} 4 |InplaceDimShuffle{x,0} v={0: [0]} 3 | |CGemv{inplace} d={0: [0]} 2 | |AllocEmpty{dtype='float64'} 1 @@ -285,6 +285,15 @@ def test_debugprint(): | |TensorConstant{0.0} |D |A + + Inner graphs: + + Elemwise{Composite} + >add + > | + > |sub + > | + > | """ ).lstrip()