Skip to content

Commit 397baa1

Browse files
Implement fft torchop (#2141)
WIP - Implement aten__fft_r2c, aten__fft_c2r, aten__fft_c2c r2c = forwards, could be one-sided c2r = backwards/inverse, never one-sided c2c could be either forwards/backwards, never one-sided Must respect normalization method provided - however, op.DFT calls "backwards" normalization, if 'inverse' is set to True, so need to account for normalization being done by op.DFT When running above functions across multiple axes, need to run FFT in reverse order through op.DFT one-by-one Currently have issues with: - c2r has extra parameter of last_dim_size, so must truncate/zero-pad to ensure last dimension size matches last_dim_size -- still debugging this part to avoid triggering errors #1271 --------- Co-authored-by: Justin Chu <[email protected]>
1 parent 4905bfd commit 397baa1

File tree

4 files changed

+158
-144
lines changed

4 files changed

+158
-144
lines changed

onnxscript/function_libs/torch_lib/ops/fft.py

Lines changed: 118 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -21,98 +21,33 @@
2121
from onnxscript.onnx_types import TensorType
2222

2323

24-
@torch_op(
25-
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
26-
private=True,
27-
complex=True,
28-
trace_only=True,
29-
)
3024
def _fftn_onnx_normalization(
31-
self,
32-
transformed: TFloat,
25+
self: TFloat,
3326
normalization: int,
34-
forward: bool,
35-
dims: Sequence[int],
36-
) -> TFloat:
37-
# Obtain the total_sample_count (n) for normalization
38-
self_shape = op.Shape(self)
39-
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0)
40-
total_sample_count = op.CastLike(total_sample_count, transformed)
41-
42-
# Normalize the result
43-
# Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
44-
# Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
45-
if normalization == 1:
46-
# "forward" - normalize by 1/n
47-
if forward:
48-
result = op.Div(transformed, op.Sqrt(total_sample_count))
49-
else:
50-
result = op.Mul(transformed, op.Sqrt(total_sample_count))
51-
elif normalization == 2:
52-
# "ortho" - normalize by 1/sqrt(n)
53-
if forward:
54-
result = op.Div(transformed, total_sample_count)
55-
else:
56-
result = transformed
57-
else:
58-
# "backward" - no normalization
59-
if forward:
60-
result = transformed
61-
else:
62-
result = op.Mul(transformed, total_sample_count)
63-
64-
return result
65-
66-
67-
@torch_op(
68-
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
69-
trace_only=True,
70-
private=True,
71-
complex=True,
72-
)
73-
def _fftn_onnx(
74-
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
27+
signal_size: INT64,
28+
inverse: bool = False,
7529
) -> TFloat:
76-
"""Standard complex to complex or real to complex FFT (forward or backward).
77-
78-
This is a private shared function for implementing the various FFT functions.
79-
80-
Args:
81-
self: The input tensor.
82-
dims: The dimensions to apply FFT.
83-
normalization: The normalization mode.
84-
inverse: Whether to compute the inverse FFT.
85-
onesided: Whether to compute the one-sided FFT, which retains only the
86-
positive frequencies.
87-
88-
Returns:
89-
The transformed tensor.
90-
"""
91-
92-
# NOTE: trace_only because we need to process each dimension in a loop
93-
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
94-
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
95-
96-
# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
97-
# dimension at the beginning to represent the batch dimension.
98-
transformed = op.Unsqueeze(self, axes=[0])
99-
100-
# Add 1 to account for the batch dimension when counting axes from the left
101-
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
102-
103-
for dim in new_dims[:-1]:
104-
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)
105-
106-
# Torch computers one-sided FFT on the last dimension only.
107-
if onesided:
108-
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True)
30+
"""Normalize in forward or backward direction."""
31+
# Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
32+
# Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
33+
# Modes:
34+
# 0: no normalization (backward)
35+
# 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
36+
# 2: divide by signal_size (forward)
37+
signal_size = op.CastLike(signal_size, self)
38+
if not inverse:
39+
# Forward normalization
40+
if normalization == 1:
41+
self = op.Div(self, op.Sqrt(signal_size))
42+
elif normalization == 2:
43+
self = op.Div(self, signal_size)
10944
else:
110-
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False)
111-
112-
# Remove the batch dimension
113-
transformed = op.Squeeze(transformed, axes=[0])
114-
115-
return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims)
45+
# Backward normalization, accounting for op.DFT already dividing by signal_size
46+
if normalization == 0:
47+
self = op.Mul(self, signal_size)
48+
elif normalization == 1:
49+
self = op.Mul(self, op.Sqrt(signal_size))
50+
return self
11651

11752

