1
- from pytensor import scalar as aes
2
1
from pytensor .graph .rewriting .basic import copy_stack_trace , node_rewriter
3
- from pytensor .tensor .elemwise import DimShuffle , Elemwise
4
- from pytensor .tensor .math import Sum , exp
2
+ from pytensor .tensor .elemwise import DimShuffle
3
+ from pytensor .tensor .math import Sum , exp , log
5
4
from pytensor .tensor .math import sum as at_sum
6
5
from pytensor .tensor .math import true_div
7
- from pytensor .tensor .rewriting .basic import register_specialize
6
+ from pytensor .tensor .rewriting .basic import register_stabilize
8
7
from pytensor .tensor .rewriting .math import local_mul_canonizer
9
- from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
10
- from pytensor .tensor .subtensor import AdvancedIncSubtensor
8
+ from pytensor .tensor .special import Softmax , SoftmaxGrad , log_softmax
9
+ from pytensor .tensor .subtensor import (
10
+ AdvancedIncSubtensor ,
11
+ AdvancedSubtensor ,
12
+ AdvancedSubtensor1 ,
13
+ Subtensor ,
14
+ )
11
15
from pytensor .tensor .type import (
12
16
values_eq_approx_remove_inf ,
13
17
values_eq_approx_remove_nan ,
14
18
)
15
19
16
20
17
- # This is not registered in stabilize, as it cause some crossentropy
18
- # optimization to not be inserted.
19
- @register_specialize ("stabilize" , "fast_compile" )
20
- @node_rewriter ([Elemwise ])
21
+ subtensor_ops = (
22
+ Subtensor ,
23
+ AdvancedSubtensor ,
24
+ AdvancedSubtensor1 ,
25
+ )
26
+
27
+
28
+ @register_stabilize
29
+ @node_rewriter ([log ])
21
30
def local_logsoftmax (fgraph , node ):
22
31
"""
23
32
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
24
33
34
+ This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax
35
+
25
36
Note: only forward pass is affected
26
37
"""
27
- if (
28
- isinstance (node .op , Elemwise )
29
- and isinstance (node .op .scalar_op , aes .Log )
30
- and len (node .inputs ) == 1
31
- and node .inputs [0 ].owner is not None
32
- and isinstance (node .inputs [0 ].owner .op , Softmax )
33
- ):
34
- inVars = node .inputs [0 ].owner .inputs [0 ]
35
- new_op = LogSoftmax (axis = node .inputs [0 ].owner .op .axis )
36
- ret = new_op (inVars )
37
- ret .tag .values_eq_approx = values_eq_approx_remove_inf
38
- copy_stack_trace ([node .inputs [0 ], node .outputs [0 ]], ret )
39
- return [ret ]
38
+
39
+ def find_softmax_under_lifteable_ops (inp_node , ops_to_lift ):
40
+ if inp_node is None :
41
+ return
42
+
43
+ if isinstance (inp_node .op , Softmax ):
44
+ return inp_node
45
+
46
+ if isinstance (inp_node .op , subtensor_ops ):
47
+ ops_to_lift .append ((inp_node .op , inp_node .inputs [1 :]))
48
+ return find_softmax_under_lifteable_ops (
49
+ inp_node .inputs [0 ].owner , ops_to_lift
50
+ )
51
+
52
+ if isinstance (inp_node .op , DimShuffle ):
53
+ ops_to_lift .append ((inp_node .op , ()))
54
+ return find_softmax_under_lifteable_ops (
55
+ inp_node .inputs [0 ].owner , ops_to_lift
56
+ )
57
+
58
+ ops_to_lift = []
59
+ softmax_node = find_softmax_under_lifteable_ops (node .inputs [0 ].owner , ops_to_lift )
60
+
61
+ if softmax_node is None :
62
+ return
63
+
64
+ ret = log_softmax (softmax_node .inputs [0 ], axis = softmax_node .op .axis )
65
+ ret .tag .values_eq_approx = values_eq_approx_remove_inf
66
+
67
+ # Lift ops that used to be between log and softmax
68
+ for op_to_lift , parameters in reversed (ops_to_lift ):
69
+ ret = op_to_lift (ret , * parameters )
70
+
71
+ copy_stack_trace (node .outputs , ret )
72
+ return [ret ]
40
73
41
74
42
- # This is not registered in stabilize, as it cause some crossentropy
43
- # optimization to not be inserted.
44
- @register_specialize ("stabilize" , "fast_compile" )
75
+ @register_stabilize
45
76
@node_rewriter ([SoftmaxGrad ])
46
77
def local_logsoftmax_grad (fgraph , node ):
47
78
"""
@@ -50,9 +81,7 @@ def local_logsoftmax_grad(fgraph, node):
50
81
Note: only grad is affected
51
82
"""
52
83
if (
53
- isinstance (node .op , SoftmaxGrad )
54
- and len (node .inputs ) == 2
55
- and node .inputs [0 ].owner is not None
84
+ node .inputs [0 ].owner is not None
56
85
and node .inputs [0 ].owner .op == true_div
57
86
and len (node .inputs [0 ].owner .inputs ) >= 2
58
87
and node .inputs [0 ].owner .inputs [1 ].owner is not None
0 commit comments