From 320e822001b7a7916da54108f047ae862f514e1c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 10 Apr 2025 10:38:55 -0700 Subject: [PATCH 1/2] [torchlib] Precompute the constant for gelu to avoid precision loss --- onnxscript/function_libs/torch_lib/ops/nn.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4c32f975d5..441ae462df 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -479,7 +479,6 @@ def aten_gelu(self: TReal, approximate: str = "none") -> TReal: return result -@torch_op("aten::gelu", private=True) def _aten_gelu_approximate_none(self: TReal) -> TReal: """gelu(Tensor self, *, str approximate='none') -> Tensor""" @@ -492,7 +491,6 @@ def _aten_gelu_approximate_none(self: TReal) -> TReal: return result -@torch_op("aten::gelu", private=True) def _aten_gelu_approximate_tanh(self: TReal) -> TReal: """gelu(Tensor self, *, str approximate='none') -> Tensor""" @@ -500,9 +498,9 @@ def _aten_gelu_approximate_tanh(self: TReal) -> TReal: cubed = op.Pow(self, 3) inner = op.Mul(0.044715, cubed) inner = op.Add(self, inner) - # Prefer explicit graph construction over precomputed constants for clarity. - two_over_pi = op.CastLike(op.Div(2.0, _MATH_PI), self) - inner = op.Mul(op.Sqrt(two_over_pi), inner) + # math.sqrt(2.0/math.pi) = 0.7978845608028654 + sqrt_two_over_pi = op.CastLike(0.7978845608028654, self) + inner = op.Mul(sqrt_two_over_pi, inner) inner = op.Tanh(inner) inner = op.Add(inner, 1) inner = op.Mul(0.5, inner) From 779c2ac24802d35bae0a38980d1cdb8a51bc804a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 10 Apr 2025 10:44:51 -0700 Subject: [PATCH 2/2] test --- onnxscript/function_libs/torch_lib/ops/nn.py | 16 ++++++++-------- tests/function_libs/torch_lib/ops_test_data.py | 6 +----- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 441ae462df..20127cec88 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -483,10 +483,10 @@ def _aten_gelu_approximate_none(self: TReal) -> TReal: """gelu(Tensor self, *, str approximate='none') -> Tensor""" # GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)] - inner = op.Div(self, 1.4142135623730951) + inner = op.Div(self, ir.tensor(1.4142135623730951, dtype=self.dtype)) erf = op.Erf(inner) - inner = op.Add(erf, 1) - inner = op.Mul(0.5, inner) + inner = op.Add(erf, ir.tensor(1, dtype=self.dtype)) + inner = op.Mul(ir.tensor(0.5, dtype=self.dtype), inner) result = op.Mul(self, inner) return result @@ -495,15 +495,15 @@ def _aten_gelu_approximate_tanh(self: TReal) -> TReal: """gelu(Tensor self, *, str approximate='none') -> Tensor""" # GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]} - cubed = op.Pow(self, 3) - inner = op.Mul(0.044715, cubed) + cubed = op.Pow(self, ir.tensor(3, dtype=self.dtype)) + inner = op.Mul(ir.tensor(0.044715, dtype=self.dtype), cubed) inner = op.Add(self, inner) # math.sqrt(2.0/math.pi) = 0.7978845608028654 - sqrt_two_over_pi = op.CastLike(0.7978845608028654, self) + sqrt_two_over_pi = ir.tensor(0.7978845608028654, dtype=self.dtype) inner = op.Mul(sqrt_two_over_pi, inner) inner = op.Tanh(inner) - inner = op.Add(inner, 1) - inner = op.Mul(0.5, inner) + inner = op.Add(inner, ir.tensor(1, dtype=self.dtype)) + inner = op.Mul(ir.tensor(0.5, dtype=self.dtype), inner) result = op.Mul(self, inner) return result diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4066cb12f1..54e1e8cceb 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1790,11 +1790,7 @@ def _where_input_wrangler( core_ops.aten_conv3d, tolerance={torch.float32: (3.7e-5, 1.8e-4)}, ), - TorchLibOpInfo( - "nn.functional.gelu", - nn_ops.aten_gelu, - tolerance={torch.float16: (8e-2, 1e-4)}, - ), + TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), TorchLibOpInfo( "nn.functional.linear", nn_ops.aten_linear, tolerance={torch.float16: (1e-2, 1e-3)}