From a5cc251d15169dcb4f173868bf5ed4dc84e70c5e Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Thu, 5 Jan 2023 13:35:08 +0800 Subject: [PATCH 1/8] Add a placeholder for op assignment. --- onnxscript/function_libs/torch_aten/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index a9d3ea7d51..f5fdf12371 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3683,7 +3683,7 @@ def aten_pdist(self: TensorType, p: float = 2) -> TensorType: def aten_permute(self: TensorType, dims: Sequence[int]) -> TensorType: # permute(Tensor(a) self, int[] dims) -> Tensor(a) - + # TODO(jiz): Start implementation raise NotImplementedError() From 4c45d0d9a93617ba92f30a4653d0ac4eb0aba4f5 Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Mon, 9 Jan 2023 15:33:14 +0800 Subject: [PATCH 2/8] Fix a tensor issue and add 2 new ops. Signed-off-by: Jay Zhang --- onnxscript/function_libs/torch_aten/ops/core.py | 14 +++++++++++--- .../torch_aten/ops_correctness_test.py | 2 ++ onnxscript/values.py | 10 +++++++++- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index f5fdf12371..2ddb8a3b1f 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -3681,10 +3681,11 @@ def aten_pdist(self: TensorType, p: float = 2) -> TensorType: raise NotImplementedError() -def aten_permute(self: TensorType, dims: Sequence[int]) -> TensorType: +@torch_op("aten::permute") +def aten_permute(self: TTensor, dims: Sequence[int]) -> TTensor: # permute(Tensor(a) self, int[] dims) -> Tensor(a) - # TODO(jiz): Start implementation - raise NotImplementedError() + + return op.Transpose(self, perm=dims) def aten_permute_copy(self: TensorType, dims: Sequence[int]) -> TensorType: @@ -3754,6 +3755,13 @@ def aten_positive(self: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::pow") +def aten_pow(self: TReal, exponent: TTensor) -> TReal: + # pow(Tensor self, Tensor exponent) -> Tensor + + return op.Pow(self, exponent) + + def aten_prelu(self: TensorType, weight: TensorType) -> TensorType: # prelu(Tensor self, Tensor weight) -> Tensor diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index d5e7e4be79..d84434075c 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -292,6 +292,8 @@ def _topk_input_wrangler( "nonzero": core_ops.aten_nonzero, "ones_like": core_ops.aten_ones_like, "ones": core_ops.aten_ones, + "permute": core_ops.aten_permute, + "pow": core_ops.aten_pow, "reciprocal": core_ops.aten_reciprocal, "remainder": core_ops.aten_remainder, "repeat": core_ops.aten_repeat, diff --git a/onnxscript/values.py b/onnxscript/values.py index f115758df2..7ca25fd0e7 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -133,10 +133,18 @@ def adapt_kwargs(self, kwargs): ) return kwargs, closure + def _convert_kwargs_to_numpy(self, kwargs: dict[str, Any]) -> dict[str, Any]: + new_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, tensor.Tensor): + new_kwargs[k] = v.value + + return new_kwargs + def __call__(self, *args, **kwargs): from onnxscript import evaluator # pylint: disable=import-outside-toplevel - return evaluator.eval(self.opschema, args, kwargs) + return evaluator.eval(self.opschema, args, self._convert_kwargs_to_numpy(kwargs)) @dataclasses.dataclass(repr=False, eq=False) From 11ce3537ccbc1144562e2b7e592f9a22a2b3b47c Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Tue, 10 Jan 2023 13:45:20 +0800 Subject: [PATCH 3/8] Copy all of attributes in the beginning. Signed-off-by: Jay Zhang --- onnxscript/values.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/values.py b/onnxscript/values.py index 7ca25fd0e7..c700f44c56 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- from __future__ import annotations +import copy import dataclasses import logging import types @@ -134,7 +135,7 @@ def adapt_kwargs(self, kwargs): return kwargs, closure def _convert_kwargs_to_numpy(self, kwargs: dict[str, Any]) -> dict[str, Any]: - new_kwargs = {} + new_kwargs = copy.deepcopy(kwargs) for k, v in kwargs.items(): if isinstance(v, tensor.Tensor): new_kwargs[k] = v.value From a8cee9427a548815566c6390b4090b58d787735d Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Tue, 10 Jan 2023 14:26:46 +0800 Subject: [PATCH 4/8] Deepcopy is not supported in some case. Signed-off-by: Jay Zhang --- onnxscript/values.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/values.py b/onnxscript/values.py index c700f44c56..87e55dfae6 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from __future__ import annotations -import copy import dataclasses import logging import types @@ -135,8 +134,9 @@ def adapt_kwargs(self, kwargs): return kwargs, closure def _convert_kwargs_to_numpy(self, kwargs: dict[str, Any]) -> dict[str, Any]: - new_kwargs = copy.deepcopy(kwargs) + new_kwargs = {} for k, v in kwargs.items(): + new_kwargs[k] = v if isinstance(v, tensor.Tensor): new_kwargs[k] = v.value From 6a6f74106017a695417cb8a0ad9b903f2240ebe9 Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Tue, 10 Jan 2023 16:12:00 +0800 Subject: [PATCH 5/8] Disable 'full' test for a moment. Signed-off-by: Jay Zhang --- .../test/function_libs/torch_aten/ops_correctness_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index d84434075c..abe8ce9a5e 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -254,7 +254,7 @@ def _topk_input_wrangler( "expand": core_ops.aten_expand, "erf": core_ops.aten_erf, "fmod": core_ops.aten_fmod, - "full": (core_ops.aten_full, _full_input_wrangler), + # "full": (core_ops.aten_full, _full_input_wrangler), "full_like": core_ops.aten_full_like, "gt": core_ops.aten_gt, "index_select": core_ops.aten_index_select, From 1b6e01700a5ec89eb18874f57e0704a6040ce80b Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Tue, 10 Jan 2023 19:41:15 +0800 Subject: [PATCH 6/8] Skip some tests for permute op. Signed-off-by: Jay Zhang --- .../function_libs/torch_aten/ops_correctness_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index abe8ce9a5e..65ea48d775 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -408,6 +408,16 @@ def _topk_input_wrangler( matcher=lambda sample: "scale_factor" in sample.kwargs, reason="fixme: the scale_factor tests", ), + skip( + "permute", + matcher=lambda sample: len(list(filter(lambda v : v < 0, sample.args[0]))) > 0, + reason="Negative value in perm is not supported", + ), + skip( + "permute", + matcher=lambda sample: len(sample.args[0]) == 0, + reason="Empty perm is not supported", + ), ) duplicate_opinfo( From 8625341a0fd0a238f0a1ec28c7a740ad31b53108 Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Tue, 10 Jan 2023 19:54:39 +0800 Subject: [PATCH 7/8] Fix style issue. --- .../test/function_libs/torch_aten/ops_correctness_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 65ea48d775..9e89addd29 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -410,7 +410,7 @@ def _topk_input_wrangler( ), skip( "permute", - matcher=lambda sample: len(list(filter(lambda v : v < 0, sample.args[0]))) > 0, + matcher=lambda sample: len(list(filter(lambda v: v < 0, sample.args[0]))) > 0, reason="Negative value in perm is not supported", ), skip( From 37e1f11bb503410059da0c66d0f25db86242489e Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Wed, 11 Jan 2023 09:24:16 +0800 Subject: [PATCH 8/8] Enable the test of full. --- .../test/function_libs/torch_aten/ops_correctness_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 9e89addd29..6d6c2be714 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -254,7 +254,7 @@ def _topk_input_wrangler( "expand": core_ops.aten_expand, "erf": core_ops.aten_erf, "fmod": core_ops.aten_fmod, - # "full": (core_ops.aten_full, _full_input_wrangler), + "full": (core_ops.aten_full, _full_input_wrangler), "full_like": core_ops.aten_full_like, "gt": core_ops.aten_gt, "index_select": core_ops.aten_index_select,