Skip to content

Implement _fft_* ops | feat(torchlib) #926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1d8c640
c2c
justinchuby Jul 27, 2023
4403ab8
fft
justinchuby Jul 27, 2023
362c637
_fft_c2r
justinchuby Jul 27, 2023
43c963d
complex
justinchuby Jul 27, 2023
e05b56c
FFT
justinchuby Jul 27, 2023
253f511
docs
justinchuby Jul 27, 2023
ff87572
update
justinchuby Jul 27, 2023
8999327
link
justinchuby Jul 27, 2023
3b2a9dd
"aten::_fft_c2r"
justinchuby Jul 27, 2023
8dc00a4
Update onnxscript/function_libs/torch_lib/ops/fft.py
justinchuby Jul 27, 2023
a86acfd
Update fft.py
justinchuby Jul 27, 2023
ba62df4
Format
justinchuby Jul 27, 2023
acfc859
Test
justinchuby Jul 28, 2023
936ced4
todo
justinchuby Jul 28, 2023
dfc775c
syntax
justinchuby Jul 28, 2023
a458ed6
test
justinchuby Jul 28, 2023
7403c7e
Merge branch 'main' into justinchu/fft
justinchuby Jul 28, 2023
84b5c79
Merge branch 'main' into justinchu/fft
justinchuby Aug 1, 2023
a72b212
Merge branch 'main' into justinchu/fft
justinchuby Aug 10, 2023
8d394f0
show full torch output
justinchuby Aug 10, 2023
9e4224a
print all tensor
justinchuby Aug 10, 2023
68f8871
revert
justinchuby Aug 10, 2023
0f0fe17
lint
justinchuby Aug 10, 2023
1adb4fa
fix
justinchuby Aug 10, 2023
681a693
batch
justinchuby Aug 10, 2023
946e37d
Merge branch 'main' into justinchu/fft
justinchuby Sep 21, 2023
35797be
Some fixes on axis
justinchuby Sep 21, 2023
2c88d06
Correct the tests
fatcat-z Oct 26, 2023
0844b90
Merge branch 'main' into justinchu/fft
fatcat-z Oct 26, 2023
45522c0
Remove unnecessary code.
fatcat-z Oct 26, 2023
bd6f792
Merge branch 'justinchu/fft' of https://github.com/microsoft/onnxscri…
fatcat-z Oct 26, 2023
de1a8f0
Update the tolaence for fft test cases.
fatcat-z Oct 26, 2023
0bd5688
Add comments for normalization type.
fatcat-z Oct 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,118 @@

from typing import Optional, Sequence

from onnxscript import INT64
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "_fft_r2c"),
trace_only=True,
private=True,
complex=True,
)
def _fftn_onnx(
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
) -> TFloat:
"""Standard complex to complex or real to complex FFT (forward or backward).

This is a private shared function for implementing the various FFT functions.

Args:
self: The input tensor.
dims: The dimensions to apply FFT.
normalization: The normalization mode.
inverse: Whether to compute the inverse FFT.
onesided: Whether to compute the one-sided FFT, which retains only the
positive frequencies.

Returns:
The transformed tensor.
"""

# NOTE: trace_only because we need to process each dimension in a loop
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support

transformed = self
for dim_ in dims:
transformed = op.DFT(transformed, axis=dim_, inverse=inverse, onesided=onesided)

# Obtain the total_sample_count (n) for normalization
total_sample_count = op.Constant(value_int=1)
self_shape = op.Shape(self)
for dim_ in dims:
total_sample_count = op.Mul(total_sample_count, self_shape[dim_])

total_sample_count = op.CastLike(total_sample_count, transformed)
# Normalize the result
# Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
if normalization == 1:
# forward - normalize by 1/2
result = op.Div(transformed, total_sample_count)
elif normalization == 2:
# backward - no normalization
result = transformed
else: # normalization == 3:
# ortho - normalize by 1/sqrt(n)
result = op.Div(transformed, op.Sqrt(total_sample_count))

return result


@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
def aten__fft_c2c(
self: TFloat, dim: Sequence[int], normalization: int, forward: bool
) -> TFloat:
"""_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor

Standard complex to complex FFT (forward or backward).
"""

# NOTE: trace_only because we need to negate forward
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False)


@torch_op("aten::_fft_c2r", trace_only=True, complex=True)
def aten__fft_c2r(
self: TFloat, dim: Sequence[int], normalization: int, last_dim_size: INT64
) -> TFloat:
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor

Complex to real inverse FFT.
"""

# TODO(justinchuby): Figure out what last_dim_size does

transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
# Take only the real part
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])

return op.Squeeze(real_part, axes=[-1])


@torch_op("aten::_fft_r2c", trace_only=True)
def aten__fft_r2c(
self: TFloat, dim: Sequence[int], normalization: int, onesided: bool
) -> TFloat:
"""_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor

Real to complex forward FFT.
"""

# Add a new dimension at the end
self = op.Unsqueeze(self, axes=[-1])
# No need to fill the imaginary part because ONNX DFT accepts real inputs
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs

return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided)


def aten_fft_fft(
self: TensorType, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None
) -> TensorType:
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def torch_op(
trace_only: Whether the function should only be traced and not compiled.
private: Whether the function is private (not directly exposed). It should
be true for all functions with names starting with "_".
complex: Whether the function supports complex.
complex: Whether the function expects complex-valued inputs.
"""
if registry is None:
registry = default_registry
Expand Down