|
16 | 16 |
|
17 | 17 | from typing import Optional, Sequence
|
18 | 18 |
|
19 |
| -from onnxscript import INT64 |
| 19 | +from onnxscript import FLOAT, INT64 |
20 | 20 | from onnxscript.function_libs.torch_aten.registration import torch_op
|
21 | 21 | from onnxscript.function_libs.torch_aten.typing import TFloat, TFloatOrBFloat16, TReal
|
22 | 22 | from onnxscript.onnx_opset import opset18 as op
|
@@ -198,10 +198,11 @@ def aten_binary_cross_entropy_backward(
|
198 | 198 | raise NotImplementedError()
|
199 | 199 |
|
200 | 200 |
|
201 |
| -def aten_celu(self, alpha: float = 1.0): |
| 201 | +@torch_op("aten::celu") |
| 202 | +def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT: |
202 | 203 | # celu(Tensor self, Scalar alpha=1.0) -> Tensor
|
203 | 204 |
|
204 |
| - raise NotImplementedError() |
| 205 | + return op.Celu(self, alpha=alpha) # op.Celu only support float32 |
205 | 206 |
|
206 | 207 |
|
207 | 208 | def aten_col2im(
|
@@ -323,10 +324,30 @@ def aten_fractional_max_pool3d_backward(
|
323 | 324 | raise NotImplementedError()
|
324 | 325 |
|
325 | 326 |
|
326 |
| -def aten_gelu(self: TensorType, approximate: str = "none") -> TensorType: |
| 327 | +@torch_op("aten::gelu") |
| 328 | +def aten_gelu(self: TReal, approximate: str = "none") -> TReal: |
327 | 329 | # gelu(Tensor self, *, str approximate='none') -> Tensor
|
328 | 330 |
|
329 |
| - raise NotImplementedError() |
| 331 | + self = op.Cast(self, to=FLOAT.dtype) |
| 332 | + |
| 333 | + if approximate == "tanh": |
| 334 | + # GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]} |
| 335 | + cubed = op.Pow(self, 3) |
| 336 | + inner = op.Mul(0.044715, cubed) |
| 337 | + inner = op.Add(self, inner) |
| 338 | + inner = op.Mul(op.Sqrt(op.Div(2.0, 3.141592653589793)), inner) |
| 339 | + inner = op.Tanh(inner) |
| 340 | + inner = op.Add(inner, 1) |
| 341 | + inner = op.Mul(self, inner) |
| 342 | + result = op.Mul(0.5, inner) |
| 343 | + else: |
| 344 | + # GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)] |
| 345 | + inner = op.Div(self, 1.4142135623730951) |
| 346 | + erf = op.Erf(inner) |
| 347 | + inner = op.Add(erf, 1) |
| 348 | + inner = op.Mul(self, inner) |
| 349 | + result = op.Mul(0.5, inner) |
| 350 | + return result |
330 | 351 |
|
331 | 352 |
|
332 | 353 | def aten_gelu_backward(
|
|
0 commit comments