Skip to content

Commit 34359e8

Browse files
authored
feat(atenlib): add ops(gelu, celu) (#323)
add GELU,CELU op
1 parent 000d657 commit 34359e8

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from typing import Optional, Sequence
1818

19-
from onnxscript import INT64
19+
from onnxscript import FLOAT, INT64
2020
from onnxscript.function_libs.torch_aten.registration import torch_op
2121
from onnxscript.function_libs.torch_aten.typing import TFloat, TFloatOrBFloat16, TReal
2222
from onnxscript.onnx_opset import opset18 as op
@@ -198,10 +198,11 @@ def aten_binary_cross_entropy_backward(
198198
raise NotImplementedError()
199199

200200

201-
def aten_celu(self, alpha: float = 1.0):
201+
@torch_op("aten::celu")
202+
def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT:
202203
# celu(Tensor self, Scalar alpha=1.0) -> Tensor
203204

204-
raise NotImplementedError()
205+
return op.Celu(self, alpha=alpha) # op.Celu only support float32
205206

206207

207208
def aten_col2im(
@@ -323,10 +324,30 @@ def aten_fractional_max_pool3d_backward(
323324
raise NotImplementedError()
324325

325326

326-
def aten_gelu(self: TensorType, approximate: str = "none") -> TensorType:
327+
@torch_op("aten::gelu")
328+
def aten_gelu(self: TReal, approximate: str = "none") -> TReal:
327329
# gelu(Tensor self, *, str approximate='none') -> Tensor
328330

329-
raise NotImplementedError()
331+
self = op.Cast(self, to=FLOAT.dtype)
332+
333+
if approximate == "tanh":
334+
# GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]}
335+
cubed = op.Pow(self, 3)
336+
inner = op.Mul(0.044715, cubed)
337+
inner = op.Add(self, inner)
338+
inner = op.Mul(op.Sqrt(op.Div(2.0, 3.141592653589793)), inner)
339+
inner = op.Tanh(inner)
340+
inner = op.Add(inner, 1)
341+
inner = op.Mul(self, inner)
342+
result = op.Mul(0.5, inner)
343+
else:
344+
# GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)]
345+
inner = op.Div(self, 1.4142135623730951)
346+
erf = op.Erf(inner)
347+
inner = op.Add(erf, 1)
348+
inner = op.Mul(self, inner)
349+
result = op.Mul(0.5, inner)
350+
return result
330351

331352

332353
def aten_gelu_backward(

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,10 @@ def _topk_input_wrangler(
313313
"nn.functional.adaptive_avg_pool1d": nn_ops.aten_adaptive_avg_pool1d,
314314
"nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d,
315315
"nn.functional.adaptive_avg_pool3d": nn_ops.aten_adaptive_avg_pool3d,
316+
"nn.functional.celu": nn_ops.aten_celu,
316317
"nn.functional.elu": nn_ops.aten_elu,
317318
"nn.functional.embedding": core_ops.aten_embedding,
319+
"nn.functional.gelu": nn_ops.aten_gelu,
318320
"nn.functional.leaky_relu": nn_ops.aten_leaky_relu,
319321
"nn.functional.linear": nn_ops.aten_linear,
320322
"nn.functional.logsigmoid": nn_ops.aten_log_sigmoid,

0 commit comments

Comments
 (0)