11853
@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
@@ -124,39 +59,87 @@ def aten__fft_c2c(
12459
Standard complex to complex FFT (forward or backward).
12560
"""
12661

127-
# NOTE: trace_only because we need to negate forward
128-
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
129-
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
62+
# NOTE: SymInt dim is not supported because DFT-17 needs a static axis
13063

13164
# ONNX DFT input assumes the last dimension is the complex dimension.
132-
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
133-
dim = [d - 1 if d < 0 else d for d in dim]
134-
return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False)
65+
66+
unsqueeze_first_dim = 0 in dim
67+
# 1. Add a new dimension for the end and batch dimension, if needed
68+
# 2. ONNX DFT input assumes the last dimension is the complex dimension.
69+
# If needed, add 1 to account for the batch dimension.
70+
71+
if unsqueeze_first_dim:
72+
transformed = op.Unsqueeze(self, axes=[0])
73+
dim = [d + 1 for d in dim]
74+
else:
75+
transformed = self
76+
77+
for dimension in reversed(dim):
78+
transformed = op.DFT(transformed, axis=dimension, inverse=not forward, onesided=False)
79+
transformed = _fftn_onnx_normalization(
80+
transformed,
81+
normalization,
82+
op.Shape(transformed, start=dimension, end=dimension + 1),
83+
not forward,
84+
)
85+
86+
if unsqueeze_first_dim:
87+
transformed = op.Squeeze(transformed, axes=[0])
88+
89+
return transformed
13590

13691

13792
@torch_op("aten::_fft_c2r", trace_only=True, complex=True)
13893
def aten__fft_c2r(
13994
self: TFloat,
14095
dim: Sequence[int],
14196
normalization: int,
142-
last_dim_size: INT64, # pylint: disable=unused-argument
97+
last_dim_size: INT64,
14398
) -> TFloat:
14499
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
145100
146-
Complex to real inverse FFT.
101+
Complex to real inverse FFT. Assumes that input tensor is output of previous FFT operation.
147102
"""
148-
149-
# TODO(justinchuby): Figure out what last_dim_size does
150-
151-
self_rank = len(self.shape)
152-
# ONNX DFT input assumes the last dimension is the complex dimension.
153-
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
154-
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
155-
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
156-
# Take only the real part
157-
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])
158-
159-
return op.Squeeze(real_part, axes=[-1])
103+
if len(dim) != 1:
104+
raise NotImplementedError("Only one dimension is supported for inverse FFT")
105+
106+
dimension = dim[0]
107+
unsqueeze_first_dim = dimension == 0
108+
# 1. Add a new dimension for batch dimension, if needed
109+
# 2. ONNX DFT input assumes the last dimension is the complex dimension.
110+
# If needed, add 1 to account for the batch dimension.
111+
112+
if unsqueeze_first_dim:
113+
transformed = op.Unsqueeze(self, axes=[0])
114+
dimension = 1
115+
else:
116+
transformed = self
117+
118+
# Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
119+
# into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
120+
# place no such restriction on the ONNX side.
121+
transformed = op.DFT(
122+
transformed,
123+
dft_length=last_dim_size,
124+
axis=dimension,
125+
inverse=True,
126+
onesided=False,
127+
)
128+
transformed = _fftn_onnx_normalization(
129+
transformed,
130+
normalization,
131+
op.Shape(transformed, start=dimension, end=dimension + 1),
132+
inverse=True,
133+
)
134+
135+
if unsqueeze_first_dim:
136+
transformed = op.Squeeze(transformed, axes=[0])
137+
138+
# Remove the imaginary part
139+
transformed = op.Slice(transformed, [0], [1], [-1])
140+
transformed = op.Squeeze(transformed, axes=[-1])
141+
142+
return transformed
160143

161144

