Skip to content

Commit ce0b503

Browse files
committed
Extend log_softmax rewrite and run it in stabilize
1 parent 39bda72 commit ce0b503

File tree

2 files changed

+91
-53
lines changed

2 files changed

+91
-53
lines changed

pytensor/tensor/rewriting/special.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,78 @@
1-
from pytensor import scalar as aes
21
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
3-
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4-
from pytensor.tensor.math import Sum, exp
2+
from pytensor.tensor.elemwise import DimShuffle
3+
from pytensor.tensor.math import Sum, exp, log
54
from pytensor.tensor.math import sum as at_sum
65
from pytensor.tensor.math import true_div
7-
from pytensor.tensor.rewriting.basic import register_specialize
6+
from pytensor.tensor.rewriting.basic import register_stabilize
87
from pytensor.tensor.rewriting.math import local_mul_canonizer
9-
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
10-
from pytensor.tensor.subtensor import AdvancedIncSubtensor
8+
from pytensor.tensor.special import Softmax, SoftmaxGrad, log_softmax
9+
from pytensor.tensor.subtensor import (
10+
AdvancedIncSubtensor,
11+
AdvancedSubtensor,
12+
AdvancedSubtensor1,
13+
Subtensor,
14+
)
1115
from pytensor.tensor.type import (
1216
values_eq_approx_remove_inf,
1317
values_eq_approx_remove_nan,
1418
)
1519

1620

17-
# This is not registered in stabilize, as it cause some crossentropy
18-
# optimization to not be inserted.
19-
@register_specialize("stabilize", "fast_compile")
20-
@node_rewriter([Elemwise])
21+
subtensor_ops = (
22+
Subtensor,
23+
AdvancedSubtensor,
24+
AdvancedSubtensor1,
25+
)
26+
27+
28+
@register_stabilize
29+
@node_rewriter([log])
2130
def local_logsoftmax(fgraph, node):
2231
"""
2332
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
2433
34+
This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax
35+
2536
Note: only forward pass is affected
2637
"""
27-
if (
28-
isinstance(node.op, Elemwise)
29-
and isinstance(node.op.scalar_op, aes.Log)
30-
and len(node.inputs) == 1
31-
and node.inputs[0].owner is not None
32-
and isinstance(node.inputs[0].owner.op, Softmax)
33-
):
34-
inVars = node.inputs[0].owner.inputs[0]
35-
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
36-
ret = new_op(inVars)
37-
ret.tag.values_eq_approx = values_eq_approx_remove_inf
38-
copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
39-
return [ret]
38+
39+
def find_softmax_under_lifteable_ops(inp_node, ops_to_lift):
40+
if inp_node is None:
41+
return
42+
43+
if isinstance(inp_node.op, Softmax):
44+
return inp_node
45+
46+
if isinstance(inp_node.op, subtensor_ops):
47+
ops_to_lift.append((inp_node.op, inp_node.inputs[1:]))
48+
return find_softmax_under_lifteable_ops(
49+
inp_node.inputs[0].owner, ops_to_lift
50+
)
51+
52+
if isinstance(inp_node.op, DimShuffle):
53+
ops_to_lift.append((inp_node.op, ()))
54+
return find_softmax_under_lifteable_ops(
55+
inp_node.inputs[0].owner, ops_to_lift
56+
)
57+
58+
ops_to_lift = []
59+
softmax_node = find_softmax_under_lifteable_ops(node.inputs[0].owner, ops_to_lift)
60+
61+
if softmax_node is None:
62+
return
63+
64+
ret = log_softmax(softmax_node.inputs[0], axis=softmax_node.op.axis)
65+
ret.tag.values_eq_approx = values_eq_approx_remove_inf
66+
67+
# Lift ops that used to be between log and softmax
68+
for op_to_lift, parameters in reversed(ops_to_lift):
69+
ret = op_to_lift(ret, *parameters)
70+
71+
copy_stack_trace(node.outputs, ret)
72+
return [ret]
4073

4174

