Skip to content
31 changes: 26 additions & 5 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Optional, Sequence

from onnxscript import INT64
from onnxscript import FLOAT, INT64
from onnxscript.function_libs.torch_aten.registration import torch_op
from onnxscript.function_libs.torch_aten.typing import TFloat, TFloatOrBFloat16, TReal
from onnxscript.onnx_opset import opset18 as op
Expand Down Expand Up @@ -198,10 +198,11 @@ def aten_binary_cross_entropy_backward(
raise NotImplementedError()


def aten_celu(self, alpha: float = 1.0):
@torch_op("aten::celu")
def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT:
# celu(Tensor self, Scalar alpha=1.0) -> Tensor

raise NotImplementedError()
return op.Celu(self, alpha=alpha) # op.Celu only support float32


def aten_col2im(
Expand Down Expand Up @@ -323,10 +324,30 @@ def aten_fractional_max_pool3d_backward(
raise NotImplementedError()


def aten_gelu(self: TensorType, approximate: str = "none") -> TensorType:
@torch_op("aten::gelu")
def aten_gelu(self: TReal, approximate: str = "none") -> TReal:
# gelu(Tensor self, *, str approximate='none') -> Tensor

raise NotImplementedError()
self = op.Cast(self, to=FLOAT.dtype)

if approximate == "tanh":
# GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]}
cubed = op.Pow(self, 3)
inner = op.Mul(0.044715, cubed)
inner = op.Add(self, inner)
inner = op.Mul(op.Sqrt(op.Div(2.0, 3.141592653589793)), inner)
inner = op.Tanh(inner)
inner = op.Add(inner, 1)
inner = op.Mul(self, inner)
result = op.Mul(0.5, inner)
else:
# GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)]
inner = op.Div(self, 1.4142135623730951)
erf = op.Erf(inner)
inner = op.Add(erf, 1)
inner = op.Mul(self, inner)
result = op.Mul(0.5, inner)
return result


def aten_gelu_backward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,10 @@ def _topk_input_wrangler(
"nn.functional.adaptive_avg_pool1d": nn_ops.aten_adaptive_avg_pool1d,
"nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d,
"nn.functional.adaptive_avg_pool3d": nn_ops.aten_adaptive_avg_pool3d,
"nn.functional.celu": nn_ops.aten_celu,
"nn.functional.elu": nn_ops.aten_elu,
"nn.functional.embedding": core_ops.aten_embedding,
"nn.functional.gelu": nn_ops.aten_gelu,
"nn.functional.leaky_relu": nn_ops.aten_leaky_relu,
"nn.functional.linear": nn_ops.aten_linear,
"nn.functional.logsigmoid": nn_ops.aten_log_sigmoid,
Expand Down