From 46871892cf06c3cca85c7b0ab2e946cee45bfb47 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 28 Dec 2022 21:54:23 +0800 Subject: [PATCH 01/32] add reshape/view --- aaa.onnx | Bin 0 -> 144 bytes onnxscript/function_libs/torch_aten/ops/core.py | 7 ++++--- .../torch_aten/ops_correctness_test.py | 2 ++ onnxscript/values.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 aaa.onnx diff --git a/aaa.onnx b/aaa.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4d3b611ebf83e288acc4c87b70e9e4b40854876f GIT binary patch literal 144 zcmd;Jw`yhNGUH;)%qu7@F@VyBLhSjaB_IJM_Mp__jKqReEdegh#GIV`@~YILd TensorType: # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) - raise NotImplementedError() + return op.Reshape(self, shape) def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: @@ -4840,11 +4841,11 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::view") def aten_view(self: TensorType, size: INT64) -> TensorType: # view(Tensor(a) self, SymInt[] size) -> Tensor(a) - raise NotImplementedError() - + return op.Reshape(self, size) def aten_view_as(self: TensorType, other: TensorType) -> TensorType: # view_as(Tensor(a) self, Tensor other) -> Tensor(a) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 48bd88b312..476772484d 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -191,6 +191,7 @@ def wrapped(fn): "ones_like": core_ops.aten_ones_like, "ones": core_ops.aten_ones, "repeat": core_ops.aten_repeat, + "reshape": core_ops.aten_reshape, "round": core_ops.aten_round, "sin": core_ops.aten_sin, "sinh": core_ops.aten_sinh, @@ -199,6 +200,7 @@ def wrapped(fn): "tan": core_ops.aten_tan, "tanh": core_ops.aten_tanh, # "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed, + "view": core_ops.aten_view, "zeros": core_ops.aten_zeros, "zeros_like": core_ops.aten_zeros_like, } diff --git a/onnxscript/values.py b/onnxscript/values.py index f115758df2..bd89f4f568 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -207,7 +207,8 @@ def adapt(input: ExtendedModeValue) -> EagerModeValue: elif input is None: return None elif isinstance(input, list): - return [adapt(elt) for elt in input] + return input + # return [adapt(elt) for elt in input] elif isinstance(input, tuple): return tuple(adapt(elt) for elt in input) raise TypeError(f"Unexpected input type {type(input)}.") From 8ba5b11bda5f84d3607a0ceed909f0cabf40f085 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 28 Dec 2022 21:57:33 +0800 Subject: [PATCH 02/32] Delete aaa.onnx --- aaa.onnx | Bin 144 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 aaa.onnx diff --git a/aaa.onnx b/aaa.onnx deleted file mode 100644 index 4d3b611ebf83e288acc4c87b70e9e4b40854876f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 144 zcmd;Jw`yhNGUH;)%qu7@F@VyBLhSjaB_IJM_Mp__jKqReEdegh#GIV`@~YILd Date: Wed, 28 Dec 2022 22:05:30 +0800 Subject: [PATCH 03/32] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 7c2abf3eb7..cd572fba99 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4009,7 +4009,7 @@ def aten_repeat_interleave( def aten_reshape(self: TensorType, shape: INT64) -> TensorType: # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) - return op.Reshape(self, shape) + return op.Reshape(self, shape) # type: ignore[arg-type] def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: @@ -4845,7 +4845,7 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: def aten_view(self: TensorType, size: INT64) -> TensorType: # view(Tensor(a) self, SymInt[] size) -> Tensor(a) - return op.Reshape(self, size) + return op.Reshape(self, size) # type: ignore[arg-type] def aten_view_as(self: TensorType, other: TensorType) -> TensorType: # view_as(Tensor(a) self, Tensor other) -> Tensor(a) From 87a7a82da4c9464a3ee27e4dcf03be400813ca5a Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Sat, 31 Dec 2022 15:34:04 +0800 Subject: [PATCH 04/32] add ops --- onnxscript/function_libs/torch_aten/ops/core.py | 6 ++++-- .../test/function_libs/torch_aten/ops_correctness_test.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index cd572fba99..85f825395f 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1575,10 +1575,11 @@ def aten_exp2(self): return op.Pow(two, self) # type: ignore[arg-type] +@torch_op() def aten_expand(self: TensorType, size: INT64, implicit: bool = False) -> TensorType: # expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) - raise NotImplementedError() + return op.Expand(self, size) def aten_expand_as(self: TensorType, other: TensorType) -> TensorType: @@ -4439,10 +4440,11 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens raise NotImplementedError() +@torch_op() def aten_sum(self: TensorType, dtype: Optional[int] = None) -> TensorType: # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor - raise NotImplementedError() + return op.Sum(self) def aten_sum_to_size(self: TensorType, size: Sequence[int]) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 476772484d..4bad53ebc9 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -179,6 +179,7 @@ def wrapped(fn): "dot": core_ops.aten_dot, "exp": core_ops.aten_exp, "exp2": core_ops.aten_exp2, + "expand": core_ops.aten_expand, "gt": core_ops.aten_gt, "lt": core_ops.aten_lt, "matmul": core_ops.aten_matmul, @@ -196,6 +197,7 @@ def wrapped(fn): "sin": core_ops.aten_sin, "sinh": core_ops.aten_sinh, "sub": core_ops.aten_sub, + "sum": core_ops.aten_sum, "t": core_ops.aten_t, "tan": core_ops.aten_tan, "tanh": core_ops.aten_tanh, From f5956beb9751b926ccb456e79b5ee249b705f25f Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Sat, 31 Dec 2022 16:35:19 +0800 Subject: [PATCH 05/32] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 85f825395f..78535c4db0 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1575,7 +1575,7 @@ def aten_exp2(self): return op.Pow(two, self) # type: ignore[arg-type] -@torch_op() +@torch_op("aten::expand") def aten_expand(self: TensorType, size: INT64, implicit: bool = False) -> TensorType: # expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) @@ -4440,7 +4440,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens raise NotImplementedError() -@torch_op() +@torch_op("aten::sum") def aten_sum(self: TensorType, dtype: Optional[int] = None) -> TensorType: # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor From fa7fefcb07715d088d2233e45c38938bb2c4d004 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Sat, 31 Dec 2022 17:15:02 +0800 Subject: [PATCH 06/32] add clone --- onnxscript/function_libs/torch_aten/ops/core.py | 5 +++-- .../test/function_libs/torch_aten/ops_correctness_test.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 78535c4db0..ac4b19f01c 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -803,10 +803,11 @@ def aten_clamp_min(self, min_): return result +@torch_op("aten::clone") def aten_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType: # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor - raise NotImplementedError() + return op.CastLike(self, self) def aten_coalesce(self: TensorType) -> TensorType: @@ -4444,7 +4445,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens def aten_sum(self: TensorType, dtype: Optional[int] = None) -> TensorType: # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor - return op.Sum(self) + return op.ReduceSum(self, keepdims=0) def aten_sum_to_size(self: TensorType, size: Sequence[int]) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 4bad53ebc9..014468e7b2 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -174,6 +174,7 @@ def wrapped(fn): "clamp_max": core_ops.aten_clamp_max, "clamp_min": core_ops.aten_clamp_min, "clamp": core_ops.aten_clamp, + "clone": core_ops.aten_clone, "cos": core_ops.aten_cos, "cosh": core_ops.aten_cosh, "dot": core_ops.aten_dot, From 252f05fddfd0e6d6fa36cb25587829976302cfbb Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Sat, 31 Dec 2022 20:26:27 +0800 Subject: [PATCH 07/32] add more ops --- onnxscript/function_libs/torch_aten/ops/core.py | 11 +++++++---- onnxscript/function_libs/torch_aten/ops/special.py | 7 +++++-- .../function_libs/torch_aten/ops_correctness_test.py | 2 ++ 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index ac4b19f01c..34c9440d88 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -804,7 +804,7 @@ def aten_clamp_min(self, min_): @torch_op("aten::clone") -def aten_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType: +def aten_clone(self: TensorType, memory_format: str = None) -> TensorType: # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor return op.CastLike(self, self) @@ -1407,10 +1407,11 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2) -> TensorType: raise NotImplementedError() +@torch_op("aten::div") def aten_div(self: TensorType, other: TensorType) -> TensorType: # div.Tensor(Tensor self, Tensor other) -> Tensor - raise NotImplementedError() + return op.Div(self, other) def aten_divide(self: TensorType, other: TensorType) -> TensorType: @@ -4252,6 +4253,7 @@ def aten_sinh(self): return op.Sinh(self) +@torch_op("aten::slice") def aten_slice( self: TensorType, dim: int = 0, @@ -4261,7 +4263,8 @@ def aten_slice( ) -> TensorType: # slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) - raise NotImplementedError() + return op.Slice(self, start, end, dim, step) + def aten_slice_backward( @@ -4442,7 +4445,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens @torch_op("aten::sum") -def aten_sum(self: TensorType, dtype: Optional[int] = None) -> TensorType: +def aten_sum(self: TensorType, dtype: int = None) -> TensorType: # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor return op.ReduceSum(self, keepdims=0) diff --git a/onnxscript/function_libs/torch_aten/ops/special.py b/onnxscript/function_libs/torch_aten/ops/special.py index 29cb5c9387..2b15c7a8b9 100644 --- a/onnxscript/function_libs/torch_aten/ops/special.py +++ b/onnxscript/function_libs/torch_aten/ops/special.py @@ -16,6 +16,8 @@ from typing import Optional, Sequence +from onnxscript.function_libs.torch_aten.registration import torch_op +from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -205,12 +207,13 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::log_softmax") def aten_special_log_softmax( - self: TensorType, dim: int, dtype: Optional[int] = None + self: TensorType, dim: int, dtype: int = None ) -> TensorType: # special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor - raise NotImplementedError() + return op.LogSoftmax(self, axis=dim) def aten_special_logit(self: TensorType, eps: Optional[float] = None) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 014468e7b2..ee6810e9fe 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -16,6 +16,7 @@ import onnxscript from onnxscript.function_libs.torch_aten.ops import core as core_ops from onnxscript.function_libs.torch_aten.ops import nn as nn_ops +from onnxscript.function_libs.torch_aten.ops import special as special_ops T = TypeVar("T") @@ -183,6 +184,7 @@ def wrapped(fn): "expand": core_ops.aten_expand, "gt": core_ops.aten_gt, "lt": core_ops.aten_lt, + "log_softmax": special_ops.aten_special_log_softmax, "matmul": core_ops.aten_matmul, "mm": core_ops.aten_mm, "mul": core_ops.aten_mul, From b5da543e751ac4ff284526dc42bc05333f3653e3 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Sun, 1 Jan 2023 18:50:33 +0800 Subject: [PATCH 08/32] add ops --- .../function_libs/torch_aten/ops/core.py | 22 ++++++++++++++----- .../torch_aten/ops_correctness_test.py | 1 + onnxscript/values.py | 4 ++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 34c9440d88..917a52142b 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -655,10 +655,12 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() +@torch_op("aten::cat") def aten_cat(tensors: Sequence[TensorType], dim: int = 0) -> TensorType: # cat(Tensor[] tensors, int dim=0) -> Tensor + # TODO: onnxscript cannot support parsing correctly input as Tensor[] now - raise NotImplementedError() + return op.Concat(tensors, axis=dim) def aten_ccol_indices(self: TensorType) -> TensorType: @@ -1531,16 +1533,26 @@ def aten_empty_strided(size: INT64, stride: INT64) -> TensorType: raise NotImplementedError() +@torch_op("aten::eq") def aten_eq(self: TensorType, other: TensorType) -> TensorType: # eq.Tensor(Tensor self, Tensor other) -> Tensor - raise NotImplementedError() + return op.Equal(self, other) +@torch_op("aten::equal") def aten_equal(self: TensorType, other: TensorType) -> bool: # equal(Tensor self, Tensor other) -> bool - raise NotImplementedError() + sub_self_other = op.Sub(self, other) + abs_sub = op.Abs(sub_self_other) + sum_of_abs = op.ReduceSum(abs_sub) + result = True + if sum_of_abs == 0: + result = True + else: + result = False + return result def aten_erf(self: TensorType) -> TensorType: @@ -4257,8 +4269,8 @@ def aten_sinh(self): def aten_slice( self: TensorType, dim: int = 0, - start: Optional[INT64] = None, - end: Optional[INT64] = None, + start: INT64 = None, + end: INT64 = None, step: INT64 = 1, ) -> TensorType: # slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index ee6810e9fe..87b80d40ab 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -171,6 +171,7 @@ def wrapped(fn): "atan": core_ops.aten_atan, "atanh": core_ops.aten_atanh, "bmm": core_ops.aten_bmm, + # "cat": core_ops.aten_cat, # TODO(xiaowuhu): Enable when onnxscript errors are fixed, it cannot suport Sequence[tensor] parsing as input "ceil": core_ops.aten_ceil, "clamp_max": core_ops.aten_clamp_max, "clamp_min": core_ops.aten_clamp_min, diff --git a/onnxscript/values.py b/onnxscript/values.py index bd89f4f568..e6f6f01fa9 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -204,6 +204,8 @@ def adapt(input: ExtendedModeValue) -> EagerModeValue: return input elif isinstance(input, (bool, int, float)): return tensor.Tensor(np.array(input)) + elif isinstance(input, int): + return input elif input is None: return None elif isinstance(input, list): @@ -237,6 +239,8 @@ def _adapt_to_user_mode(output: ExtendedModeValue) -> UserModeValue: return tuple(_adapt_to_user_mode(elt) for elt in output) elif isinstance(output, np.ndarray): return output + elif isinstance(output, (bool, int, float)): + return output raise TypeError(f"Unexpected type {type(output)}.") From d21219d30b056a49103cd58f0b71f709fc63447d Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Sun, 1 Jan 2023 19:12:04 +0800 Subject: [PATCH 09/32] add ops --- onnxscript/function_libs/torch_aten/ops/core.py | 1 + .../test/function_libs/torch_aten/ops_correctness_test.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 917a52142b..590d6540b4 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4459,6 +4459,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens @torch_op("aten::sum") def aten_sum(self: TensorType, dtype: int = None) -> TensorType: # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor + # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() return op.ReduceSum(self, keepdims=0) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 87b80d40ab..5be9ae3dbf 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -179,7 +179,10 @@ def wrapped(fn): "clone": core_ops.aten_clone, "cos": core_ops.aten_cos, "cosh": core_ops.aten_cosh, + "div": core_ops.aten_div, "dot": core_ops.aten_dot, + "eq": core_ops.aten_eq, + "equal": core_ops.aten_equal, "exp": core_ops.aten_exp, "exp2": core_ops.aten_exp2, "expand": core_ops.aten_expand, @@ -200,6 +203,7 @@ def wrapped(fn): "round": core_ops.aten_round, "sin": core_ops.aten_sin, "sinh": core_ops.aten_sinh, + "slice": core_ops.aten_slice, "sub": core_ops.aten_sub, "sum": core_ops.aten_sum, "t": core_ops.aten_t, @@ -349,6 +353,10 @@ def wrapped(fn): reason="Sinh is not defined on bool or int tensors", ), xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), + xfail( + "sum", + dtypes=[torch.bfloat16, torch.int32, torch.int64, torch.double, torch.float32, torch.float16], + reason="Sum is not defined on bool tensors"), xfail( "tan", dtypes=BOOL_TYPES + INT_TYPES, From c8ead64c78568dfe7b5ca900726427ee54cf15a7 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Sun, 1 Jan 2023 19:24:38 +0800 Subject: [PATCH 10/32] fix bug --- onnxscript/function_libs/torch_aten/ops/core.py | 7 +------ .../test/function_libs/torch_aten/ops_correctness_test.py | 4 ++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 590d6540b4..0a9fe1fb20 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1547,12 +1547,7 @@ def aten_equal(self: TensorType, other: TensorType) -> bool: sub_self_other = op.Sub(self, other) abs_sub = op.Abs(sub_self_other) sum_of_abs = op.ReduceSum(abs_sub) - result = True - if sum_of_abs == 0: - result = True - else: - result = False - return result + return sum_of_abs == 0 def aten_erf(self: TensorType) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 5be9ae3dbf..c2a33235d5 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -171,7 +171,7 @@ def wrapped(fn): "atan": core_ops.aten_atan, "atanh": core_ops.aten_atanh, "bmm": core_ops.aten_bmm, - # "cat": core_ops.aten_cat, # TODO(xiaowuhu): Enable when onnxscript errors are fixed, it cannot suport Sequence[tensor] parsing as input + # "cat": core_ops.aten_cat, # TODO(xiaowuhu): Enable after fix: it cannot parse Sequence[tensor] as input "ceil": core_ops.aten_ceil, "clamp_max": core_ops.aten_clamp_max, "clamp_min": core_ops.aten_clamp_min, @@ -355,7 +355,7 @@ def wrapped(fn): xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), xfail( "sum", - dtypes=[torch.bfloat16, torch.int32, torch.int64, torch.double, torch.float32, torch.float16], + dtypes=except(torch.bfloat16, torch.int32, torch.int64, torch.double, torch.float32, torch.float16), reason="Sum is not defined on bool tensors"), xfail( "tan", From 0a250fd2e41d10df042cd6afbedd23ab8fcb57ce Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Sun, 1 Jan 2023 19:31:18 +0800 Subject: [PATCH 11/32] Update ops_correctness_test.py --- .../test/function_libs/torch_aten/ops_correctness_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index c2a33235d5..5c9a95ecc5 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -355,8 +355,9 @@ def wrapped(fn): xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), xfail( "sum", - dtypes=except(torch.bfloat16, torch.int32, torch.int64, torch.double, torch.float32, torch.float16), - reason="Sum is not defined on bool tensors"), + dtypes=dtypes_except(torch.bfloat16, torch.int32, torch.int64, torch.double, torch.float32, torch.float16), + reason="Sum is not defined on bool tensors", + ) xfail( "tan", dtypes=BOOL_TYPES + INT_TYPES, From d8b8e331233bf6c6cae222187e5837d20aa40c08 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Sun, 1 Jan 2023 19:33:14 +0800 Subject: [PATCH 12/32] Update ops_correctness_test.py --- .../test/function_libs/torch_aten/ops_correctness_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 5c9a95ecc5..8648d44881 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -357,7 +357,7 @@ def wrapped(fn): "sum", dtypes=dtypes_except(torch.bfloat16, torch.int32, torch.int64, torch.double, torch.float32, torch.float16), reason="Sum is not defined on bool tensors", - ) + ), xfail( "tan", dtypes=BOOL_TYPES + INT_TYPES, From 744f2f30d509941d56a155bf99e49846ffe65aa3 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 4 Jan 2023 11:09:32 +0800 Subject: [PATCH 13/32] update --- .../function_libs/torch_aten/ops_correctness_test.py | 9 +++++++-- onnxscript/values.py | 2 -- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 66efccbb30..9472d5758d 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -207,7 +207,7 @@ def wrapped(fn): "slice": core_ops.aten_slice, "sqrt": core_ops.aten_sqrt, "sub": core_ops.aten_sub, - "sum": core_ops.aten_sum, + # "sum": core_ops.aten_sum, #TODO: kwargs={dim, keepdims}, dim is invalid "t": core_ops.aten_t, "tan": core_ops.aten_tan, "tanh": core_ops.aten_tanh, @@ -257,10 +257,15 @@ def wrapped(fn): dtypes=BOOL_TYPES + INT_TYPES, reason="Sinh is not defined on bool or int tensors", ), + xfail( + "slice", + dtypes=dtypes_except(torch.float32), + reason="Sinh is not defined on bool or int tensors", + ), xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), xfail( "sum", - dtypes=dtypes_except(torch.bfloat16, torch.int32, torch.int64, torch.double, torch.float32, torch.float16), + dtypes=dtypes_except(torch.bfloat16, torch.int64, torch.double, torch.float32, torch.float16), reason="Sum is not defined on bool tensors", ), xfail( diff --git a/onnxscript/values.py b/onnxscript/values.py index e6f6f01fa9..7c5867424e 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -204,8 +204,6 @@ def adapt(input: ExtendedModeValue) -> EagerModeValue: return input elif isinstance(input, (bool, int, float)): return tensor.Tensor(np.array(input)) - elif isinstance(input, int): - return input elif input is None: return None elif isinstance(input, list): From 62c2ebd100d2c9cc1ed2b57a4f39802ebebd5250 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 14:18:11 +0800 Subject: [PATCH 14/32] fix comments --- onnxscript/function_libs/torch_aten/ops/core.py | 17 +++++++++-------- .../function_libs/torch_aten/ops/special.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index bc6a02485b..05c85dea30 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -820,10 +820,10 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal: @torch_op("aten::clone") -def aten_clone(self: TensorType, memory_format: str = None) -> TensorType: +def aten_clone(self: TensorType, memory_format: str = None) -> TensorType: # pylint: disable=unused-argument # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor - return op.CastLike(self, self) + return op.Identity(self) def aten_coalesce(self: TensorType) -> TensorType: @@ -1600,7 +1600,7 @@ def aten_exp2(self: TFloat) -> TFloat: @torch_op("aten::expand") -def aten_expand(self: TensorType, size: INT64, implicit: bool = False) -> TensorType: +def aten_expand(self: TensorType, size: INT64) -> TensorType: # expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) return op.Expand(self, size) @@ -4050,7 +4050,7 @@ def aten_repeat_interleave( @torch_op("aten::reshape") -def aten_reshape(self: TensorType, shape: INT64) -> TensorType: +def aten_reshape(self: TensorType, shape: INT64["M"]) -> TensorType: # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) return op.Reshape(self, shape) # type: ignore[arg-type] @@ -4301,8 +4301,8 @@ def aten_sinh(self: TFloat) -> TFloat: def aten_slice( self: TensorType, dim: int = 0, - start: INT64 = None, - end: INT64 = None, + start: Optional[INT64] = None, + end: Optional[INT64] = None, step: INT64 = 1, ) -> TensorType: # slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) @@ -4490,7 +4490,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens @torch_op("aten::sum") -def aten_sum(self: TensorType, dtype: int = None) -> TensorType: +def aten_sum(self: TensorType, dtype: int = None) -> TensorType: # pylint: disable=unused-argument # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() @@ -4896,11 +4896,12 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::view") -def aten_view(self: TensorType, size: INT64) -> TensorType: +def aten_view(self: TensorType, size: INT64["M"]) -> TensorType: # view(Tensor(a) self, SymInt[] size) -> Tensor(a) return op.Reshape(self, size) # type: ignore[arg-type] + def aten_view_as(self: TensorType, other: TensorType) -> TensorType: # view_as(Tensor(a) self, Tensor other) -> Tensor(a) diff --git a/onnxscript/function_libs/torch_aten/ops/special.py b/onnxscript/function_libs/torch_aten/ops/special.py index 60e3466f52..9d8995aedb 100644 --- a/onnxscript/function_libs/torch_aten/ops/special.py +++ b/onnxscript/function_libs/torch_aten/ops/special.py @@ -206,7 +206,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: @torch_op("aten::log_softmax") def aten_special_log_softmax( - self: TensorType, dim: int, dtype: int = None + self: TensorType, dim: int, dtype: int = -1 # pylint: disable=unused-argument ) -> TensorType: # special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor From 81ae88630ec3bf84cf560e7952f909934fb03cca Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 14:50:32 +0800 Subject: [PATCH 15/32] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 05c85dea30..d67cc04631 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -820,7 +820,7 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal: @torch_op("aten::clone") -def aten_clone(self: TensorType, memory_format: str = None) -> TensorType: # pylint: disable=unused-argument +def aten_clone(self: TensorType, memory_format: str = "") -> TensorType: # pylint: disable=unused-argument # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor return op.Identity(self) @@ -4050,7 +4050,7 @@ def aten_repeat_interleave( @torch_op("aten::reshape") -def aten_reshape(self: TensorType, shape: INT64["M"]) -> TensorType: +def aten_reshape(self: TensorType, shape: INT64) -> TensorType: # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) return op.Reshape(self, shape) # type: ignore[arg-type] @@ -4896,7 +4896,7 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::view") -def aten_view(self: TensorType, size: INT64["M"]) -> TensorType: +def aten_view(self: TensorType, size: INT64) -> TensorType: # view(Tensor(a) self, SymInt[] size) -> Tensor(a) return op.Reshape(self, size) # type: ignore[arg-type] From e214859e36d3c901c6d3a692ead9da516826e49a Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 14:58:59 +0800 Subject: [PATCH 16/32] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index d67cc04631..63a45e701a 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -820,7 +820,9 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal: @torch_op("aten::clone") -def aten_clone(self: TensorType, memory_format: str = "") -> TensorType: # pylint: disable=unused-argument +def aten_clone( + self: TensorType, memory_format: str = "" +) -> TensorType: # pylint: disable=unused-argument # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor return op.Identity(self) @@ -4053,7 +4055,7 @@ def aten_repeat_interleave( def aten_reshape(self: TensorType, shape: INT64) -> TensorType: # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) - return op.Reshape(self, shape) # type: ignore[arg-type] + return op.Reshape(self, shape) # type: ignore[arg-type] def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: @@ -4310,7 +4312,6 @@ def aten_slice( return op.Slice(self, start, end, dim, step) - def aten_slice_backward( grad_output: TensorType, input_sizes: INT64, @@ -4490,7 +4491,9 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens @torch_op("aten::sum") -def aten_sum(self: TensorType, dtype: int = None) -> TensorType: # pylint: disable=unused-argument +def aten_sum( + self: TensorType, dtype: int = None +) -> TensorType: # pylint: disable=unused-argument # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() @@ -4899,7 +4902,7 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: def aten_view(self: TensorType, size: INT64) -> TensorType: # view(Tensor(a) self, SymInt[] size) -> Tensor(a) - return op.Reshape(self, size) # type: ignore[arg-type] + return op.Reshape(self, size) # type: ignore[arg-type] def aten_view_as(self: TensorType, other: TensorType) -> TensorType: From 59cf7f52133a715e05fb3699b3e3441007373afb Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 15:11:36 +0800 Subject: [PATCH 17/32] fix issues --- onnxscript/function_libs/torch_aten/ops/core.py | 7 ++----- .../test/function_libs/torch_aten/ops_correctness_test.py | 4 +++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 63a45e701a..86dda9dcc2 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -821,8 +821,7 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal: @torch_op("aten::clone") def aten_clone( - self: TensorType, memory_format: str = "" -) -> TensorType: # pylint: disable=unused-argument + self: TensorType, memory_format: str = "") -> TensorType: # pylint: disable=unused-argument # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor return op.Identity(self) @@ -4491,9 +4490,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens @torch_op("aten::sum") -def aten_sum( - self: TensorType, dtype: int = None -) -> TensorType: # pylint: disable=unused-argument +def aten_sum(self: TensorType, dtype: int = -1) -> TensorType: # pylint: disable=unused-argument # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 9472d5758d..1bc489cebf 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -265,7 +265,9 @@ def wrapped(fn): xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), xfail( "sum", - dtypes=dtypes_except(torch.bfloat16, torch.int64, torch.double, torch.float32, torch.float16), + dtypes=dtypes_except( + torch.bfloat16, torch.int64, torch.double, torch.float32, torch.float16 + ), reason="Sum is not defined on bool tensors", ), xfail( From 85773d46fbfed61563d41d99f6779956f290e8cc Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 15:26:12 +0800 Subject: [PATCH 18/32] fix pylint --- onnxscript/function_libs/torch_aten/ops/core.py | 7 +++++-- onnxscript/function_libs/torch_aten/ops/special.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 86dda9dcc2..51d2e726cc 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -821,7 +821,8 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal: @torch_op("aten::clone") def aten_clone( - self: TensorType, memory_format: str = "") -> TensorType: # pylint: disable=unused-argument + self: TensorType, memory_format: str = "" +) -> TensorType: # pylint: disable=unused-argument # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor return op.Identity(self) @@ -4490,7 +4491,9 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens @torch_op("aten::sum") -def aten_sum(self: TensorType, dtype: int = -1) -> TensorType: # pylint: disable=unused-argument +def aten_sum( + self: TensorType, dtype: int = -1 +) -> TensorType: # pylint: disable=unused-argument # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() diff --git a/onnxscript/function_libs/torch_aten/ops/special.py b/onnxscript/function_libs/torch_aten/ops/special.py index 9d8995aedb..7ca78ea8f1 100644 --- a/onnxscript/function_libs/torch_aten/ops/special.py +++ b/onnxscript/function_libs/torch_aten/ops/special.py @@ -206,7 +206,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: @torch_op("aten::log_softmax") def aten_special_log_softmax( - self: TensorType, dim: int, dtype: int = -1 # pylint: disable=unused-argument + self: TensorType, dim: int, dtype: int = -1 # pylint: disable=unused-argument ) -> TensorType: # special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor From ceec613c2cabbb5efeb096951fa1177f240c7c5a Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 16:13:26 +0800 Subject: [PATCH 19/32] fix lint --- onnxscript/function_libs/torch_aten/ops/core.py | 12 ++++++------ onnxscript/function_libs/torch_aten/ops/special.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 51d2e726cc..6d608ca32a 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -821,8 +821,8 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal: @torch_op("aten::clone") def aten_clone( - self: TensorType, memory_format: str = "" -) -> TensorType: # pylint: disable=unused-argument + self: TensorType, memory_format: str = "" # pylint: disable=unused-argument +) -> TensorType: # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor return op.Identity(self) @@ -4309,7 +4309,7 @@ def aten_slice( ) -> TensorType: # slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) - return op.Slice(self, start, end, dim, step) + return op.Slice(self, start, end, dim, step) # type: ignore[arg-type] def aten_slice_backward( @@ -4492,12 +4492,12 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens @torch_op("aten::sum") def aten_sum( - self: TensorType, dtype: int = -1 -) -> TensorType: # pylint: disable=unused-argument + self: TensorType, dtype: int = -1 # pylint: disable=unused-argument +) -> TensorType: # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() - return op.ReduceSum(self, keepdims=0) + return op.ReduceSum(self, keepdims=0) # type: ignore[arg-type] def aten_sum_to_size(self: TensorType, size: Sequence[int]) -> TensorType: diff --git a/onnxscript/function_libs/torch_aten/ops/special.py b/onnxscript/function_libs/torch_aten/ops/special.py index 7ca78ea8f1..1a4bf397cb 100644 --- a/onnxscript/function_libs/torch_aten/ops/special.py +++ b/onnxscript/function_libs/torch_aten/ops/special.py @@ -206,7 +206,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: @torch_op("aten::log_softmax") def aten_special_log_softmax( - self: TensorType, dim: int, dtype: int = -1 # pylint: disable=unused-argument + self: TensorType, dim: int, dtype: int = -1 # pylint: disable=unused-argument ) -> TensorType: # special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor From 536b785cb746faec016f58e97303429084b9a95a Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 16:22:45 +0800 Subject: [PATCH 20/32] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 6d608ca32a..b073a34068 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4309,7 +4309,7 @@ def aten_slice( ) -> TensorType: # slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) - return op.Slice(self, start, end, dim, step) # type: ignore[arg-type] + return op.Slice(self, start, end, dim, step)# type: ignore[arg-type] def aten_slice_backward( @@ -4497,7 +4497,7 @@ def aten_sum( # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() - return op.ReduceSum(self, keepdims=0) # type: ignore[arg-type] + return op.ReduceSum(self, keepdims=0)# type: ignore[arg-type] def aten_sum_to_size(self: TensorType, size: Sequence[int]) -> TensorType: From 2913eac6c9d57829c956ef18622b86cbfee664b4 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 16:39:05 +0800 Subject: [PATCH 21/32] fix lint --- .../function_libs/torch_aten/ops/core.py | 4 +- .../torch_aten/ops_correctness_test.py | 54 ------------------- 2 files changed, 2 insertions(+), 56 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index b073a34068..5ecc7a03bc 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -4309,7 +4309,7 @@ def aten_slice( ) -> TensorType: # slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) - return op.Slice(self, start, end, dim, step)# type: ignore[arg-type] + return op.Slice(self, start, end, dim, step) # type: ignore[arg-type] def aten_slice_backward( @@ -4497,7 +4497,7 @@ def aten_sum( # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() - return op.ReduceSum(self, keepdims=0)# type: ignore[arg-type] + return op.ReduceSum(self, keepdims=0) # type: ignore[arg-type] def aten_sum_to_size(self: TensorType, size: Sequence[int]) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 1bc489cebf..b2f54192c2 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -226,60 +226,6 @@ def wrapped(fn): "nn.functional.linear", reason="ONNX Runtime thinks the graph is invalid", ), - xfail( - "nn.functional.relu6", - dtypes=dtypes_except(torch.float16, torch.float32), - reason="ONNX Runtime doesn't support float64 for Relu", - ), - xfail( - "nn.functional.selu", - dtypes=dtypes_except(torch.float16, torch.float32), - reason="ONNX Runtime doesn't support float64 for Selu", - ), - xfail( - "round", - variant_name="", - dtypes=dtypes_except(*FLOAT_TYPES), - reason="Round is not defined on non-float tensors", - ), - xfail("round", variant_name="decimals_0", reason="The ATen op does not support decimals"), - xfail("round", variant_name="decimals_3", reason="The ATen op does not support decimals"), - xfail( - "round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals" - ), - xfail( - "sin", - dtypes=BOOL_TYPES + INT_TYPES, - reason="Sin is not defined on bool or int tensors", - ), - xfail( - "sinh", - dtypes=BOOL_TYPES + INT_TYPES, - reason="Sinh is not defined on bool or int tensors", - ), - xfail( - "slice", - dtypes=dtypes_except(torch.float32), - reason="Sinh is not defined on bool or int tensors", - ), - xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"), - xfail( - "sum", - dtypes=dtypes_except( - torch.bfloat16, torch.int64, torch.double, torch.float32, torch.float16 - ), - reason="Sum is not defined on bool tensors", - ), - xfail( - "tan", - dtypes=BOOL_TYPES + INT_TYPES, - reason="Tan is not defined on bool or int tensors", - ), - xfail( - "tanh", - dtypes=BOOL_TYPES + INT_TYPES, - reason="Tanh is not defined on bool or int tensors", - ), xfail("round", variant_name="decimals_0", reason="The op does not support decimals"), xfail("round", variant_name="decimals_3", reason="The op does not support decimals"), xfail("round", variant_name="decimals_neg_3", reason="The op does not support decimals"), From 2fd768b0057c404125c7151f8bdd78d7c30ad56a Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 17:57:49 +0800 Subject: [PATCH 22/32] fix failed case --- onnxscript/function_libs/torch_aten/ops/core.py | 2 +- .../test/function_libs/torch_aten/ops_correctness_test.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 5ecc7a03bc..5fc8a293e3 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1429,7 +1429,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2) -> TensorType: def aten_div(self: TensorType, other: TensorType) -> TensorType: # div.Tensor(Tensor self, Tensor other) -> Tensor - return op.Div(self, other) + return op.Div(self, other) # type: ignore[arg-type] def aten_divide(self: TensorType, other: TensorType) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index b2f54192c2..eb4a0239ee 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -234,6 +234,11 @@ def wrapped(fn): SKIP_SUBTESTS: tuple[DecorateMeta, ...] = ( + skip( + "div", + matcher=lambda sample: sample.kwargs.get("rounding_mode") is True, + reason="rounding_mode=True is not supported", + ), skip( "nonzero", matcher=lambda sample: sample.kwargs.get("as_tuple") is True, From 632bce8cde4f04856affa6d825e1abfef3c84f4c Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 17:58:56 +0800 Subject: [PATCH 23/32] Update ops_correctness_test.py --- .../test/function_libs/torch_aten/ops_correctness_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index eb4a0239ee..608eb1f43a 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -236,7 +236,7 @@ def wrapped(fn): SKIP_SUBTESTS: tuple[DecorateMeta, ...] = ( skip( "div", - matcher=lambda sample: sample.kwargs.get("rounding_mode") is True, + matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="rounding_mode=True is not supported", ), skip( From 4979616e227319918e3a35dd8b381d4fdbfaa7f8 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 5 Jan 2023 19:35:25 +0800 Subject: [PATCH 24/32] fix bug --- .../function_libs/torch_aten/ops/core.py | 2 +- .../torch_aten/ops_correctness_test.py | 21 +++++++++++-------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 5fc8a293e3..42b9e45feb 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1562,7 +1562,7 @@ def aten_equal(self: TensorType, other: TensorType) -> bool: sub_self_other = op.Sub(self, other) abs_sub = op.Abs(sub_self_other) - sum_of_abs = op.ReduceSum(abs_sub) + sum_of_abs = op.ReduceSum(abs_sub, keepdims=0) return sum_of_abs == 0 diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 608eb1f43a..2c2124b828 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -159,7 +159,7 @@ def wrapped(fn): "atan": core_ops.aten_atan, "atanh": core_ops.aten_atanh, "bmm": core_ops.aten_bmm, - # "cat": core_ops.aten_cat, # TODO(xiaowuhu): Enable after fix: it cannot parse Sequence[tensor] as input + "cat": core_ops.aten_cat, # TODO(xiaowuhu): Enable after fix: it cannot parse Sequence[tensor] as input "ceil": core_ops.aten_ceil, "clamp_max": core_ops.aten_clamp_max, "clamp_min": core_ops.aten_clamp_min, @@ -207,7 +207,7 @@ def wrapped(fn): "slice": core_ops.aten_slice, "sqrt": core_ops.aten_sqrt, "sub": core_ops.aten_sub, - # "sum": core_ops.aten_sum, #TODO: kwargs={dim, keepdims}, dim is invalid + "sum": core_ops.aten_sum, #TODO: kwargs={dim, keepdims}, dim is invalid "t": core_ops.aten_t, "tan": core_ops.aten_tan, "tanh": core_ops.aten_tanh, @@ -382,13 +382,16 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): rtol = None atol = None - # Use torch testing to ensure dtypes and shapes match - torch.testing.assert_close( - torch.tensor(function_output), - output_torch, - rtol=rtol, - atol=atol, - ) + if isinstance(output_torch, bool): + assert(output_torch == function_output) + else: + # Use torch testing to ensure dtypes and shapes match + torch.testing.assert_close( + torch.tensor(function_output), + output_torch, + rtol=rtol, + atol=atol, + ) common_device_type.instantiate_device_type_tests( From 0082b2345a60c43be6516d0b4b52f4374d186570 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 6 Jan 2023 18:04:25 +0800 Subject: [PATCH 25/32] fix issues --- .../function_libs/torch_aten/ops/core.py | 19 ++++++------ .../torch_aten/ops_correctness_test.py | 29 +++++++++++-------- onnxscript/values.py | 7 +++-- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 42b9e45feb..1f2cadeedd 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -672,7 +672,7 @@ def aten_cat(tensors: Sequence[TensorType], dim: int = 0) -> TensorType: # cat(Tensor[] tensors, int dim=0) -> Tensor # TODO: onnxscript cannot support parsing correctly input as Tensor[] now - return op.Concat(tensors, axis=dim) + return op.ConcatFromSequence(tensors, axis=dim) # type: ignore[arg-type] def aten_ccol_indices(self: TensorType) -> TensorType: @@ -1605,7 +1605,8 @@ def aten_exp2(self: TFloat) -> TFloat: def aten_expand(self: TensorType, size: INT64) -> TensorType: # expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) - return op.Expand(self, size) + size_int64 = op.Cast(size, to=7) # to INT64 + return op.Expand(self, size_int64) def aten_expand_as(self: TensorType, other: TensorType) -> TensorType: @@ -4055,7 +4056,8 @@ def aten_repeat_interleave( def aten_reshape(self: TensorType, shape: INT64) -> TensorType: # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) - return op.Reshape(self, shape) # type: ignore[arg-type] + shape_int64 = op.Cast(shape, to=7) # Reshape only support INT64 as 'shape' + return op.Reshape(self, shape_int64) # type: ignore[arg-type] def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: @@ -4299,7 +4301,6 @@ def aten_sinh(self: TFloat) -> TFloat: return op.Sinh(self) -@torch_op("aten::slice") def aten_slice( self: TensorType, dim: int = 0, @@ -4309,7 +4310,7 @@ def aten_slice( ) -> TensorType: # slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) - return op.Slice(self, start, end, dim, step) # type: ignore[arg-type] + raise NotImplementedError() def aten_slice_backward( @@ -4490,14 +4491,13 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens raise NotImplementedError() -@torch_op("aten::sum") def aten_sum( - self: TensorType, dtype: int = -1 # pylint: disable=unused-argument + self: TensorType, dim: Optional[int] = None, keepdim: bool = False, dtype: int = -1 # pylint: disable=unused-argument ) -> TensorType: # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() - return op.ReduceSum(self, keepdims=0) # type: ignore[arg-type] + raise NotImplementedError() def aten_sum_to_size(self: TensorType, size: Sequence[int]) -> TensorType: @@ -4902,7 +4902,8 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: def aten_view(self: TensorType, size: INT64) -> TensorType: # view(Tensor(a) self, SymInt[] size) -> Tensor(a) - return op.Reshape(self, size) # type: ignore[arg-type] + size_int64 = op.Cast(size, to=7) # Reshape only support INT64 as second input + return op.Reshape(self, size_int64) # type: ignore[arg-type] def aten_view_as(self: TensorType, other: TensorType) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 2c2124b828..05ef9163b1 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -159,7 +159,7 @@ def wrapped(fn): "atan": core_ops.aten_atan, "atanh": core_ops.aten_atanh, "bmm": core_ops.aten_bmm, - "cat": core_ops.aten_cat, # TODO(xiaowuhu): Enable after fix: it cannot parse Sequence[tensor] as input + "cat": core_ops.aten_cat, "ceil": core_ops.aten_ceil, "clamp_max": core_ops.aten_clamp_max, "clamp_min": core_ops.aten_clamp_min, @@ -204,10 +204,9 @@ def wrapped(fn): "sign": core_ops.aten_sign, "sin": core_ops.aten_sin, "sinh": core_ops.aten_sinh, - "slice": core_ops.aten_slice, "sqrt": core_ops.aten_sqrt, "sub": core_ops.aten_sub, - "sum": core_ops.aten_sum, #TODO: kwargs={dim, keepdims}, dim is invalid + # "sum": core_ops.aten_sum, # OptionalHasElement() function cannot work "t": core_ops.aten_t, "tan": core_ops.aten_tan, "tanh": core_ops.aten_tanh, @@ -222,6 +221,7 @@ def wrapped(fn): EXPECTED_SKIPS_OR_FAILS = ( skip("clamp", reason="Enable when onnxscript errors are fixed"), + skip("sum", reason="Enable when op.OptionalHasElement() function can work"), xfail( "nn.functional.linear", reason="ONNX Runtime thinks the graph is invalid", @@ -239,6 +239,11 @@ def wrapped(fn): matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="rounding_mode=True is not supported", ), + skip( + "expand", + matcher=lambda sample: (np.array(sample.args[0]) > 0).all() == False, + reason="Negative value is not supported", + ), skip( "nonzero", matcher=lambda sample: sample.kwargs.get("as_tuple") is True, @@ -382,16 +387,16 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): rtol = None atol = None - if isinstance(output_torch, bool): - assert(output_torch == function_output) - else: + # if isinstance(output_torch, bool): + # assert(output_torch == function_output) + # else: # Use torch testing to ensure dtypes and shapes match - torch.testing.assert_close( - torch.tensor(function_output), - output_torch, - rtol=rtol, - atol=atol, - ) + torch.testing.assert_close( + torch.tensor(function_output), + torch.tensor(output_torch), + rtol=rtol, + atol=atol, + ) common_device_type.instantiate_device_type_tests( diff --git a/onnxscript/values.py b/onnxscript/values.py index b847ce77bd..91ac53f1a6 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -207,6 +207,9 @@ def adapt(input: ExtendedModeValue) -> EagerModeValue: elif input is None: return None elif isinstance(input, list): + assert (len(input) > 0) + if isinstance(input[0], np.ndarray): # this is for the case: list[array] + has_array = True return input elif isinstance(input, tuple): return tuple(adapt(elt) for elt in input) @@ -236,8 +239,8 @@ def _adapt_to_user_mode(output: ExtendedModeValue) -> UserModeValue: return tuple(_adapt_to_user_mode(elt) for elt in output) elif isinstance(output, np.ndarray): return output - elif isinstance(output, (bool, int, float)): - return output + # elif isinstance(output, (bool, int, float)): + # return output raise TypeError(f"Unexpected type {type(output)}.") From c0bbc626abf5f468cc111d75415607d2af0a4fcc Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 6 Jan 2023 18:50:51 +0800 Subject: [PATCH 26/32] fix comments --- onnxscript/function_libs/torch_aten/ops/core.py | 12 ++++++------ .../function_libs/torch_aten/ops_correctness_test.py | 4 ---- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 2b2e4a701b..0ddbfca2b3 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1429,7 +1429,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2) -> TensorType: def aten_div(self: TensorType, other: TensorType) -> TensorType: # div.Tensor(Tensor self, Tensor other) -> Tensor - return op.Div(self, other) # type: ignore[arg-type] + return op.Div(self, other) def aten_divide(self: TensorType, other: TensorType) -> TensorType: @@ -1550,20 +1550,20 @@ def aten_empty_strided(size: INT64, stride: INT64) -> TensorType: @torch_op("aten::eq") -def aten_eq(self: TensorType, other: TensorType) -> TensorType: +def aten_eq(self: TTensor, other: TTensor) -> BOOL: # eq.Tensor(Tensor self, Tensor other) -> Tensor - return op.Equal(self, other) + return op.Equal(self, other) # type: ignore[arg-type] @torch_op("aten::equal") -def aten_equal(self: TensorType, other: TensorType) -> bool: +def aten_equal(self: TTensor, other: TTensor) -> bool: # equal(Tensor self, Tensor other) -> bool - sub_self_other = op.Sub(self, other) + sub_self_other = op.Sub(self, other) # type: ignore[arg-type] abs_sub = op.Abs(sub_self_other) sum_of_abs = op.ReduceSum(abs_sub, keepdims=0) - return sum_of_abs == 0 + return op.Equal(sum_of_abs, 0) @torch_op("aten::erf") diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index d1a0fad462..0d46db2ac5 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -391,10 +391,6 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): rtol = None atol = None - # if isinstance(output_torch, bool): - # assert(output_torch == function_output) - # else: - # Use torch testing to ensure dtypes and shapes match torch.testing.assert_close( torch.tensor(function_output), torch.tensor(output_torch), From 31e8bb0de46aca8aa608c83d6c374ea3982fc211 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 6 Jan 2023 18:54:42 +0800 Subject: [PATCH 27/32] Update values.py --- onnxscript/values.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxscript/values.py b/onnxscript/values.py index af974ea473..f115758df2 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -236,8 +236,6 @@ def _adapt_to_user_mode(output: ExtendedModeValue) -> UserModeValue: return tuple(_adapt_to_user_mode(elt) for elt in output) elif isinstance(output, np.ndarray): return output - # elif isinstance(output, (bool, int, float)): - # return output raise TypeError(f"Unexpected type {type(output)}.") From 2a6c291e6ce89f14ebef9395fed9957c7249ede9 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 6 Jan 2023 19:05:46 +0800 Subject: [PATCH 28/32] Update core.py --- onnxscript/function_libs/torch_aten/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 0ddbfca2b3..0fc9e8d099 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -1605,7 +1605,7 @@ def aten_exp2(self: TFloat) -> TFloat: def aten_expand(self: TensorType, size: INT64) -> TensorType: # expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) - size_int64 = op.Cast(size, to=7) # to INT64 + size_int64 = op.Cast(size, to=7) # to INT64 return op.Expand(self, size_int64) @@ -4912,7 +4912,7 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: def aten_view(self: TensorType, size: INT64) -> TensorType: # view(Tensor(a) self, SymInt[] size) -> Tensor(a) - size_int64 = op.Cast(size, to=7) # Reshape only support INT64 as second input + size_int64 = op.Cast(size, to=7) # Reshape only support INT64 as second input return op.Reshape(self, size_int64) # type: ignore[arg-type] From 5ba21fb16f6d1b161bab2f76267ffa5357c247fb Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 6 Jan 2023 19:22:19 +0800 Subject: [PATCH 29/32] remove logsoftmax --- onnxscript/function_libs/torch_aten/ops/special.py | 7 +++---- .../test/function_libs/torch_aten/ops_correctness_test.py | 2 -- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/special.py b/onnxscript/function_libs/torch_aten/ops/special.py index 1a4bf397cb..3e590400b6 100644 --- a/onnxscript/function_libs/torch_aten/ops/special.py +++ b/onnxscript/function_libs/torch_aten/ops/special.py @@ -14,7 +14,6 @@ from typing import Optional, Sequence from onnxscript.function_libs.torch_aten.registration import torch_op -from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -204,13 +203,13 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::log_softmax") def aten_special_log_softmax( - self: TensorType, dim: int, dtype: int = -1 # pylint: disable=unused-argument + self: TensorType, dim: int, dtype: int = -1 ) -> TensorType: # special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor - return op.LogSoftmax(self, axis=dim) + + raise NotImplementedError() def aten_special_logit(self: TensorType, eps: Optional[float] = None) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 0d46db2ac5..5b8d8ff68e 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -16,7 +16,6 @@ import onnxscript from onnxscript.function_libs.torch_aten.ops import core as core_ops from onnxscript.function_libs.torch_aten.ops import nn as nn_ops -from onnxscript.function_libs.torch_aten.ops import special as special_ops T = TypeVar("T") @@ -181,7 +180,6 @@ def wrapped(fn): "gt": core_ops.aten_gt, "isinf": core_ops.aten_isinf, "lt": core_ops.aten_lt, - "log_softmax": special_ops.aten_special_log_softmax, "matmul": core_ops.aten_matmul, "mm": core_ops.aten_mm, "mul": core_ops.aten_mul, From 2bdb0c0f20b37e8154cc78c7dd821223ded3926f Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 6 Jan 2023 19:45:18 +0800 Subject: [PATCH 30/32] remove import --- onnxscript/function_libs/torch_aten/ops/core.py | 7 +++++-- onnxscript/function_libs/torch_aten/ops/special.py | 6 +----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 0fc9e8d099..382f416e15 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -668,7 +668,7 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType: @torch_op("aten::cat") -def aten_cat(tensors: Sequence[TensorType], dim: int = 0) -> TensorType: +def aten_cat(tensors: Sequence[TensorType], dim: Optional[int] = 0) -> TensorType: # cat(Tensor[] tensors, int dim=0) -> Tensor # TODO: onnxscript cannot support parsing correctly input as Tensor[] now @@ -4502,7 +4502,10 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens def aten_sum( - self: TensorType, dim: Optional[int] = None, keepdim: bool = False, dtype: int = -1 # pylint: disable=unused-argument + self: TensorType, + dim: Optional[int] = None, + keepdim: bool = False, + dtype: int = -1 ) -> TensorType: # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() diff --git a/onnxscript/function_libs/torch_aten/ops/special.py b/onnxscript/function_libs/torch_aten/ops/special.py index 3e590400b6..0b88e81506 100644 --- a/onnxscript/function_libs/torch_aten/ops/special.py +++ b/onnxscript/function_libs/torch_aten/ops/special.py @@ -13,7 +13,6 @@ from typing import Optional, Sequence -from onnxscript.function_libs.torch_aten.registration import torch_op from onnxscript.onnx_types import TensorType @@ -203,12 +202,9 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_special_log_softmax( - self: TensorType, dim: int, dtype: int = -1 -) -> TensorType: +def aten_special_log_softmax(self: TensorType, dim: int, dtype: int = -1) -> TensorType: # special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor - raise NotImplementedError() From 6834a041219de9dfdca60047511fddfaf33bfd87 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 6 Jan 2023 21:23:57 +0800 Subject: [PATCH 31/32] fix lint --- onnxscript/function_libs/torch_aten/ops/core.py | 10 ++-------- .../function_libs/torch_aten/ops_correctness_test.py | 3 +-- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 382f416e15..024c005938 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -667,12 +667,10 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::cat") def aten_cat(tensors: Sequence[TensorType], dim: Optional[int] = 0) -> TensorType: # cat(Tensor[] tensors, int dim=0) -> Tensor - # TODO: onnxscript cannot support parsing correctly input as Tensor[] now - return op.ConcatFromSequence(tensors, axis=dim) # type: ignore[arg-type] + raise NotImplementedError() def aten_ccol_indices(self: TensorType) -> TensorType: @@ -4502,13 +4500,9 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens def aten_sum( - self: TensorType, - dim: Optional[int] = None, - keepdim: bool = False, - dtype: int = -1 + self: TensorType, dim: Optional[int] = None, keepdim: bool = False, dtype: int = -1 ) -> TensorType: # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor - # since op.Sum() is element-wise sum, so we have to use op.ReduceSum() raise NotImplementedError() diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 5b8d8ff68e..8521a656d0 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -158,7 +158,6 @@ def wrapped(fn): "atan": core_ops.aten_atan, "atanh": core_ops.aten_atanh, "bmm": core_ops.aten_bmm, - "cat": core_ops.aten_cat, "ceil": core_ops.aten_ceil, "clamp_max": core_ops.aten_clamp_max, "clamp_min": core_ops.aten_clamp_min, @@ -243,7 +242,7 @@ def wrapped(fn): ), skip( "expand", - matcher=lambda sample: (np.array(sample.args[0]) > 0).all() == False, + matcher=lambda sample: (np.array(sample.args[0]) > 0).all() is np.bool_(False), reason="Negative value is not supported", ), skip( From 72d638112ddb640e08f75eb1c6a8e6236b327c2e Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Sat, 7 Jan 2023 11:19:24 +0800 Subject: [PATCH 32/32] fix comments --- .../function_libs/torch_aten/ops/core.py | 32 +++++++++---------- .../torch_aten/ops_correctness_test.py | 3 +- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 024c005938..4725d28881 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -667,7 +667,7 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -def aten_cat(tensors: Sequence[TensorType], dim: Optional[int] = 0) -> TensorType: +def aten_cat(tensors: Sequence[TensorType], dim: int = 0) -> TensorType: # cat(Tensor[] tensors, int dim=0) -> Tensor raise NotImplementedError() @@ -819,8 +819,8 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal: @torch_op("aten::clone") def aten_clone( - self: TensorType, memory_format: str = "" # pylint: disable=unused-argument -) -> TensorType: + self: TTensor, memory_format: str = "" # pylint: disable=unused-argument +) -> TTensor: # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor return op.Identity(self) @@ -1424,7 +1424,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2) -> TensorType: @torch_op("aten::div") -def aten_div(self: TensorType, other: TensorType) -> TensorType: +def aten_div(self: TReal, other: TReal) -> TReal: # div.Tensor(Tensor self, Tensor other) -> Tensor return op.Div(self, other) @@ -1551,14 +1551,14 @@ def aten_empty_strided(size: INT64, stride: INT64) -> TensorType: def aten_eq(self: TTensor, other: TTensor) -> BOOL: # eq.Tensor(Tensor self, Tensor other) -> Tensor - return op.Equal(self, other) # type: ignore[arg-type] + return op.Equal(self, other) @torch_op("aten::equal") -def aten_equal(self: TTensor, other: TTensor) -> bool: +def aten_equal(self: TTensor, other: TTensor) -> BOOL: # equal(Tensor self, Tensor other) -> bool - sub_self_other = op.Sub(self, other) # type: ignore[arg-type] + sub_self_other = op.Sub(self, other) abs_sub = op.Abs(sub_self_other) sum_of_abs = op.ReduceSum(abs_sub, keepdims=0) return op.Equal(sum_of_abs, 0) @@ -1600,11 +1600,11 @@ def aten_exp2(self: TFloat) -> TFloat: @torch_op("aten::expand") -def aten_expand(self: TensorType, size: INT64) -> TensorType: +def aten_expand(self: TTensor, size: INT64) -> TTensor: # expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) - size_int64 = op.Cast(size, to=7) # to INT64 - return op.Expand(self, size_int64) + size = op.Cast(size, to=INT64.dtype) # to INT64 + return op.Expand(self, size) def aten_expand_as(self: TensorType, other: TensorType) -> TensorType: @@ -4061,11 +4061,11 @@ def aten_repeat_interleave( @torch_op("aten::reshape") -def aten_reshape(self: TensorType, shape: INT64) -> TensorType: +def aten_reshape(self: TTensor, shape: INT64) -> TTensor: # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) - shape_int64 = op.Cast(shape, to=7) # Reshape only support INT64 as 'shape' - return op.Reshape(self, shape_int64) # type: ignore[arg-type] + shape = op.Cast(shape, to=INT64.dtype) # Reshape only support INT64 as 'shape' + return op.Reshape(self, shape) def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType: @@ -4906,11 +4906,11 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::view") -def aten_view(self: TensorType, size: INT64) -> TensorType: +def aten_view(self: TTensor, size: INT64) -> TTensor: # view(Tensor(a) self, SymInt[] size) -> Tensor(a) - size_int64 = op.Cast(size, to=7) # Reshape only support INT64 as second input - return op.Reshape(self, size_int64) # type: ignore[arg-type] + size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + return op.Reshape(self, size) def aten_view_as(self: TensorType, other: TensorType) -> TensorType: diff --git a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py index 8521a656d0..641c181421 100644 --- a/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/test/function_libs/torch_aten/ops_correctness_test.py @@ -238,7 +238,7 @@ def wrapped(fn): skip( "div", matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, - reason="rounding_mode=True is not supported", + reason="rounding_mode is not yet supported", ), skip( "expand", @@ -388,6 +388,7 @@ def test_output_match(self, device: str, dtype: torch.dtype, op): rtol = None atol = None + # Use torch.testing as opposed to np.testing to ensure dtypes and shapes match torch.testing.assert_close( torch.tensor(function_output), torch.tensor(output_torch),