diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 590625445f..7cf495b56a 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -2,6 +2,7 @@ import itertools import operator +from collections import defaultdict from functools import partial, reduce import numpy as np @@ -423,6 +424,100 @@ def local_sumsqr2dot(fgraph, node): return [new_out] +@register_specialize +@node_rewriter([mul, true_div]) +def local_mul_exp_to_exp_add(fgraph, node): + """ + This rewrite detects e^x * e^y and converts it to e^(x+y). + Similarly, e^x / e^y becomes e^(x-y). + """ + exps = [ + n.owner.inputs[0] + for n in node.inputs + if n.owner + and hasattr(n.owner.op, "scalar_op") + and isinstance(n.owner.op.scalar_op, aes.Exp) + ] + # Can only do any rewrite if there are at least two exp-s + if len(exps) >= 2: + # Mul -> add; TrueDiv -> sub + orig_op, new_op = mul, add + if isinstance(node.op.scalar_op, aes.TrueDiv): + orig_op, new_op = true_div, sub + new_out = exp(new_op(*exps)) + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + # The original Mul may have more than two factors, some of which may not be exp nodes. + # If so, we keep multiplying them with the new exp(sum) node. + # E.g.: e^x * y * e^z * w --> e^(x+z) * y * w + rest = [ + n + for n in node.inputs + if not n.owner + or not hasattr(n.owner.op, "scalar_op") + or not isinstance(n.owner.op.scalar_op, aes.Exp) + ] + if len(rest) > 0: + new_out = orig_op(new_out, *rest) + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + return [new_out] + + +@register_specialize +@node_rewriter([mul, true_div]) +def local_mul_pow_to_pow_add(fgraph, node): + """ + This rewrite detects a^x * a^y and converts it to a^(x+y). + Similarly, a^x / a^y becomes a^(x-y). + """ + # search for pow-s and group them by their bases + pow_nodes = defaultdict(list) + rest = [] + for n in node.inputs: + if ( + n.owner + and hasattr(n.owner.op, "scalar_op") + and isinstance(n.owner.op.scalar_op, aes.Pow) + ): + base_node = n.owner.inputs[0] + # exponent is at n.owner.inputs[1], but we need to store the full node + # in case this particular power node remains alone and can't be rewritten + pow_nodes[base_node].append(n) + else: + rest.append(n) + + # Can only do any rewrite if there are at least two pow-s with the same base + can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2] + if len(can_rewrite) >= 1: + # Mul -> add; TrueDiv -> sub + orig_op, new_op = mul, add + if isinstance(node.op.scalar_op, aes.TrueDiv): + orig_op, new_op = true_div, sub + pow_factors = [] + # Rewrite pow-s having the same base for each different base + # E.g.: a^x * a^y --> a^(x+y) + for base in can_rewrite: + exponents = [n.owner.inputs[1] for n in pow_nodes[base]] + new_node = base ** new_op(*exponents) + if new_node.dtype != node.outputs[0].dtype: + new_node = cast(new_node, dtype=node.outputs[0].dtype) + pow_factors.append(new_node) + # Don't forget about those sole pow-s that couldn't be rewriten + sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite] + # Combine the rewritten pow-s and other, non-pow factors of the original Mul + # E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v + if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0: + new_out = orig_op(*pow_factors, *sole_pows, *rest) + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + else: + # if all factors of the original mul were pows-s with the same base, + # we can get rid of the mul completely. + new_out = pow_factors[0] + return [new_out] + + @register_stabilize @register_specialize @register_canonicalize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 3f508af79b..74a6624077 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4014,6 +4014,161 @@ def test_local_sumsqr2dot(): ) +def test_local_mul_exp_to_exp_add(): + # Default and FAST_RUN modes put a Composite op into the final graph, + # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs, + # we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites + mode = get_default_mode().excluding("fusion").including("local_mul_exp_to_exp_add") + + x = scalar("x") + y = scalar("y") + z = scalar("z") + w = scalar("w") + expx = exp(x) + expy = exp(y) + expz = exp(z) + expw = exp(w) + + # e^x * e^y * e^z * e^w = e^(x+y+z+w) + op = expx * expy * expz * expw + f = function([x, y, z, w], op, mode) + pytensor.dprint(f) + utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6)) + graph = f.maker.fgraph.toposort() + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + + # e^x * e^y * e^z / e^w = e^(x+y+z-w) + op = expx * expy * expz / expw + f = function([x, y, z, w], op, mode) + utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 - 6)) + graph = f.maker.fgraph.toposort() + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph) + + # e^x * e^y / e^z * e^w = e^(x+y-z+w) + op = expx * expy / expz * expw + f = function([x, y, z, w], op, mode) + utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 - 5 + 6)) + graph = f.maker.fgraph.toposort() + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph) + + # e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z) + op = expx / expy / expz + f = function([x, y, z], op, mode) + utt.assert_allclose(f(3, 4, 5), np.exp(3 - 4 - 5)) + graph = f.maker.fgraph.toposort() + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph) + + # e^x * y * e^z * w = e^(x+z) * y * w + op = expx * y * expz * w + f = function([x, y, z, w], op, mode) + utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 5) * 4 * 6) + graph = f.maker.fgraph.toposort() + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + + # expect same for matrices as well + mx = matrix("mx") + my = matrix("my") + f = function([mx, my], exp(mx) * exp(my), mode, allow_input_downcast=True) + M1 = np.array([[1.0, 2.0], [3.0, 4.0]]) + M2 = np.array([[5.0, 6.0], [7.0, 8.0]]) + utt.assert_allclose(f(M1, M2), np.exp(M1 + M2)) + graph = f.maker.fgraph.toposort() + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + + # checking whether further rewrites can proceed after this one as one would expect + # e^x * e^(-x) = e^(x-x) = e^0 = 1 + f = function([x], expx * exp(neg(x)), mode) + utt.assert_allclose(f(42), 1) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].inputs[0], TensorConstant) + + # e^x / e^x = e^(x-x) = e^0 = 1 + f = function([x], expx / expx, mode) + utt.assert_allclose(f(42), 1) + graph = f.maker.fgraph.toposort() + assert isinstance(graph[0].inputs[0], TensorConstant) + + +def test_local_mul_pow_to_pow_add(): + # Default and FAST_RUN modes put a Composite op into the final graph, + # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs, + # we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites + mode = ( + get_default_mode() + .excluding("fusion") + .including("local_mul_exp_to_exp_add") + .including("local_mul_pow_to_pow_add") + ) + + x = scalar("x") + y = scalar("y") + z = scalar("z") + w = scalar("w") + v = scalar("v") + u = scalar("u") + t = scalar("t") + s = scalar("s") + a = scalar("a") + b = scalar("b") + c = scalar("c") + + # 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w) + op = 2**x * 2**y * 2**z * 2**w + f = function([x, y, z, w], op, mode) + utt.assert_allclose(f(3, 4, 5, 6), 2 ** (3 + 4 + 5 + 6)) + graph = f.maker.fgraph.toposort() + assert all(isinstance(n.op, Elemwise) for n in graph) + assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph) + assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + + # 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s + op = 2**x * a**y * 2**z * b**w * c**v * a**u * s * b**t + f = function([x, y, z, w, v, u, t, s, a, b, c], op, mode) + utt.assert_allclose( + f(4, 5, 6, 7, 8, 9, 10, 11, 2.5, 3, 3.5), + 2 ** (4 + 6) * 2.5 ** (5 + 9) * 3 ** (7 + 10) * 3.5**8 * 11, + ) + graph = f.maker.fgraph.toposort() + assert all(isinstance(n.op, Elemwise) for n in graph) + assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 3 + assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Pow)]) == 4 + assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + + # (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w) + op = 2**x / 2**y * (a**z / a**w) + f = function([x, y, z, w, a], op, mode) + utt.assert_allclose(f(3, 5, 6, 4, 7), 2 ** (3 - 5) * 7 ** (6 - 4)) + graph = f.maker.fgraph.toposort() + assert all(isinstance(n.op, Elemwise) for n in graph) + assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Sub)]) == 2 + assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + + # a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w) + op = a**x * a**y * exp(z) * exp(w) + f = function([x, y, z, w, a], op, mode) + utt.assert_allclose(f(3, 4, 5, 6, 2), 2 ** (3 + 4) * np.exp(5 + 6)) + graph = f.maker.fgraph.toposort() + assert all(isinstance(n.op, Elemwise) for n in graph) + assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 2 + assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph) + + def test_local_expm1(): x = matrix("x") u = scalar("u")