42-
# This is not registered in stabilize, as it cause some crossentropy
43-
# optimization to not be inserted.
44-
@register_specialize("stabilize", "fast_compile")
75+
@register_stabilize
4576
@node_rewriter([SoftmaxGrad])
4677
def local_logsoftmax_grad(fgraph, node):
4778
"""
@@ -50,9 +81,7 @@ def local_logsoftmax_grad(fgraph, node):
5081
Note: only grad is affected
5182
"""
5283
if (
53-
isinstance(node.op, SoftmaxGrad)
54-
and len(node.inputs) == 2
55-
and node.inputs[0].owner is not None
84+
node.inputs[0].owner is not None
5685
and node.inputs[0].owner.op == true_div
5786
and len(node.inputs[0].owner.inputs) >= 2
5887
and node.inputs[0].owner.inputs[1].owner is not None

tests/tensor/rewriting/test_special.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
import scipy.special
34

45
import pytensor
56
from pytensor import shared
@@ -35,6 +36,37 @@ def test_local_logsoftmax_rewrite(self, axis):
3536
_fast_run_rewrites.rewrite(fgraph)
3637
assert isinstance(fgraph.outputs[0].owner.op, LogSoftmax)
3738
assert check_stack_trace(fgraph, ops_to_check=LogSoftmax)
39+
assert check_stack_trace(fgraph, ops_to_check="all")
40+
41+
@pytest.mark.parametrize("axis", [None, 0, -1])
42+
@pytest.mark.parametrize("idx0", [0, slice(1, None), slice(None)])
43+
@pytest.mark.parametrize("idx1", [None, [0, 1, 1, -1]])
44+
def test_logsoftmax_subtensor_dimshuffle(self, axis, idx0, idx1):
45+
"""Test that stabilization is introduced even when subtensor or dimshuffle operations
46+
are present between log and softmax.
47+
"""
48+
logit_p = matrix("logit_p")
49+
p = softmax(logit_p, axis=axis)
50+
p_indexed = p[(idx0, idx1)]
51+
out = log(p_indexed)
52+
53+
# Don't waste time with C compilation
54+
with config.change_flags(cxx=""):
55+
mode = get_mode(None).including("stabilize")
56+
fn = pytensor.function([logit_p], out, mode=mode)
57+
58+
assert not any(
59+
isinstance(node.op, Softmax) for node in fn.maker.fgraph.apply_nodes
60+
)
61+
62+
# This range would lead to underflow to -inf without the stabilization
63+
test_logit_p = np.array(
64+
[[-10.0, -10.0, 999.0], [999.0, 990.0, -10.0]], dtype=config.floatX
65+
)
66+
np.testing.assert_allclose(
67+
fn(logit_p=test_logit_p),
68+
scipy.special.log_softmax(test_logit_p, axis=axis)[(idx0, idx1)],
69+
)
3870

3971
@pytest.mark.parametrize("axis", [None, 0, -1])
4072
def test_local_logsoftmax_grad_rewrite(self, axis):
@@ -46,7 +78,7 @@ def test_local_logsoftmax_grad_rewrite(self, axis):
4678
"""
4779

4880
m = config.mode
49-
m = get_mode(m)
81+
m = get_mode(m).including("stabilize")
5082
m.check_isfinite = False
5183
# some inputs that are large to make the gradient explode in the non
5284
# rewritten case
@@ -91,29 +123,6 @@ def test_logsoftmax_grad_true_div_elemwise(self):
91123
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
92124

93125

94-
def test_log_softmax_stabilization():
95-
mode = pytensor.compile.mode.get_default_mode()
96-
mode = mode.including("local_log_softmax", "specialize")
97-
98-
x = matrix()
99-
y = softmax(x, axis=-1)
100-
z = log(y)
101-
102-
fgraph = FunctionGraph([x], [z])
103-
_fast_run_rewrites(fgraph)
104-
assert check_stack_trace(fgraph, ops_to_check="all")
105-
106-
# Check that the softmax has been rewritten
107-
for node in fgraph.toposort():
108-
assert not isinstance(node.op, Softmax)
109-
110-
# Call the function so debug mode can verify the rewritten version matches
111-
# the un-rewritten version
112-
f = pytensor.function([x], z, mode=mode)
113-
rng = np.random.default_rng(utt.fetch_seed())
114-
f(np.cast[config.floatX](rng.random((2, 3))))
115-
116-
117126
def test_softmax_graph():
118127
"""Make sure that sotfmax expressions are turned into
119128
a softmax Op.

0 commit comments

Comments
 (0)