21
21
from onnxscript .onnx_types import TensorType
22
22
23
23
24
- @torch_op (
25
- ("aten::_fft_c2c" , "aten::_fft_c2r" , "aten::_fft_r2c" ),
26
- private = True ,
27
- complex = True ,
28
- trace_only = True ,
29
- )
30
24
def _fftn_onnx_normalization (
31
- self ,
32
- transformed : TFloat ,
25
+ self : TFloat ,
33
26
normalization : int ,
34
- forward : bool ,
35
- dims : Sequence [int ],
36
- ) -> TFloat :
37
- # Obtain the total_sample_count (n) for normalization
38
- self_shape = op .Shape (self )
39
- total_sample_count = op .ReduceProd (op .Gather (self_shape , dims ), keepdims = 0 )
40
- total_sample_count = op .CastLike (total_sample_count , transformed )
41
-
42
- # Normalize the result
43
- # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
44
- # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
45
- if normalization == 1 :
46
- # "forward" - normalize by 1/n
47
- if forward :
48
- result = op .Div (transformed , op .Sqrt (total_sample_count ))
49
- else :
50
- result = op .Mul (transformed , op .Sqrt (total_sample_count ))
51
- elif normalization == 2 :
52
- # "ortho" - normalize by 1/sqrt(n)
53
- if forward :
54
- result = op .Div (transformed , total_sample_count )
55
- else :
56
- result = transformed
57
- else :
58
- # "backward" - no normalization
59
- if forward :
60
- result = transformed
61
- else :
62
- result = op .Mul (transformed , total_sample_count )
63
-
64
- return result
65
-
66
-
67
- @torch_op (
68
- ("aten::_fft_c2c" , "aten::_fft_c2r" , "aten::_fft_r2c" ),
69
- trace_only = True ,
70
- private = True ,
71
- complex = True ,
72
- )
73
- def _fftn_onnx (
74
- self : TFloat , dims : Sequence [int ], normalization : int , inverse : bool , onesided : bool
27
+ signal_size : INT64 ,
28
+ inverse : bool = False ,
75
29
) -> TFloat :
76
- """Standard complex to complex or real to complex FFT (forward or backward).
77
-
78
- This is a private shared function for implementing the various FFT functions.
79
-
80
- Args:
81
- self: The input tensor.
82
- dims: The dimensions to apply FFT.
83
- normalization: The normalization mode.
84
- inverse: Whether to compute the inverse FFT.
85
- onesided: Whether to compute the one-sided FFT, which retains only the
86
- positive frequencies.
87
-
88
- Returns:
89
- The transformed tensor.
90
- """
91
-
92
- # NOTE: trace_only because we need to process each dimension in a loop
93
- # NOTE: SymInt dim is not support because DFT-17 needs a static axis
94
- # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
95
-
96
- # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
97
- # dimension at the beginning to represent the batch dimension.
98
- transformed = op .Unsqueeze (self , axes = [0 ])
99
-
100
- # Add 1 to account for the batch dimension when counting axes from the left
101
- new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims ]
102
-
103
- for dim in new_dims [:- 1 ]:
104
- transformed = op .DFT (transformed , axis = dim , inverse = inverse , onesided = False )
105
-
106
- # Torch computers one-sided FFT on the last dimension only.
107
- if onesided :
108
- transformed = op .DFT (transformed , axis = new_dims [- 1 ], inverse = inverse , onesided = True )
30
+ """Normalize in forward or backward direction."""
31
+ # Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
32
+ # Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
33
+ # Modes:
34
+ # 0: no normalization (backward)
35
+ # 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
36
+ # 2: divide by signal_size (forward)
37
+ signal_size = op .CastLike (signal_size , self )
38
+ if not inverse :
39
+ # Forward normalization
40
+ if normalization == 1 :
41
+ self = op .Div (self , op .Sqrt (signal_size ))
42
+ elif normalization == 2 :
43
+ self = op .Div (self , signal_size )
109
44
else :
110
- transformed = op .DFT ( transformed , axis = new_dims [ - 1 ], inverse = inverse , onesided = False )
111
-
112
- # Remove the batch dimension
113
- transformed = op . Squeeze ( transformed , axes = [ 0 ])
114
-
115
- return _fftn_onnx_normalization ( self , transformed , normalization , not inverse , dims )
45
+ # Backward normalization, accounting for op.DFT already dividing by signal_size
46
+ if normalization == 0 :
47
+ self = op . Mul ( self , signal_size )
48
+ elif normalization == 1 :
49
+ self = op . Mul ( self , op . Sqrt ( signal_size ))
50
+ return self
116
51
117
52
118
53
@torch_op ("aten::_fft_c2c" , trace_only = True , complex = True )
@@ -124,39 +59,87 @@ def aten__fft_c2c(
124
59
Standard complex to complex FFT (forward or backward).
125
60
"""
126
61
127
- # NOTE: trace_only because we need to negate forward
128
- # NOTE: SymInt dim is not support because DFT-17 needs a static axis
129
- # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
62
+ # NOTE: SymInt dim is not supported because DFT-17 needs a static axis
130
63
131
64
# ONNX DFT input assumes the last dimension is the complex dimension.
132
- # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
133
- dim = [d - 1 if d < 0 else d for d in dim ]
134
- return _fftn_onnx (self , dim , normalization , inverse = not forward , onesided = False )
65
+
66
+ unsqueeze_first_dim = 0 in dim
67
+ # 1. Add a new dimension for the end and batch dimension, if needed
68
+ # 2. ONNX DFT input assumes the last dimension is the complex dimension.
69
+ # If needed, add 1 to account for the batch dimension.
70
+
71
+ if unsqueeze_first_dim :
72
+ transformed = op .Unsqueeze (self , axes = [0 ])
73
+ dim = [d + 1 for d in dim ]
74
+ else :
75
+ transformed = self
76
+
77
+ for dimension in reversed (dim ):
78
+ transformed = op .DFT (transformed , axis = dimension , inverse = not forward , onesided = False )
79
+ transformed = _fftn_onnx_normalization (
80
+ transformed ,
81
+ normalization ,
82
+ op .Shape (transformed , start = dimension , end = dimension + 1 ),
83
+ not forward ,
84
+ )
85
+
86
+ if unsqueeze_first_dim :
87
+ transformed = op .Squeeze (transformed , axes = [0 ])
88
+
89
+ return transformed
135
90
136
91
137
92
@torch_op ("aten::_fft_c2r" , trace_only = True , complex = True )
138
93
def aten__fft_c2r (
139
94
self : TFloat ,
140
95
dim : Sequence [int ],
141
96
normalization : int ,
142
- last_dim_size : INT64 , # pylint: disable=unused-argument
97
+ last_dim_size : INT64 ,
143
98
) -> TFloat :
144
99
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
145
100
146
- Complex to real inverse FFT.
101
+ Complex to real inverse FFT. Assumes that input tensor is output of previous FFT operation.
147
102
"""
148
-
149
- # TODO(justinchuby): Figure out what last_dim_size does
150
-
151
- self_rank = len (self .shape )
152
- # ONNX DFT input assumes the last dimension is the complex dimension.
153
- # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
154
- dim = [(d - 1 ) + self_rank if d < 0 else d for d in dim ]
155
- transformed = _fftn_onnx (self , dim , normalization , inverse = True , onesided = False )
156
- # Take only the real part
157
- real_part = op .Slice (transformed , axes = [- 1 ], starts = [0 ], ends = [1 ])
158
-
159
- return op .Squeeze (real_part , axes = [- 1 ])
103
+ if len (dim ) != 1 :
104
+ raise NotImplementedError ("Only one dimension is supported for inverse FFT" )
105
+
106
+ dimension = dim [0 ]
107
+ unsqueeze_first_dim = dimension == 0
108
+ # 1. Add a new dimension for batch dimension, if needed
109
+ # 2. ONNX DFT input assumes the last dimension is the complex dimension.
110
+ # If needed, add 1 to account for the batch dimension.
111
+
112
+ if unsqueeze_first_dim :
113
+ transformed = op .Unsqueeze (self , axes = [0 ])
114
+ dimension = 1
115
+ else :
116
+ transformed = self
117
+
118
+ # Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed
119
+ # into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we
120
+ # place no such restriction on the ONNX side.
121
+ transformed = op .DFT (
122
+ transformed ,
123
+ dft_length = last_dim_size ,
124
+ axis = dimension ,
125
+ inverse = True ,
126
+ onesided = False ,
127
+ )
128
+ transformed = _fftn_onnx_normalization (
129
+ transformed ,
130
+ normalization ,
131
+ op .Shape (transformed , start = dimension , end = dimension + 1 ),
132
+ inverse = True ,
133
+ )
134
+
135
+ if unsqueeze_first_dim :
136
+ transformed = op .Squeeze (transformed , axes = [0 ])
137
+
138
+ # Remove the imaginary part
139
+ transformed = op .Slice (transformed , [0 ], [1 ], [- 1 ])
140
+ transformed = op .Squeeze (transformed , axes = [- 1 ])
141
+
142
+ return transformed
160
143
161
144
162
145
@torch_op ("aten::_fft_r2c" , trace_only = True )
@@ -168,17 +151,37 @@ def aten__fft_r2c(
168
151
Real to complex forward FFT.
169
152
"""
170
153
171
- # Add a new dimension at the end
172
- signal = op .Unsqueeze (self , axes = [- 1 ])
173
154
# No need to fill the imaginary part because ONNX DFT accepts real inputs
174
155
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs
175
156
176
- self_rank = len (self .shape )
177
- # ONNX DFT input assumes the last dimension is the complex dimension.
178
- # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
179
- dim = [(d - 1 ) + self_rank if d < 0 else d for d in dim ]
157
+ unsqueeze_first_dim = 0 in dim
158
+ # 1. Add a new dimension for the end and batch dimension, if needed
159
+ # 2. ONNX DFT input assumes the last dimension is the complex dimension.
160
+ # If needed, add 1 to account for the batch dimension.
161
+
162
+ if unsqueeze_first_dim :
163
+ transformed = op .Unsqueeze (self , axes = [0 , - 1 ])
164
+ dim = [d + 1 for d in dim ]
165
+ else :
166
+ transformed = op .Unsqueeze (self , axes = [- 1 ])
167
+
168
+ for idx , dimension in enumerate (reversed (dim )):
169
+ transformed = _fftn_onnx_normalization (
170
+ transformed ,
171
+ normalization ,
172
+ op .Shape (transformed , start = dimension , end = dimension + 1 ),
173
+ inverse = False ,
174
+ )
175
+ if idx > 0 :
176
+ transformed = op .DFT (transformed , axis = dimension , inverse = False , onesided = False )
177
+ else :
178
+ # Torch computes one-sided FFT on the last dimension only.
179
+ transformed = op .DFT (transformed , axis = dimension , inverse = False , onesided = onesided )
180
+
181
+ if unsqueeze_first_dim :
182
+ transformed = op .Squeeze (transformed , axes = [0 ])
180
183
181
- return _fftn_onnx ( signal , dim , normalization , inverse = False , onesided = onesided )
184
+ return transformed
182
185
183
186
184
187
def aten_fft_fft (
0 commit comments