66# can be found in the PATENTS file in the same directory.
77
88import math
9+
910import torch
1011import torch .nn as nn
1112import torch .nn .functional as F
@@ -121,6 +122,8 @@ def __init__(self, input_size, kernel_size=1, padding_l=None, num_heads=1,
121122
122123 self .reset_parameters ()
123124
125+ self .onnx_trace = False
126+
124127 def reset_parameters (self ):
125128 nn .init .xavier_uniform_ (self .weight )
126129 if self .bias is not None :
@@ -144,6 +147,9 @@ def forward(self, x, incremental_state=None, unfold=False):
144147 output = output + self .bias .view (1 , 1 , - 1 )
145148 return output
146149
150+ def prepare_for_onnx_export_ (self ):
151+ self .onnx_trace = True
152+
147153 def _forward_unfolded (self , x , incremental_state ):
148154 '''The conventional implementation of convolutions.
149155 Unfolding the input by having a window shifting to the right.'''
@@ -167,7 +173,7 @@ def _forward_unfolded(self, x, incremental_state):
167173 x_unfold = x_unfold .view (T * B * H , R , K )
168174
169175 if self .weight_softmax :
170- weight = F .softmax (weight . float () , dim = 1 ).type_as (weight )
176+ weight = utils .softmax (weight , dim = 1 , onnx_trace = self . onnx_trace ).type_as (weight )
171177
172178 if incremental_state is not None :
173179 weight = weight [:, - x_unfold .size (2 ):]
@@ -192,7 +198,7 @@ def _forward_expanded(self, x, incremental_state):
192198
193199 weight = self .weight .view (H , K )
194200 if self .weight_softmax :
195- weight = F .softmax (weight . float () , dim = 1 ).type_as (weight )
201+ weight = utils .softmax (weight , dim = 1 , onnx_trace = self . onnx_trace ).type_as (weight )
196202 weight = weight .view (1 , H , K ).expand (T * B , H , K ).contiguous ()
197203 weight = weight .view (T , B * H , K ).transpose (0 , 1 )
198204
0 commit comments