From 00071ac64c04a814eed00924ff4ccbf891f390f9 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 10 Feb 2023 18:57:13 +0800 Subject: [PATCH 01/10] add native_code --- onnxscript/function_libs/torch_aten/ops/core.py | 9 +++++++-- .../function_libs/torch_aten/ops_correctness_test.py | 6 +++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 25b001be8e..e622c858d3 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3636,6 +3636,7 @@ def aten_narrow_copy(self: TensorType, dim: int, start: INT64, length: INT64) -> raise NotImplementedError() +@torch_op("aten::native_batch_norm", trace_only=True) def aten_native_batch_norm( input: TensorType, weight: Optional[TensorType], @@ -3648,8 +3649,12 @@ def aten_native_batch_norm( ) -> tuple[TensorType, TensorType, TensorType]: # native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) - raise NotImplementedError() - + mode = 0 if not training else 1 + a = opset17.ReduceMean(input, axes=[1,2]) + result = op.BatchNormalization(input, weight, bias, running_mean, running_var, epsilon=eps, momentum=momentum, training_mode=mode) + if not training: + result = result, [], [] + return result def aten_native_batch_norm_backward( grad_out: TensorType, diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 52116197d1..34d4846ecf 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -304,6 +304,7 @@ def _where_input_wrangler( "minimum": core_ops.aten_minimum, "mm": core_ops.aten_mm, "mul": core_ops.aten_mul, + "native_batch_norm": core_ops.aten_native_batch_norm, "ne": core_ops.aten_ne, "neg": core_ops.aten_neg, "new_full": core_ops.aten_new_full, @@ -692,7 +693,10 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): flattened_function_outputs, _ = pytree.tree_flatten(function_output) assert flattened_torch_outputs - assert len(flattened_torch_outputs) == len(flattened_function_outputs) + if op.name == 'native_batch_norm' and not cpu_sample.args[4]: + assert(len(flattened_function_outputs) == 1) + else: + assert len(flattened_torch_outputs) == len(flattened_function_outputs) for torch_output, function_output in zip( flattened_torch_outputs, flattened_function_outputs From 38e84835ec51ef007ff9c9f03d06e434eec0c6ba Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 13 Feb 2023 17:40:21 +0800 Subject: [PATCH 02/10] add ops --- .../function_libs/torch_aten/ops/core.py | 24 +++++++++++++++---- .../torch_aten/ops_correctness_test.py | 3 +++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 87091a4059..1214bf21d2 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1330,10 +1330,12 @@ def aten_convolution_overrideable( raise NotImplementedError() +@torch_op("aten::copy") def aten_copy(self: TensorType, src: TensorType, non_blocking: bool = False) -> TensorType: # copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor - raise NotImplementedError() + self = op.Identity(src) + return self def aten_copysign(self: TensorType, other: TensorType) -> TensorType: @@ -1880,10 +1882,15 @@ def aten_empty_quantized( raise NotImplementedError() +@torch_op("aten::empty_strided") def aten_empty_strided(size: INT64, stride: INT64) -> TensorType: # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - raise NotImplementedError() + # using Zeros to simulate empty() + size = op.Cast(size, to=INT64.dtype) + zero = op.Constant(value_float=0.0) + + return op.Expand(zero, size) @torch_op("aten::eq") @@ -2104,10 +2111,14 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType raise NotImplementedError() +@torch_op("aten::fill") def aten_fill(self: TensorType, value: TensorType) -> TensorType: # fill.Tensor(Tensor self, Tensor value) -> Tensor - raise NotImplementedError() + shape = op.Shape(self) + value = op.Cast(value, to=FLOAT.dtype) # the value might be bool type + result = op.Expand(value, shape) + return result def aten_fix(self: TensorType) -> TensorType: @@ -3735,12 +3746,17 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: raise NotImplementedError() +@torch_op("aten::native_dropout", trace_only=True) def aten_native_dropout( input: TensorType, p: float, train: Optional[bool] ) -> tuple[TensorType, TensorType]: # native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) - raise NotImplementedError() + if not train: + result = input, None + else: + result = op.Dropout(input, p, train) + return result def aten_native_dropout_backward( diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 48258d8959..6aa32b1224 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -264,18 +264,21 @@ def _where_input_wrangler( "clamp_max": core_ops.aten_clamp_max, "clamp_min": core_ops.aten_clamp_min, "clone": core_ops.aten_clone, + # "copy": core_ops.aten_copy, # copy is not in OPS_DB "cos": core_ops.aten_cos, "cosh": core_ops.aten_cosh, # "detach": core_ops.aten_detach, # detach is not in OP-TEST-DB "div": core_ops.aten_div, "dot": core_ops.aten_dot, "empty": core_ops.aten_empty, + # "empty_strided": core_ops.aten_empty_strided, # empty_strided is not in OPS_DB "eq": core_ops.aten_eq, "equal": core_ops.aten_equal, "exp": core_ops.aten_exp, "exp2": core_ops.aten_exp2, "expand": core_ops.aten_expand, "erf": core_ops.aten_erf, + "fill": core_ops.aten_fill, "fmod": core_ops.aten_fmod, "full": (core_ops.aten_full, _full_input_wrangler), "full_like": core_ops.aten_full_like, From e5cab2ea9263b7ccf3659e0fa3f57e3c85b22970 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 13 Feb 2023 17:45:06 +0800 Subject: [PATCH 03/10] update --- onnxscript/function_libs/torch_aten/ops/core.py | 9 ++------- .../function_libs/torch_aten/ops_correctness_test.py | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 1214bf21d2..258426df48 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3703,7 +3703,6 @@ def aten_narrow_copy(self: TensorType, dim: int, start: INT64, length: INT64) -> raise NotImplementedError() -@torch_op("aten::native_batch_norm", trace_only=True) def aten_native_batch_norm( input: TensorType, weight: Optional[TensorType], @@ -3716,12 +3715,8 @@ def aten_native_batch_norm( ) -> tuple[TensorType, TensorType, TensorType]: # native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) - mode = 0 if not training else 1 - a = opset17.ReduceMean(input, axes=[1,2]) - result = op.BatchNormalization(input, weight, bias, running_mean, running_var, epsilon=eps, momentum=momentum, training_mode=mode) - if not training: - result = result, [], [] - return result + raise NotImplementedError() + def aten_native_batch_norm_backward( grad_out: TensorType, diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 6aa32b1224..9da2c95162 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -303,7 +303,7 @@ def _where_input_wrangler( "minimum": core_ops.aten_minimum, "mm": core_ops.aten_mm, "mul": core_ops.aten_mul, - "native_batch_norm": core_ops.aten_native_batch_norm, + # "native_dropout": core_ops.aten_native_dropout, # native_dropout is not in OPS_DB "ne": core_ops.aten_ne, "neg": core_ops.aten_neg, "new_full": core_ops.aten_new_full, From d604aa03a06f522ce5f6c199b33734f038b9aee0 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 13 Feb 2023 17:46:48 +0800 Subject: [PATCH 04/10] Update ops_correctness_test.py --- .../tests/function_libs/torch_aten/ops_correctness_test.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 9da2c95162..32b6a694ec 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -696,10 +696,7 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): flattened_function_outputs, _ = pytree.tree_flatten(function_output) assert flattened_torch_outputs - if op.name == 'native_batch_norm' and not cpu_sample.args[4]: - assert(len(flattened_function_outputs) == 1) - else: - assert len(flattened_torch_outputs) == len(flattened_function_outputs) + assert len(flattened_torch_outputs) == len(flattened_function_outputs) for torch_output, function_output in zip( flattened_torch_outputs, flattened_function_outputs From 527274d08250f8b2542fb81b16017d8e08b3102e Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 13 Feb 2023 18:01:12 +0800 Subject: [PATCH 05/10] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 258426df48..5fd04b9ffa 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1331,7 +1331,9 @@ def aten_convolution_overrideable( @torch_op("aten::copy") -def aten_copy(self: TensorType, src: TensorType, non_blocking: bool = False) -> TensorType: +def aten_copy( + self: TTensor, src: TTensor, non_blocking: bool = False # pylint: disable=unused-argument +) -> TTensor: # copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor self = op.Identity(src) @@ -1883,7 +1885,9 @@ def aten_empty_quantized( @torch_op("aten::empty_strided") -def aten_empty_strided(size: INT64, stride: INT64) -> TensorType: +def aten_empty_strided( + size: INT64, stride: INT64 # pylint: disable=unused-argument +) -> TTensor: # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # using Zeros to simulate empty() @@ -2112,7 +2116,7 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType @torch_op("aten::fill") -def aten_fill(self: TensorType, value: TensorType) -> TensorType: +def aten_fill(self: TTensor, value: TTensor) -> TFloat: # fill.Tensor(Tensor self, Tensor value) -> Tensor shape = op.Shape(self) @@ -3743,8 +3747,8 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: @torch_op("aten::native_dropout", trace_only=True) def aten_native_dropout( - input: TensorType, p: float, train: Optional[bool] -) -> tuple[TensorType, TensorType]: + input: TTensor, p: float, train: Optional[bool] +) -> tuple[TTensor, TTensor]: # native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) if not train: From 0b82867969589ec84fa647b0c036717cc140a1ab Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Mon, 13 Feb 2023 18:36:08 +0800 Subject: [PATCH 06/10] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 5fd04b9ffa..2174ee92eb 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1887,7 +1887,7 @@ def aten_empty_quantized( @torch_op("aten::empty_strided") def aten_empty_strided( size: INT64, stride: INT64 # pylint: disable=unused-argument -) -> TTensor: +) -> TTensor: # type: ignore[type-var] # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # using Zeros to simulate empty() @@ -2116,7 +2116,7 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType @torch_op("aten::fill") -def aten_fill(self: TTensor, value: TTensor) -> TFloat: +def aten_fill(self: TTensor, value: TTensor) -> TTensor: # fill.Tensor(Tensor self, Tensor value) -> Tensor shape = op.Shape(self) From 0618043f75c456492f8c2410d5e28f862308d912 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 14 Feb 2023 13:50:41 +0800 Subject: [PATCH 07/10] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 2174ee92eb..0d158da280 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -2118,10 +2118,10 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType @torch_op("aten::fill") def aten_fill(self: TTensor, value: TTensor) -> TTensor: # fill.Tensor(Tensor self, Tensor value) -> Tensor - + # after fill, the self Tensor should keep origianl type shape = op.Shape(self) - value = op.Cast(value, to=FLOAT.dtype) # the value might be bool type result = op.Expand(value, shape) + result = op.CastLike(result, self) return result @@ -3745,17 +3745,14 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: raise NotImplementedError() -@torch_op("aten::native_dropout", trace_only=True) +@torch_op("aten::native_dropout") def aten_native_dropout( - input: TTensor, p: float, train: Optional[bool] + input: TTensor, p: float, train: bool = True ) -> tuple[TTensor, TTensor]: # native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) - if not train: - result = input, None - else: - result = op.Dropout(input, p, train) - return result + result, mask = op.Dropout(input, p, train) + return result, mask def aten_native_dropout_backward( From b6c8b21e08fc715f7fbc6062cccd544a041db140 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 15 Feb 2023 12:26:48 +0800 Subject: [PATCH 08/10] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index d334137961..f25b084b19 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -2180,10 +2180,11 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType @torch_op("aten::fill") def aten_fill(self: TTensor, value: TTensor) -> TTensor: # fill.Tensor(Tensor self, Tensor value) -> Tensor + # after fill, the self Tensor should keep origianl type shape = op.Shape(self) - result = op.Expand(value, shape) - result = op.CastLike(result, self) + expanded = op.Expand(value, shape) + result = op.CastLike(expanded, self) return result @@ -3810,7 +3811,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: @torch_op("aten::native_dropout") def aten_native_dropout( input: TTensor, p: float, train: bool = True -) -> tuple[TTensor, TTensor]: +) -> Tuple[TTensor, BOOL]: # native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) result, mask = op.Dropout(input, p, train) From d56ea5168ec701a3fd76e773e02abcd6c5624432 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 15 Feb 2023 12:53:43 +0800 Subject: [PATCH 09/10] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index f25b084b19..dd67af3183 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3809,9 +3809,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: @torch_op("aten::native_dropout") -def aten_native_dropout( - input: TTensor, p: float, train: bool = True -) -> Tuple[TTensor, BOOL]: +def aten_native_dropout(input: TTensor, p: float, train: bool = True) -> Tuple[TTensor, BOOL]: # native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) result, mask = op.Dropout(input, p, train) From aa6df76d03cd85ed29aa919f57c598194661c80c Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 16 Feb 2023 12:17:47 +0800 Subject: [PATCH 10/10] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index dd67af3183..0358bb0656 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3809,7 +3809,9 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: @torch_op("aten::native_dropout") -def aten_native_dropout(input: TTensor, p: float, train: bool = True) -> Tuple[TTensor, BOOL]: +def aten_native_dropout( + input: TFloatOrBFloat16, p: float, train: bool = True +) -> Tuple[TFloatOrBFloat16, BOOL]: # native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) result, mask = op.Dropout(input, p, train)