Skip to content

Commit 6323b65

Browse files
authored
Merge branch 'main' into titaiwang/add_op_native_batch_norm
2 parents a5cad9f + 70843ef commit 6323b65

File tree

5 files changed

+236
-1
lines changed

5 files changed

+236
-1
lines changed

onnxscript/function_libs/torch_lib/ops/fft.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,165 @@
1313

1414
from typing import Optional, Sequence
1515

16+
from onnxscript import INT64
17+
from onnxscript.function_libs.torch_lib.registration import torch_op
18+
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
19+
from onnxscript.onnx_opset import opset18 as op
1620
from onnxscript.onnx_types import TensorType
1721

1822

23+
@torch_op(
24+
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
25+
private=True,
26+
complex=True,
27+
)
28+
def _fftn_onnx_normalization(
29+
self,
30+
transformed: TFloat,
31+
normalization: int,
32+
forward: bool,
33+
dims: Sequence[int],
34+
) -> TFloat:
35+
# Obtain the total_sample_count (n) for normalization
36+
self_shape = op.Shape(self)
37+
total_sample_count = op.ReduceProd(self_shape[dims], keepdims=0)
38+
total_sample_count = op.CastLike(total_sample_count, transformed)
39+
40+
# Normalize the result
41+
# Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
42+
# Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
43+
if normalization == 1:
44+
# "forward" - normalize by 1/n
45+
if forward:
46+
result = op.Div(transformed, op.Sqrt(total_sample_count))
47+
else:
48+
result = op.Mul(transformed, op.Sqrt(total_sample_count))
49+
elif normalization == 2:
50+
# "ortho" - normalize by 1/sqrt(n)
51+
if forward:
52+
result = op.Div(transformed, total_sample_count)
53+
else:
54+
result = transformed
55+
else:
56+
# "backward" - no normalization
57+
if forward:
58+
result = transformed
59+
else:
60+
result = op.Mul(transformed, total_sample_count)
61+
62+
return result
63+
64+
65+
@torch_op(
66+
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
67+
trace_only=True,
68+
private=True,
69+
complex=True,
70+
)
71+
def _fftn_onnx(
72+
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
73+
) -> TFloat:
74+
"""Standard complex to complex or real to complex FFT (forward or backward).
75+
76+
This is a private shared function for implementing the various FFT functions.
77+
78+
Args:
79+
self: The input tensor.
80+
dims: The dimensions to apply FFT.
81+
normalization: The normalization mode.
82+
inverse: Whether to compute the inverse FFT.
83+
onesided: Whether to compute the one-sided FFT, which retains only the
84+
positive frequencies.
85+
86+
Returns:
87+
The transformed tensor.
88+
"""
89+
90+
# NOTE: trace_only because we need to process each dimension in a loop
91+
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
92+
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
93+
94+
# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
95+
# dimension at the beginning to represent the batch dimension.
96+
transformed = op.Unsqueeze(self, axes=[0])
97+
98+
for dim_ in dims:
99+
if dim_ >= 0:
100+
# Add 1 to account for the batch dimension when counting axes from the left
101+
dim_ = dim_ + 1
102+
transformed = op.DFT(transformed, axis=dim_, inverse=inverse, onesided=onesided)
103+
# Remove the batch dimension
104+
transformed = op.Squeeze(transformed, axes=[0])
105+
106+
return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims)
107+
108+
109+
@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
110+
def aten__fft_c2c(
111+
self: TFloat, dim: Sequence[int], normalization: int, forward: bool
112+
) -> TFloat:
113+
"""_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
114+
115+
Standard complex to complex FFT (forward or backward).
116+
"""
117+
118+
# NOTE: trace_only because we need to negate forward
119+
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
120+
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
121+
122+
# ONNX DFT input assumes the last dimension is the complex dimension.
123+
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
124+
dim = [d - 1 if d < 0 else d for d in dim]
125+
return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False)
126+
127+
128+
@torch_op("aten::_fft_c2r", trace_only=True, complex=True)
129+
def aten__fft_c2r(
130+
self: TFloat,
131+
dim: Sequence[int],
132+
normalization: int,
133+
last_dim_size: INT64, # pylint: disable=unused-argument
134+
) -> TFloat:
135+
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
136+
137+
Complex to real inverse FFT.
138+
"""
139+
140+
# TODO(justinchuby): Figure out what last_dim_size does
141+
142+
self_rank = len(self.shape)
143+
# ONNX DFT input assumes the last dimension is the complex dimension.
144+
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
145+
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
146+
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
147+
# Take only the real part
148+
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])
149+
150+
return op.Squeeze(real_part, axes=[-1])
151+
152+
153+
@torch_op("aten::_fft_r2c", trace_only=True)
154+
def aten__fft_r2c(
155+
self: TFloat, dim: Sequence[int], normalization: int, onesided: bool
156+
) -> TFloat:
157+
"""_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
158+
159+
Real to complex forward FFT.
160+
"""
161+
162+
# Add a new dimension at the end
163+
signal = op.Unsqueeze(self, axes=[-1])
164+
# No need to fill the imaginary part because ONNX DFT accepts real inputs
165+
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs
166+
167+
self_rank = len(self.shape)
168+
# ONNX DFT input assumes the last dimension is the complex dimension.
169+
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
170+
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
171+
172+
return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided)
173+
174+
19175
def aten_fft_fft(
20176
self: TensorType, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None
21177
) -> TensorType:

onnxscript/function_libs/torch_lib/registration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def torch_op(
111111
trace_only: Whether the function should only be traced and not compiled.
112112
private: Whether the function is private (not directly exposed). It should
113113
be true for all functions with names starting with "_".
114-
complex: Whether the function supports complex.
114+
complex: Whether the function expects complex-valued inputs.
115115
"""
116116
if registry is None:
117117
registry = default_registry

