Skip to content

Commit 005568a

Browse files
authored
[torchlib] Precompute the constant for gelu to avoid precision loss (#2179)
I think this improves accuracy for gelu under float16.
1 parent 8f71f1a commit 005568a

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -479,33 +479,31 @@ def aten_gelu(self: TReal, approximate: str = "none") -> TReal:
479479
return result
480480

481481

482-
@torch_op("aten::gelu", private=True)
483482
def _aten_gelu_approximate_none(self: TReal) -> TReal:
484483
"""gelu(Tensor self, *, str approximate='none') -> Tensor"""
485484

486485
# GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)]
487-
inner = op.Div(self, 1.4142135623730951)
486+
inner = op.Div(self, ir.tensor(1.4142135623730951, dtype=self.dtype))
488487
erf = op.Erf(inner)
489-
inner = op.Add(erf, 1)
490-
inner = op.Mul(0.5, inner)
488+
inner = op.Add(erf, ir.tensor(1, dtype=self.dtype))
489+
inner = op.Mul(ir.tensor(0.5, dtype=self.dtype), inner)
491490
result = op.Mul(self, inner)
492491
return result
493492

494493

495-
@torch_op("aten::gelu", private=True)
496494
def _aten_gelu_approximate_tanh(self: TReal) -> TReal:
497495
"""gelu(Tensor self, *, str approximate='none') -> Tensor"""
498496

499497
# GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]}
500-
cubed = op.Pow(self, 3)
501-
inner = op.Mul(0.044715, cubed)
498+
cubed = op.Pow(self, ir.tensor(3, dtype=self.dtype))
499+
inner = op.Mul(ir.tensor(0.044715, dtype=self.dtype), cubed)
502500
inner = op.Add(self, inner)
503-
# Prefer explicit graph construction over precomputed constants for clarity.
504-
two_over_pi = op.CastLike(op.Div(2.0, _MATH_PI), self)
505-
inner = op.Mul(op.Sqrt(two_over_pi), inner)
501+
# math.sqrt(2.0/math.pi) = 0.7978845608028654
502+
sqrt_two_over_pi = ir.tensor(0.7978845608028654, dtype=self.dtype)
503+
inner = op.Mul(sqrt_two_over_pi, inner)
506504
inner = op.Tanh(inner)
507-
inner = op.Add(inner, 1)
508-
inner = op.Mul(0.5, inner)
505+
inner = op.Add(inner, ir.tensor(1, dtype=self.dtype))
506+
inner = op.Mul(ir.tensor(0.5, dtype=self.dtype), inner)
509507
result = op.Mul(self, inner)
510508
return result
511509

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,11 +1790,7 @@ def _where_input_wrangler(
17901790
core_ops.aten_conv3d,
17911791
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
17921792
),
1793-
TorchLibOpInfo(
1794-
"nn.functional.gelu",
1795-
nn_ops.aten_gelu,
1796-
tolerance={torch.float16: (8e-2, 1e-4)},
1797-
),
1793+
TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu),
17981794
TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu),
17991795
TorchLibOpInfo(
18001796
"nn.functional.linear", nn_ops.aten_linear, tolerance={torch.float16: (1e-2, 1e-3)}

0 commit comments

Comments
 (0)