From 774b989b2f54612ab26888dda8735d16ae395aa2 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 16 Jan 2023 16:49:30 +0800 Subject: [PATCH 01/13] add gelu --- onnxscript/function_libs/torch_aten/ops/nn.py | 27 +++++++++++++++++-- onnxscript/values.py | 2 ++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 887123c5c3..5822192c6c 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -16,7 +16,7 @@ from typing import Optional, Sequence -from onnxscript import INT64 +from onnxscript import INT64, DOUBLE, FLOAT from onnxscript.function_libs.torch_aten.registration import torch_op from onnxscript.function_libs.torch_aten.typing import TFloat, TFloatOrBFloat16, TReal from onnxscript.onnx_opset import opset18 as op @@ -323,10 +323,33 @@ def aten_fractional_max_pool3d_backward( raise NotImplementedError() +@torch_op("aten::gelu") def aten_gelu(self: TensorType, approximate: str = "none") -> TensorType: # gelu(Tensor self, *, str approximate='none') -> Tensor - raise NotImplementedError() + self = op.Cast(self, to=FLOAT.dtype) + + if approximate == "tanh": + # GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]} + inner1 = op.Div(2.0, 3.141592653589793) + inner1 = op.Sqrt(inner1) + #inner1 = op.Cast(inner1, to=DOUBLE.dtype) + self_cube = op.Pow(self, 3) + inner = op.Mul(0.044715, self_cube) + inner = op.Add(self, inner) + inner = op.Mul(inner1, inner) + inner = op.Tanh(inner) + inner = op.Add(inner, 1) + inner = op.Mul(self, inner) + result = op.Mul(0.5, inner) + else: + # GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)] + inner = op.Div(self, 1.4142135623730951) + erf = op.Erf(inner) + inner = op.Add(erf, 1) + inner = op.Mul(self, inner) + result = op.Mul(0.5, inner) + return result def aten_gelu_backward( diff --git a/onnxscript/values.py b/onnxscript/values.py index 87e55dfae6..27d3c956f9 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -213,6 +213,8 @@ def adapt(input: ExtendedModeValue) -> EagerModeValue: return input elif isinstance(input, (bool, int, float)): return tensor.Tensor(np.array(input)) + elif isinstance(input, str): + return input elif input is None: return None elif isinstance(input, list): From 8480e10e1983e783bccb724fcc098f3910d0ab39 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 16 Jan 2023 16:52:27 +0800 Subject: [PATCH 02/13] Update ops_correctness_test.py --- onnxscript/test/function_libs/torch_aten/ops_correctness_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 48980f578a..06ef045722 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -287,6 +287,7 @@ def _topk_input_wrangler( "full": (core_ops.aten_full, _full_input_wrangler), "full_like": core_ops.aten_full_like, "ge": core_ops.aten_ge, + "gelu": nn_ops.aten_gelu, "gt": core_ops.aten_gt, "isinf": core_ops.aten_isinf, "log": core_ops.aten_log, From e64ced36f17ae7c42b3a59ab27d4603cfd483505 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 16 Jan 2023 17:16:08 +0800 Subject: [PATCH 03/13] fix lint --- onnxscript/function_libs/torch_aten/ops/nn.py | 4 +--- .../test/function_libs/torch_aten/ops_correctness_test.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 5822192c6c..68e2bac265 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -16,7 +16,7 @@ from typing import Optional, Sequence -from onnxscript import INT64, DOUBLE, FLOAT +from onnxscript import FLOAT, INT64 from onnxscript.function_libs.torch_aten.registration import torch_op from onnxscript.function_libs.torch_aten.typing import TFloat, TFloatOrBFloat16, TReal from onnxscript.onnx_opset import opset18 as op @@ -333,7 +333,6 @@ def aten_gelu(self: TensorType, approximate: str = "none") -> TensorType: # GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]} inner1 = op.Div(2.0, 3.141592653589793) inner1 = op.Sqrt(inner1) - #inner1 = op.Cast(inner1, to=DOUBLE.dtype) self_cube = op.Pow(self, 3) inner = op.Mul(0.044715, self_cube) inner = op.Add(self, inner) @@ -343,7 +342,6 @@ def aten_gelu(self: TensorType, approximate: str = "none") -> TensorType: inner = op.Mul(self, inner) result = op.Mul(0.5, inner) else: - # GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)] inner = op.Div(self, 1.4142135623730951) erf = op.Erf(inner) inner = op.Add(erf, 1) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 06ef045722..8d0663197e 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -287,7 +287,6 @@ def _topk_input_wrangler( "full": (core_ops.aten_full, _full_input_wrangler), "full_like": core_ops.aten_full_like, "ge": core_ops.aten_ge, - "gelu": nn_ops.aten_gelu, "gt": core_ops.aten_gt, "isinf": core_ops.aten_isinf, "log": core_ops.aten_log, @@ -315,6 +314,7 @@ def _topk_input_wrangler( "nn.functional.adaptive_avg_pool3d": nn_ops.aten_adaptive_avg_pool3d, "nn.functional.elu": nn_ops.aten_elu, "nn.functional.embedding": core_ops.aten_embedding, + "nn.functional.gelu": nn_ops.aten_gelu, "nn.functional.leaky_relu": nn_ops.aten_leaky_relu, "nn.functional.linear": nn_ops.aten_linear, "nn.functional.logsigmoid": nn_ops.aten_log_sigmoid, From b952cd47151763a992e94342056d33b03c41ff11 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 16 Jan 2023 17:17:29 +0800 Subject: [PATCH 04/13] Update nn.py --- onnxscript/function_libs/torch_aten/ops/nn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 68e2bac265..90f449481f 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -342,6 +342,7 @@ def aten_gelu(self: TensorType, approximate: str = "none") -> TensorType: inner = op.Mul(self, inner) result = op.Mul(0.5, inner) else: + # GELU(x) = 0.5 * x * [1 + ERF(x/sqrt(2)] inner = op.Div(self, 1.4142135623730951) erf = op.Erf(inner) inner = op.Add(erf, 1) From 0b03b525a6f287316f38997f35c5bcae9a5a453a Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 16 Jan 2023 17:25:54 +0800 Subject: [PATCH 05/13] add celu --- onnxscript/function_libs/torch_aten/ops/nn.py | 7 ++++--- .../test/function_libs/torch_aten/ops_correctness_test.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 90f449481f..629389e2ca 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -198,10 +198,11 @@ def aten_binary_cross_entropy_backward( raise NotImplementedError() -def aten_celu(self, alpha: float = 1.0): +@torch_op("aten::celu") +def aten_celu(self: TReal, alpha: float = 1.0) -> TReal: # celu(Tensor self, Scalar alpha=1.0) -> Tensor - raise NotImplementedError() + return op.Celu(self, alpha=alpha) def aten_col2im( @@ -324,7 +325,7 @@ def aten_fractional_max_pool3d_backward( @torch_op("aten::gelu") -def aten_gelu(self: TensorType, approximate: str = "none") -> TensorType: +def aten_gelu(self: TReal, approximate: str = "none") -> TReal: # gelu(Tensor self, *, str approximate='none') -> Tensor self = op.Cast(self, to=FLOAT.dtype) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 8d0663197e..acb023e64d 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -312,6 +312,7 @@ def _topk_input_wrangler( "nn.functional.adaptive_avg_pool1d": nn_ops.aten_adaptive_avg_pool1d, "nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d, "nn.functional.adaptive_avg_pool3d": nn_ops.aten_adaptive_avg_pool3d, + "nn.functional.celu": nn_ops.aten_celu, "nn.functional.elu": nn_ops.aten_elu, "nn.functional.embedding": core_ops.aten_embedding, "nn.functional.gelu": nn_ops.aten_gelu, From 07223f11c179fc128f69a3e440ca10635334097f Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 16 Jan 2023 17:26:41 +0800 Subject: [PATCH 06/13] Update nn.py --- onnxscript/function_libs/torch_aten/ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 629389e2ca..90a05feebe 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -199,7 +199,7 @@ def aten_binary_cross_entropy_backward( @torch_op("aten::celu") -def aten_celu(self: TReal, alpha: float = 1.0) -> TReal: +def aten_celu(self: TFloat, alpha: float = 1.0) -> TFloat: # celu(Tensor self, Scalar alpha=1.0) -> Tensor return op.Celu(self, alpha=alpha) From 5bcc86d1a03ab31ca73393c022826d33b29514cf Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 16 Jan 2023 18:51:38 +0800 Subject: [PATCH 07/13] Update nn.py --- onnxscript/function_libs/torch_aten/ops/nn.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 90a05feebe..1063071f9f 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -199,10 +199,13 @@ def aten_binary_cross_entropy_backward( @torch_op("aten::celu") -def aten_celu(self: TFloat, alpha: float = 1.0) -> TFloat: +def aten_celu(self: TFloat, alpha: float = 1.0, dtype = FLOAT.dtype) -> TFloat: # celu(Tensor self, Scalar alpha=1.0) -> Tensor - return op.Celu(self, alpha=alpha) + self = op.Cast(self, to=FLOAT.dtype) + result = op.Celu(self, alpha=alpha) + result = op.Cast(result, to=dtype) + return result def aten_col2im( From 78d4c39c08b2814f0ed95e19820ca33928ac92e7 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 16 Jan 2023 18:52:00 +0800 Subject: [PATCH 08/13] Update nn.py --- onnxscript/function_libs/torch_aten/ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 1063071f9f..f882915968 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -203,7 +203,7 @@ def aten_celu(self: TFloat, alpha: float = 1.0, dtype = FLOAT.dtype) -> TFloat: # celu(Tensor self, Scalar alpha=1.0) -> Tensor self = op.Cast(self, to=FLOAT.dtype) - result = op.Celu(self, alpha=alpha) + result = op.Celu(self, alpha=alpha) # op.Celu only support float32 result = op.Cast(result, to=dtype) return result From 086d01a55e0363f34518070800463c73b67f6770 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 16 Jan 2023 19:13:25 +0800 Subject: [PATCH 09/13] Update nn.py --- onnxscript/function_libs/torch_aten/ops/nn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index f882915968..3cec97d928 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -199,12 +199,13 @@ def aten_binary_cross_entropy_backward( @torch_op("aten::celu") -def aten_celu(self: TFloat, alpha: float = 1.0, dtype = FLOAT.dtype) -> TFloat: +def aten_celu(self: TFloat, alpha: float = 1.0, dtype: int = -1) -> TFloat: # celu(Tensor self, Scalar alpha=1.0) -> Tensor self = op.Cast(self, to=FLOAT.dtype) result = op.Celu(self, alpha=alpha) # op.Celu only support float32 - result = op.Cast(result, to=dtype) + if dtype != -1: + result = op.Cast(result, to=dtype) return result From cb978df094c80fd00762f4ff006b23eee9ac3987 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 17 Jan 2023 15:00:16 +0800 Subject: [PATCH 10/13] Update nn.py --- onnxscript/function_libs/torch_aten/ops/nn.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 3cec97d928..511218e2b2 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -199,14 +199,10 @@ def aten_binary_cross_entropy_backward( @torch_op("aten::celu") -def aten_celu(self: TFloat, alpha: float = 1.0, dtype: int = -1) -> TFloat: +def aten_celu(self: TFloat, alpha: float = 1.0) -> TFloat: # celu(Tensor self, Scalar alpha=1.0) -> Tensor - self = op.Cast(self, to=FLOAT.dtype) - result = op.Celu(self, alpha=alpha) # op.Celu only support float32 - if dtype != -1: - result = op.Cast(result, to=dtype) - return result + return op.Celu(self, alpha=alpha) # op.Celu only support float32 def aten_col2im( From ccdfe2e6c9b9b56d7a97ad1ca2efebac24bdcc51 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 17 Jan 2023 15:33:38 +0800 Subject: [PATCH 11/13] Update nn.py --- onnxscript/function_libs/torch_aten/ops/nn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index 511218e2b2..ff720a0ce2 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -332,12 +332,10 @@ def aten_gelu(self: TReal, approximate: str = "none") -> TReal: if approximate == "tanh": # GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]} - inner1 = op.Div(2.0, 3.141592653589793) - inner1 = op.Sqrt(inner1) self_cube = op.Pow(self, 3) inner = op.Mul(0.044715, self_cube) inner = op.Add(self, inner) - inner = op.Mul(inner1, inner) + inner = op.Mul(op.Sqrt(op.Div(2.0, 3.141592653589793)), inner) inner = op.Tanh(inner) inner = op.Add(inner, 1) inner = op.Mul(self, inner) From 4a5da7bb3f3b6b8d5c1b61cbef1db250674ab245 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 18 Jan 2023 08:44:00 +0800 Subject: [PATCH 12/13] Update nn.py --- onnxscript/function_libs/torch_aten/ops/nn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index ff720a0ce2..088cfec3b0 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -199,7 +199,7 @@ def aten_binary_cross_entropy_backward( @torch_op("aten::celu") -def aten_celu(self: TFloat, alpha: float = 1.0) -> TFloat: +def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT: # celu(Tensor self, Scalar alpha=1.0) -> Tensor return op.Celu(self, alpha=alpha) # op.Celu only support float32 @@ -332,8 +332,8 @@ def aten_gelu(self: TReal, approximate: str = "none") -> TReal: if approximate == "tanh": # GELU(x) = 0.5 * x * {1 + Tanh[\sqrt(2/pi) * (x + 0.044715 * x^3)]} - self_cube = op.Pow(self, 3) - inner = op.Mul(0.044715, self_cube) + cubed = op.Pow(self, 3) + inner = op.Mul(0.044715, cubed) inner = op.Add(self, inner) inner = op.Mul(op.Sqrt(op.Div(2.0, 3.141592653589793)), inner) inner = op.Tanh(inner) From 18619e238233951be1843f5adcfff13bca5703a8 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 18 Jan 2023 13:32:30 +0800 Subject: [PATCH 13/13] Update values.py --- onnxscript/values.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxscript/values.py b/onnxscript/values.py index 27d3c956f9..87e55dfae6 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -213,8 +213,6 @@ def adapt(input: ExtendedModeValue) -> EagerModeValue: return input elif isinstance(input, (bool, int, float)): return tensor.Tensor(np.array(input)) - elif isinstance(input, str): - return input elif input is None: return None elif isinstance(input, list):