@@ -142,7 +142,7 @@ def _fftn_onnx(
142
142
143
143
@torch_op ("aten::_fft_c2c" , trace_only = True , complex = True )
144
144
def aten__fft_c2c (
145
- transformed : TFloat , dim : Sequence [int ], normalization : int , forward : bool
145
+ self : TFloat , dim : Sequence [int ], normalization : int , forward : bool
146
146
) -> TFloat :
147
147
"""_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
148
148
@@ -153,26 +153,27 @@ def aten__fft_c2c(
153
153
154
154
# ONNX DFT input assumes the last dimension is the complex dimension.
155
155
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
156
- self_rank = len (transformed .shape )
157
- signal_size = op .CastLike (op .Size (transformed ), transformed )
156
+ self_rank = len (self .shape )
158
157
159
158
# ONNX DFT input assumes the last dimension is the complex dimension.
160
159
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
161
160
dim = [(d - 1 ) + self_rank if d < 0 else d for d in dim ]
162
161
163
162
unsqueeze_first_dim = 0 in dim
164
163
if unsqueeze_first_dim :
165
- transformed = op .Unsqueeze (transformed , axes = [0 ])
164
+ transformed = op .Unsqueeze (self , axes = [0 ])
166
165
# Add 1 to account for the batch dimension when counting axes from the left
167
166
dim = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dim ]
167
+ else :
168
+ transformed = self
168
169
169
170
for dimension in reversed (dim ):
170
171
transformed = op .DFT (transformed , axis = dimension , inverse = not forward , onesided = False )
171
172
if forward :
172
- transformed = _fftn_onnx_normalization (transformed , normalization , signal_size )
173
+ transformed = _fftn_onnx_normalization (transformed , normalization , op . CastLike ( self . shape [ dimension - unsqueeze_first_dim ], transformed ) )
173
174
else :
174
175
transformed = _fftn_onnx_inverse_normalization (
175
- transformed , normalization , signal_size
176
+ transformed , normalization , op . CastLike ( self . shape [ dimension - unsqueeze_first_dim ], transformed )
176
177
)
177
178
178
179
if unsqueeze_first_dim :
@@ -183,7 +184,7 @@ def aten__fft_c2c(
183
184
184
185
@torch_op ("aten::_fft_c2r" , trace_only = True , complex = True )
185
186
def aten__fft_c2r (
186
- transformed : TFloat ,
187
+ self : TFloat ,
187
188
dim : Sequence [int ],
188
189
normalization : int ,
189
190
last_dim_size : INT64 ,
@@ -208,7 +209,7 @@ def aten__fft_c2r(
208
209
209
210
@torch_op ("aten::_fft_r2c" , trace_only = True )
210
211
def aten__fft_r2c (
211
- transformed : TFloat , dim : Sequence [int ], normalization : int , onesided : bool
212
+ self : TFloat , dim : Sequence [int ], normalization : int , onesided : bool
212
213
) -> TFloat :
213
214
"""_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
214
215
@@ -218,11 +219,10 @@ def aten__fft_r2c(
218
219
# No need to fill the imaginary part because ONNX DFT accepts real inputs
219
220
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs
220
221
221
- self_rank = len (transformed .shape )
222
- signal_size = op .CastLike (op .Size (transformed ), transformed )
222
+ self_rank = len (self .shape )
223
223
224
224
# Add a new dimension at the end
225
- transformed = op .Unsqueeze (transformed , axes = [- 1 ])
225
+ transformed = op .Unsqueeze (self , axes = [- 1 ])
226
226
227
227
# ONNX DFT input assumes the last dimension is the complex dimension.
228
228
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
@@ -234,13 +234,13 @@ def aten__fft_r2c(
234
234
# Add 1 to account for the batch dimension when counting axes from the left
235
235
dim = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dim ]
236
236
237
- # Torch computes one-sided FFT on the last dimension only.
238
- transformed = op . DFT ( transformed , axis = dim [ - 1 ], inverse = False , onesided = onesided )
239
- transformed = _fftn_onnx_normalization (transformed , normalization , signal_size )
240
-
241
- for dimension in reversed ( dim [: - 1 ]):
242
- transformed = op .DFT (transformed , axis = dimension , inverse = False , onesided = False )
243
- transformed = _fftn_onnx_normalization (transformed , normalization , signal_size )
237
+ for idx , dimension in enumerate ( reversed ( dim )):
238
+ if idx > 0 :
239
+ transformed = op . DFT (transformed , axis = dimension , inverse = False , onesided = False )
240
+ else :
241
+ # Torch computes one-sided FFT on the last dimension only.
242
+ transformed = op .DFT (transformed , axis = dimension , inverse = False , onesided = onesided )
243
+ transformed = _fftn_onnx_normalization (transformed , normalization , op . CastLike ( self . shape [ dimension - unsqueeze_first_dim ], transformed ) )
244
244
245
245
if unsqueeze_first_dim :
246
246
transformed = op .Squeeze (transformed , axes = [0 ])
0 commit comments