Skip to content

Rewrite products of exponents as exponent of sum #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 5, 2023
95 changes: 95 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
import operator
from collections import defaultdict
from functools import partial, reduce

import numpy as np
Expand Down Expand Up @@ -423,6 +424,100 @@ def local_sumsqr2dot(fgraph, node):
return [new_out]


@register_specialize
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rethinking this, it may make sense to register these new rewrites in canonicalize as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, after looking around how it's resolved elsewhere and some experimentation with different options, I've ended up adding the following line at the beginning of the test function, which seems to have resolved the issue on my local runs in both with and without the FAST_COMPILE flag.

mode = get_mode("FAST_RUN") if config.mode == "FAST_COMPILE" else get_default_mode()

(Obviously, with passing the mode parameter to all function() calls)

If you're ok with this, on which branch do you want me to commit this change? The original one or the the one you had rebased that branch earlier?

Also, do you want me to add back the @register_canonicalize decorators?

Thanks!

Copy link
Member

@ricardoV94 ricardoV94 Jan 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a more clean solution is to use mode=get_default_mode().excluding("fusion")

Then the tests will also be more straightforward because you don't need to check the graph inside the Composites

And yes, let's try to add the register_canonicalize to these rewrites and see if it doesn't break anything.

Copy link
Contributor Author

@tamastokes tamastokes Jan 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An interesting side effect of making the rewrites canonical is that one of the test cases in test_math.py::TestSigmoidRewrites::test_local_sigm_times_exp breaks, as now my rewrites takes precedence over the sigma_times_exp rewrites. More specifically,
sigma(x) * sigma(-y) * (-exp(-x)) * exp(xy) * exp(y)
is supposed to become
(-1) * sigma(-x) * sigma(y) * exp(xy)
but my rewrites contracts the exp-s before the sigma*exp rewrite could take action, and the result becomes
(-1) * exp(xy+y-x) * sigma(x) * sigma(-y)

So, I guess, this is a trade-off, how should we resolve this? (Having run all tests, this is the only one the canonicalisation broke.)

Copy link
Member

@ricardoV94 ricardoV94 Feb 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could maybe extend the sigma_times_exp rewrite to check if the redundant term is somewhere inside the exponentiation (instead of being all there is)?

exp(y - x) * sigmoid(x) -> exp(y) * sigmoid(x)

Does that make sense? If the canonicalize doesn't do it we might need to represent the contents of the exponent as a flat a series of additions (with negated terms instead of substitution) so that we can match them more easily.

Might have other ramifications I am not seeing.

Copy link
Member

@ricardoV94 ricardoV94 Feb 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the rewrite, such change does not seem trivial, so let's leave as is: don't canonicalize.

We can open a new issue to investigate if it's worth it.

Let's just change the mode thing in the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering how frequent / marginal such use cases are. Are there any kind of statistics available on how often real world use cases trigger which rewrites?

Meanwhile, I've just committed the changes for the tests, hope they'll be fine this time.

Thanks!

Copy link
Member

@ricardoV94 ricardoV94 Feb 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we have no telemetrics to study the types of graphs users create :D

@node_rewriter([mul, true_div])
def local_mul_exp_to_exp_add(fgraph, node):
"""
This rewrite detects e^x * e^y and converts it to e^(x+y).
Similarly, e^x / e^y becomes e^(x-y).
"""
exps = [
n.owner.inputs[0]
for n in node.inputs
if n.owner
and hasattr(n.owner.op, "scalar_op")
and isinstance(n.owner.op.scalar_op, aes.Exp)
]
# Can only do any rewrite if there are at least two exp-s
if len(exps) >= 2:
# Mul -> add; TrueDiv -> sub
orig_op, new_op = mul, add
if isinstance(node.op.scalar_op, aes.TrueDiv):
orig_op, new_op = true_div, sub
new_out = exp(new_op(*exps))
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
# The original Mul may have more than two factors, some of which may not be exp nodes.
# If so, we keep multiplying them with the new exp(sum) node.
# E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
rest = [
n
for n in node.inputs
if not n.owner
or not hasattr(n.owner.op, "scalar_op")
or not isinstance(n.owner.op.scalar_op, aes.Exp)
]
if len(rest) > 0:
new_out = orig_op(new_out, *rest)
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
return [new_out]


@register_specialize
@node_rewriter([mul, true_div])
def local_mul_pow_to_pow_add(fgraph, node):
"""
This rewrite detects a^x * a^y and converts it to a^(x+y).
Similarly, a^x / a^y becomes a^(x-y).
"""
# search for pow-s and group them by their bases
pow_nodes = defaultdict(list)
rest = []
for n in node.inputs:
if (
n.owner
and hasattr(n.owner.op, "scalar_op")
and isinstance(n.owner.op.scalar_op, aes.Pow)
):
base_node = n.owner.inputs[0]
# exponent is at n.owner.inputs[1], but we need to store the full node
# in case this particular power node remains alone and can't be rewritten
pow_nodes[base_node].append(n)
else:
rest.append(n)

