@@ -479,33 +479,31 @@ def aten_gelu(self: TReal, approximate: str = "none") -> TReal:
479
479
return result
480
480
481
481
482
- @torch_op ("aten::gelu" , private = True )
483
482
def _aten_gelu_approximate_none (self : TReal ) -> TReal :
484
483
"""gelu(Tensor self, *, str approximate='none') -> Tensor"""
485
484
486
485
# GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)]
487
- inner = op .Div (self , 1.4142135623730951 )
486
+ inner = op .Div (self , ir . tensor ( 1.4142135623730951 , dtype = self . dtype ) )
488
487
erf = op .Erf (inner )
489
- inner = op .Add (erf , 1 )
490
- inner = op .Mul (0.5 , inner )
488
+ inner = op .Add (erf , ir . tensor ( 1 , dtype = self . dtype ) )
489
+ inner = op .Mul (ir . tensor ( 0.5 , dtype = self . dtype ) , inner )
491
490
result = op .Mul (self , inner )
492
491
return result
493
492
494
493
495
- @torch_op ("aten::gelu" , private = True )
496
494
def _aten_gelu_approximate_tanh (self : TReal ) -> TReal :
497
495
"""gelu(Tensor self, *, str approximate='none') -> Tensor"""
498
496
499
497
# GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]}
500
- cubed = op .Pow (self , 3 )
501
- inner = op .Mul (0.044715 , cubed )
498
+ cubed = op .Pow (self , ir . tensor ( 3 , dtype = self . dtype ) )
499
+ inner = op .Mul (ir . tensor ( 0.044715 , dtype = self . dtype ) , cubed )
502
500
inner = op .Add (self , inner )
503
- # Prefer explicit graph construction over precomputed constants for clarity.
504
- two_over_pi = op . CastLike ( op . Div ( 2.0 , _MATH_PI ), self )
505
- inner = op .Mul (op . Sqrt ( two_over_pi ) , inner )
501
+ # math.sqrt(2.0/math.pi) = 0.7978845608028654
502
+ sqrt_two_over_pi = ir . tensor ( 0.7978845608028654 , dtype = self . dtype )
503
+ inner = op .Mul (sqrt_two_over_pi , inner )
506
504
inner = op .Tanh (inner )
507
- inner = op .Add (inner , 1 )
508
- inner = op .Mul (0.5 , inner )
505
+ inner = op .Add (inner , ir . tensor ( 1 , dtype = self . dtype ) )
506
+ inner = op .Mul (ir . tensor ( 0.5 , dtype = self . dtype ) , inner )
509
507
result = op .Mul (self , inner )
510
508
return result
511
509
0 commit comments