162145
@torch_op("aten::_fft_r2c", trace_only=True)
@@ -168,17 +151,37 @@ def aten__fft_r2c(
168151
Real to complex forward FFT.
169152
"""
170153

171-
# Add a new dimension at the end
172-
signal = op.Unsqueeze(self, axes=[-1])
173154
# No need to fill the imaginary part because ONNX DFT accepts real inputs
174155
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs
175156

176-
self_rank = len(self.shape)
177-
# ONNX DFT input assumes the last dimension is the complex dimension.
178-
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
179-
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
157+
unsqueeze_first_dim = 0 in dim
158+
# 1. Add a new dimension for the end and batch dimension, if needed
159+
# 2. ONNX DFT input assumes the last dimension is the complex dimension.
160+
# If needed, add 1 to account for the batch dimension.
161+
162+
if unsqueeze_first_dim:
163+
transformed = op.Unsqueeze(self, axes=[0, -1])
164+
dim = [d + 1 for d in dim]
165+
else:
166+
transformed = op.Unsqueeze(self, axes=[-1])
167+
168+
for idx, dimension in enumerate(reversed(dim)):
169+
transformed = _fftn_onnx_normalization(
170+
transformed,
171+
normalization,
172+
op.Shape(transformed, start=dimension, end=dimension + 1),
173+
inverse=False,
174+
)
175+
if idx > 0:
176+
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=False)
177+
else:
178+
# Torch computes one-sided FFT on the last dimension only.
179+
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=onesided)
180+
181+
if unsqueeze_first_dim:
182+
transformed = op.Squeeze(transformed, axes=[0])
180183

181-
return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided)
184+
return transformed
182185

183186

184187
def aten_fft_fft(

onnxscript/ir/tensor_adapters_test.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,25 +55,25 @@ def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):
5555

5656
@parameterized.parameterized.expand(
5757
[
58-
(torch.bfloat16),
59-
(torch.bool),
60-
(torch.complex128),
61-
(torch.complex64),
62-
(torch.float16),
63-
(torch.float32),
64-
(torch.float64),
65-
(torch.float8_e4m3fn),
66-
(torch.float8_e4m3fnuz),
67-
(torch.float8_e5m2),
68-
(torch.float8_e5m2fnuz),
69-
(torch.int16),
70-
(torch.int32),
71-
(torch.int64),
72-
(torch.int8),
73-
(torch.uint16),
74-
(torch.uint32),
75-
(torch.uint64),
76-
(torch.uint8),
58+
(torch.bfloat16,),
59+
(torch.bool,),
60+
(torch.complex128,),
61+
(torch.complex64,),
62+
(torch.float16,),
63+
(torch.float32,),
64+
(torch.float64,),
65+
(torch.float8_e4m3fn,),
66+
(torch.float8_e4m3fnuz,),
67+
(torch.float8_e5m2,),
68+
(torch.float8_e5m2fnuz,),
69+
(torch.int16,),
70+
(torch.int32,),
71+
(torch.int64,),
72+
(torch.int8,),
73+
(torch.uint16,),
74+
(torch.uint32,),
75+
(torch.uint64,),
76+
(torch.uint8,),
7777
],
7878
)
7979
def test_tobytes(self, dtype: torch.dtype):

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -684,24 +684,38 @@ def sample_inputs__fft_r2c(self, device, dtype, requires_grad=False, **_):
684684

685685
def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
686686
del self # Unused
687-
oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad)
688-
687+
real_dtype = torch.float
688+
if dtype == torch.complex128:
689+
real_dtype = torch.double
690+
oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, real_dtype, requires_grad)
691+
oned_tensor_result = oned_tensor()
692+
nd_tensor_result = nd_tensor()
693+
complex_oned_tensor = torch.ops.aten._fft_r2c.default( # pylint: disable=protected-access
694+
oned_tensor_result, [0], normalization=0, onesided=False
695+
)
696+
# for normalization in (0, 1, 2):
689697
for normalization in (0, 1, 2):
690698
# 1-D
691699
yield opinfo_core.SampleInput(
692-
oned_tensor(), dim=(0,), normalization=normalization, last_dim_size=12
700+
complex_oned_tensor,
701+
dim=(0,),
702+
normalization=normalization,
703+
last_dim_size=31,
693704
)
694705
# N-D
695706
for dim in [
696707
(0,),
697708
(1,),
698709
(2,),
699-
(1, 2),
700-
(0, 1),
701-
(0, 1, 2),
702710
]:
711+
complex_nd_tensor = torch.ops.aten._fft_r2c.default( # pylint: disable=protected-access
712+
nd_tensor_result, dim, normalization=0, onesided=False
713+
)
703714
yield opinfo_core.SampleInput(
704-
nd_tensor(), dim=dim, normalization=normalization, last_dim_size=6
715+
complex_nd_tensor,
716+
dim=dim,
717+
normalization=normalization,
718+
last_dim_size=complex_nd_tensor.shape[dim[-1]],
705719
)
706720

707721

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,6 @@ def _where_input_wrangler(
452452
fft_ops.aten__fft_c2r,
453453
tolerance={torch.complex64: (3e-3, 1.8e-4)},
454454
complex=True,
455-
).xfail(
456-
dtypes=(torch.complex64,),
457-
reason="fixme: the result is wrong: https://github.com/microsoft/onnxscript/pull/926",
458455
),
459456
TorchLibOpInfo(
460457
"ops.aten._fft_r2c", # Custom from extra_opinfo

0 commit comments

Comments
 (0)