# Can only do any rewrite if there are at least two pow-s with the same base
can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2]
if len(can_rewrite) >= 1:
# Mul -> add; TrueDiv -> sub
orig_op, new_op = mul, add
if isinstance(node.op.scalar_op, aes.TrueDiv):
orig_op, new_op = true_div, sub
pow_factors = []
# Rewrite pow-s having the same base for each different base
# E.g.: a^x * a^y --> a^(x+y)
for base in can_rewrite:
exponents = [n.owner.inputs[1] for n in pow_nodes[base]]
new_node = base ** new_op(*exponents)
if new_node.dtype != node.outputs[0].dtype:
new_node = cast(new_node, dtype=node.outputs[0].dtype)
pow_factors.append(new_node)
# Don't forget about those sole pow-s that couldn't be rewriten
sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite]
# Combine the rewritten pow-s and other, non-pow factors of the original Mul
# E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0:
new_out = orig_op(*pow_factors, *sole_pows, *rest)
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
else:
# if all factors of the original mul were pows-s with the same base,
# we can get rid of the mul completely.
new_out = pow_factors[0]
return [new_out]


@register_stabilize
@register_specialize
@register_canonicalize
Expand Down
155 changes: 155 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4014,6 +4014,161 @@ def test_local_sumsqr2dot():
)


def test_local_mul_exp_to_exp_add():
# Default and FAST_RUN modes put a Composite op into the final graph,
# whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
# we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
mode = get_default_mode().excluding("fusion").including("local_mul_exp_to_exp_add")

x = scalar("x")
y = scalar("y")
z = scalar("z")
w = scalar("w")
expx = exp(x)
expy = exp(y)
expz = exp(z)
expw = exp(w)

# e^x * e^y * e^z * e^w = e^(x+y+z+w)
op = expx * expy * expz * expw
f = function([x, y, z, w], op, mode)
pytensor.dprint(f)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)

# e^x * e^y * e^z / e^w = e^(x+y+z-w)
op = expx * expy * expz / expw
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 - 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)

# e^x * e^y / e^z * e^w = e^(x+y-z+w)
op = expx * expy / expz * expw
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 - 5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)

# e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z)
op = expx / expy / expz
f = function([x, y, z], op, mode)
utt.assert_allclose(f(3, 4, 5), np.exp(3 - 4 - 5))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)

# e^x * y * e^z * w = e^(x+z) * y * w
op = expx * y * expz * w
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 5) * 4 * 6)
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)

# expect same for matrices as well
mx = matrix("mx")
my = matrix("my")
f = function([mx, my], exp(mx) * exp(my), mode, allow_input_downcast=True)
M1 = np.array([[1.0, 2.0], [3.0, 4.0]])
M2 = np.array([[5.0, 6.0], [7.0, 8.0]])
utt.assert_allclose(f(M1, M2), np.exp(M1 + M2))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)

# checking whether further rewrites can proceed after this one as one would expect
# e^x * e^(-x) = e^(x-x) = e^0 = 1
f = function([x], expx * exp(neg(x)), mode)
utt.assert_allclose(f(42), 1)
graph = f.maker.fgraph.toposort()
assert isinstance(graph[0].inputs[0], TensorConstant)

# e^x / e^x = e^(x-x) = e^0 = 1
f = function([x], expx / expx, mode)
utt.assert_allclose(f(42), 1)
graph = f.maker.fgraph.toposort()
assert isinstance(graph[0].inputs[0], TensorConstant)


def test_local_mul_pow_to_pow_add():
# Default and FAST_RUN modes put a Composite op into the final graph,
# whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
# we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
mode = (
get_default_mode()
.excluding("fusion")
.including("local_mul_exp_to_exp_add")
.including("local_mul_pow_to_pow_add")
)

x = scalar("x")
y = scalar("y")
z = scalar("z")
w = scalar("w")
v = scalar("v")
u = scalar("u")
t = scalar("t")
s = scalar("s")
a = scalar("a")
b = scalar("b")
c = scalar("c")

# 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w)
op = 2**x * 2**y * 2**z * 2**w
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), 2 ** (3 + 4 + 5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)

# 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s
op = 2**x * a**y * 2**z * b**w * c**v * a**u * s * b**t
f = function([x, y, z, w, v, u, t, s, a, b, c], op, mode)
utt.assert_allclose(
f(4, 5, 6, 7, 8, 9, 10, 11, 2.5, 3, 3.5),
2 ** (4 + 6) * 2.5 ** (5 + 9) * 3 ** (7 + 10) * 3.5**8 * 11,
)
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 3
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Pow)]) == 4
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)

# (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w)
op = 2**x / 2**y * (a**z / a**w)
f = function([x, y, z, w, a], op, mode)
utt.assert_allclose(f(3, 5, 6, 4, 7), 2 ** (3 - 5) * 7 ** (6 - 4))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Sub)]) == 2
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)

# a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w)
op = a**x * a**y * exp(z) * exp(w)
f = function([x, y, z, w, a], op, mode)
utt.assert_allclose(f(3, 4, 5, 6, 2), 2 ** (3 + 4) * np.exp(5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 2
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)


def test_local_expm1():
x = matrix("x")
u = scalar("u")
Expand Down