Skip to content

Commit 39afc8d

Browse files
committed
2 functions pass tests
1 parent 17c667d commit 39afc8d

File tree

1 file changed

+18
-18
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+18
-18
lines changed

onnxscript/function_libs/torch_lib/ops/fft.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _fftn_onnx(
142142

143143
@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
144144
def aten__fft_c2c(
145-
transformed: TFloat, dim: Sequence[int], normalization: int, forward: bool
145+
self: TFloat, dim: Sequence[int], normalization: int, forward: bool
146146
) -> TFloat:
147147
"""_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
148148
@@ -153,26 +153,27 @@ def aten__fft_c2c(
153153

154154
# ONNX DFT input assumes the last dimension is the complex dimension.
155155
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
156-
self_rank = len(transformed.shape)
157-
signal_size = op.CastLike(op.Size(transformed), transformed)
156+
self_rank = len(self.shape)
158157

159158
# ONNX DFT input assumes the last dimension is the complex dimension.
160159
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
161160
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
162161

163162
unsqueeze_first_dim = 0 in dim
164163
if unsqueeze_first_dim:
165-
transformed = op.Unsqueeze(transformed, axes=[0])
164+
transformed = op.Unsqueeze(self, axes=[0])
166165
# Add 1 to account for the batch dimension when counting axes from the left
167166
dim = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dim]
167+
else:
168+
transformed = self
168169

169170
for dimension in reversed(dim):
170171
transformed = op.DFT(transformed, axis=dimension, inverse=not forward, onesided=False)
171172
if forward:
172-
transformed = _fftn_onnx_normalization(transformed, normalization, signal_size)
173+
transformed = _fftn_onnx_normalization(transformed, normalization, op.CastLike(self.shape[dimension - unsqueeze_first_dim], transformed))
173174
else:
174175
transformed = _fftn_onnx_inverse_normalization(
175-
transformed, normalization, signal_size
176+
transformed, normalization, op.CastLike(self.shape[dimension - unsqueeze_first_dim], transformed)
176177
)
177178

178179
if unsqueeze_first_dim:
@@ -183,7 +184,7 @@ def aten__fft_c2c(
183184

184185
@torch_op("aten::_fft_c2r", trace_only=True, complex=True)
185186
def aten__fft_c2r(
186-
transformed: TFloat,
187+
self: TFloat,
187188
dim: Sequence[int],
188189
normalization: int,
189190
last_dim_size: INT64,
@@ -208,7 +209,7 @@ def aten__fft_c2r(
208209

209210
@torch_op("aten::_fft_r2c", trace_only=True)
210211
def aten__fft_r2c(
211-
transformed: TFloat, dim: Sequence[int], normalization: int, onesided: bool
212+
self: TFloat, dim: Sequence[int], normalization: int, onesided: bool
212213
) -> TFloat:
213214
"""_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
214215
@@ -218,11 +219,10 @@ def aten__fft_r2c(
218219
# No need to fill the imaginary part because ONNX DFT accepts real inputs
219220
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs
220221

221-
self_rank = len(transformed.shape)
222-
signal_size = op.CastLike(op.Size(transformed), transformed)
222+
self_rank = len(self.shape)
223223

224224
# Add a new dimension at the end
225-
transformed = op.Unsqueeze(transformed, axes=[-1])
225+
transformed = op.Unsqueeze(self, axes=[-1])
226226

227227
# ONNX DFT input assumes the last dimension is the complex dimension.
228228
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
@@ -234,13 +234,13 @@ def aten__fft_r2c(
234234
# Add 1 to account for the batch dimension when counting axes from the left
235235
dim = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dim]
236236

237-
# Torch computes one-sided FFT on the last dimension only.
238-
transformed = op.DFT(transformed, axis=dim[-1], inverse=False, onesided=onesided)
239-
transformed = _fftn_onnx_normalization(transformed, normalization, signal_size)
240-
241-
for dimension in reversed(dim[:-1]):
242-
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=False)
243-
transformed = _fftn_onnx_normalization(transformed, normalization, signal_size)
237+
for idx, dimension in enumerate(reversed(dim)):
238+
if idx > 0:
239+
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=False)
240+
else:
241+
# Torch computes one-sided FFT on the last dimension only.
242+
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=onesided)
243+
transformed = _fftn_onnx_normalization(transformed, normalization, op.CastLike(self.shape[dimension - unsqueeze_first_dim], transformed))
244244

245245
if unsqueeze_first_dim:
246246
transformed = op.Squeeze(transformed, axes=[0])

0 commit comments

Comments
 (0)