|
13 | 13 |
|
14 | 14 | from typing import Optional, Sequence
|
15 | 15 |
|
| 16 | +from onnxscript import INT64 |
| 17 | +from onnxscript.function_libs.torch_lib.registration import torch_op |
| 18 | +from onnxscript.function_libs.torch_lib.tensor_typing import TFloat |
| 19 | +from onnxscript.onnx_opset import opset18 as op |
16 | 20 | from onnxscript.onnx_types import TensorType
|
17 | 21 |
|
18 | 22 |
|
| 23 | +@torch_op( |
| 24 | + ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), |
| 25 | + private=True, |
| 26 | + complex=True, |
| 27 | +) |
| 28 | +def _fftn_onnx_normalization( |
| 29 | + self, |
| 30 | + transformed: TFloat, |
| 31 | + normalization: int, |
| 32 | + forward: bool, |
| 33 | + dims: Sequence[int], |
| 34 | +) -> TFloat: |
| 35 | + # Obtain the total_sample_count (n) for normalization |
| 36 | + self_shape = op.Shape(self) |
| 37 | + total_sample_count = op.ReduceProd(self_shape[dims], keepdims=0) |
| 38 | + total_sample_count = op.CastLike(total_sample_count, transformed) |
| 39 | + |
| 40 | + # Normalize the result |
| 41 | + # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn |
| 42 | + # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42 |
| 43 | + if normalization == 1: |
| 44 | + # "forward" - normalize by 1/n |
| 45 | + if forward: |
| 46 | + result = op.Div(transformed, op.Sqrt(total_sample_count)) |
| 47 | + else: |
| 48 | + result = op.Mul(transformed, op.Sqrt(total_sample_count)) |
| 49 | + elif normalization == 2: |
| 50 | + # "ortho" - normalize by 1/sqrt(n) |
| 51 | + if forward: |
| 52 | + result = op.Div(transformed, total_sample_count) |
| 53 | + else: |
| 54 | + result = transformed |
| 55 | + else: |
| 56 | + # "backward" - no normalization |
| 57 | + if forward: |
| 58 | + result = transformed |
| 59 | + else: |
| 60 | + result = op.Mul(transformed, total_sample_count) |
| 61 | + |
| 62 | + return result |
| 63 | + |
| 64 | + |
| 65 | +@torch_op( |
| 66 | + ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), |
| 67 | + trace_only=True, |
| 68 | + private=True, |
| 69 | + complex=True, |
| 70 | +) |
| 71 | +def _fftn_onnx( |
| 72 | + self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool |
| 73 | +) -> TFloat: |
| 74 | + """Standard complex to complex or real to complex FFT (forward or backward). |
| 75 | +
|
| 76 | + This is a private shared function for implementing the various FFT functions. |
| 77 | +
|
| 78 | + Args: |
| 79 | + self: The input tensor. |
| 80 | + dims: The dimensions to apply FFT. |
| 81 | + normalization: The normalization mode. |
| 82 | + inverse: Whether to compute the inverse FFT. |
| 83 | + onesided: Whether to compute the one-sided FFT, which retains only the |
| 84 | + positive frequencies. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + The transformed tensor. |
| 88 | + """ |
| 89 | + |
| 90 | + # NOTE: trace_only because we need to process each dimension in a loop |
| 91 | + # NOTE: SymInt dim is not support because DFT-17 needs a static axis |
| 92 | + # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support |
| 93 | + |
| 94 | + # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new |
| 95 | + # dimension at the beginning to represent the batch dimension. |
| 96 | + transformed = op.Unsqueeze(self, axes=[0]) |
| 97 | + |
| 98 | + for dim_ in dims: |
| 99 | + if dim_ >= 0: |
| 100 | + # Add 1 to account for the batch dimension when counting axes from the left |
| 101 | + dim_ = dim_ + 1 |
| 102 | + transformed = op.DFT(transformed, axis=dim_, inverse=inverse, onesided=onesided) |
| 103 | + # Remove the batch dimension |
| 104 | + transformed = op.Squeeze(transformed, axes=[0]) |
| 105 | + |
| 106 | + return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims) |
| 107 | + |
| 108 | + |
| 109 | +@torch_op("aten::_fft_c2c", trace_only=True, complex=True) |
| 110 | +def aten__fft_c2c( |
| 111 | + self: TFloat, dim: Sequence[int], normalization: int, forward: bool |
| 112 | +) -> TFloat: |
| 113 | + """_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor |
| 114 | +
|
| 115 | + Standard complex to complex FFT (forward or backward). |
| 116 | + """ |
| 117 | + |
| 118 | + # NOTE: trace_only because we need to negate forward |
| 119 | + # NOTE: SymInt dim is not support because DFT-17 needs a static axis |
| 120 | + # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support |
| 121 | + |
| 122 | + # ONNX DFT input assumes the last dimension is the complex dimension. |
| 123 | + # Thus dim=-1 in PyTorch is dim=-2 in ONNX. |
| 124 | + dim = [d - 1 if d < 0 else d for d in dim] |
| 125 | + return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False) |
| 126 | + |
| 127 | + |
| 128 | +@torch_op("aten::_fft_c2r", trace_only=True, complex=True) |
| 129 | +def aten__fft_c2r( |
| 130 | + self: TFloat, |
| 131 | + dim: Sequence[int], |
| 132 | + normalization: int, |
| 133 | + last_dim_size: INT64, # pylint: disable=unused-argument |
| 134 | +) -> TFloat: |
| 135 | + """_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor |
| 136 | +
|
| 137 | + Complex to real inverse FFT. |
| 138 | + """ |
| 139 | + |
| 140 | + # TODO(justinchuby): Figure out what last_dim_size does |
| 141 | + |
| 142 | + self_rank = len(self.shape) |
| 143 | + # ONNX DFT input assumes the last dimension is the complex dimension. |
| 144 | + # Thus dim=-1 in PyTorch is dim=-2 in ONNX. |
| 145 | + dim = [(d - 1) + self_rank if d < 0 else d for d in dim] |
| 146 | + transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) |
| 147 | + # Take only the real part |
| 148 | + real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1]) |
| 149 | + |
| 150 | + return op.Squeeze(real_part, axes=[-1]) |
| 151 | + |
| 152 | + |
| 153 | +@torch_op("aten::_fft_r2c", trace_only=True) |
| 154 | +def aten__fft_r2c( |
| 155 | + self: TFloat, dim: Sequence[int], normalization: int, onesided: bool |
| 156 | +) -> TFloat: |
| 157 | + """_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor |
| 158 | +
|
| 159 | + Real to complex forward FFT. |
| 160 | + """ |
| 161 | + |
| 162 | + # Add a new dimension at the end |
| 163 | + signal = op.Unsqueeze(self, axes=[-1]) |
| 164 | + # No need to fill the imaginary part because ONNX DFT accepts real inputs |
| 165 | + # https://onnx.ai/onnx/operators/onnx__DFT.html#inputs |
| 166 | + |
| 167 | + self_rank = len(self.shape) |
| 168 | + # ONNX DFT input assumes the last dimension is the complex dimension. |
| 169 | + # Thus dim=-1 in PyTorch is dim=-2 in ONNX. |
| 170 | + dim = [(d - 1) + self_rank if d < 0 else d for d in dim] |
| 171 | + |
| 172 | + return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided) |
| 173 | + |
| 174 | + |
19 | 175 | def aten_fft_fft(
|
20 | 176 | self: TensorType, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None
|
21 | 177 | ) -> TensorType:
|
|
0 commit comments