onnxscript/tests/function_libs/torch_lib/error_reproduction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ def create_mismatch_report(
212212
expected,
213213
error: Exception,
214214
) -> None:
215+
torch.set_printoptions(threshold=sys.maxsize)
216+
215217
error_text = str(error)
216218
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
217219
short_test_name = test_name.split(".")[-1]

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,68 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
190190
)
191191

192192

193+
def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_):
194+
del self # Unused
195+
# Adapted from https://github.com/pytorch/pytorch/blob/01069ad4be449f376cf88a56d842b8eb50f6e9b6/torch/testing/_internal/opinfo/core.py#L2448C1-L2541C79
196+
is_fp16_or_chalf = dtype in (torch.complex32, torch.half)
197+
if not is_fp16_or_chalf:
198+
nd_tensor = functools.partial(
199+
opinfo_core.make_tensor,
200+
(S, S + 1, S + 2),
201+
device=device,
202+
dtype=dtype,
203+
requires_grad=requires_grad,
204+
)
205+
oned_tensor = functools.partial(
206+
opinfo_core.make_tensor,
207+
(31,),
208+
device=device,
209+
dtype=dtype,
210+
requires_grad=requires_grad,
211+
)
212+
else:
213+
low = None
214+
high = None
215+
shapes = ((2, 8, 9), (33,))
216+
217+
nd_tensor = functools.partial(
218+
opinfo_core.make_tensor,
219+
shapes[0],
220+
device=device,
221+
low=low,
222+
high=high,
223+
dtype=dtype,
224+
requires_grad=requires_grad,
225+
)
226+
oned_tensor = functools.partial(
227+
opinfo_core.make_tensor,
228+
shapes[1],
229+
device=device,
230+
low=low,
231+
high=high,
232+
dtype=dtype,
233+
requires_grad=requires_grad,
234+
)
235+
236+
for normalization, forward in itertools.product((0, 1, 2), (True, False)):
237+
# 1-D
238+
yield opinfo_core.SampleInput(
239+
oned_tensor(), dim=(0,), normalization=normalization, forward=forward
240+
)
241+
# N-D
242+
for dim in [
243+
(0,),
244+
(1,),
245+
(2,),
246+
(1, 2),
247+
(0, 1),
248+
(0, 1, 2),
249+
]:
250+
yield opinfo_core.SampleInput(
251+
nd_tensor(), dim=dim, normalization=normalization, forward=forward
252+
)
253+
254+
193255
def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs):
194256
del op_info # unused
195257
del kwargs
@@ -1242,6 +1304,13 @@ def sample_inputs_scaled_dot_product_flash_attention(
12421304
# To avoid name duplication, it is possible to rename the OpInfo and specify
12431305
# the `op` field explicitly.
12441306
OP_DB: List[opinfo_core.OpInfo] = [
1307+
opinfo_core.OpInfo(
1308+
"ops.aten._fft_c2c",
1309+
aten_name="_fft_c2c",
1310+
dtypes=common_dtype.complex_types(),
1311+
sample_inputs_func=sample_inputs__fft_c2c,
1312+
supports_out=False,
1313+
),
12451314
opinfo_core.OpInfo(
12461315
"ops.aten._local_scalar_dense",
12471316
aten_name="_local_scalar_dense",

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
from onnxscript._internal import version_utils
4949
from onnxscript.function_libs.torch_lib.ops import core as core_ops
50+
from onnxscript.function_libs.torch_lib.ops import fft as fft_ops
5051
from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops
5152
from onnxscript.function_libs.torch_lib.ops import nn as nn_ops
5253
from onnxscript.function_libs.torch_lib.ops import special as special_ops
@@ -450,6 +451,13 @@ def _where_input_wrangler(
450451
# Ops to be tested for numerical consistency between onnx and pytorch
451452
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
452453
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
454+
TorchLibOpInfo(
455+
"ops.aten._fft_c2c", # Custom from extra_opinfo
456+
fft_ops.aten__fft_c2c,
457+
tolerance={torch.complex64: (3e-3, 1.8e-4)},
458+
trace_only=True,
459+
complex=True,
460+
),
453461
TorchLibOpInfo(
454462
"ops.aten._local_scalar_dense",
455463
core_ops.aten__local_scalar_dense,

0 commit comments

Comments
 (0)