Skip to content

add reshape/view/expand/div/clone/eq/equal #265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
4687189
add reshape/view
xiaowuhu Dec 28, 2022
8ba5b11
Delete aaa.onnx
xiaowuhu Dec 28, 2022
d6cd5b1
Update core.py
xiaowuhu Dec 28, 2022
87a7a82
add ops
xiaowuhu Dec 31, 2022
f5956be
Update core.py
xiaowuhu Dec 31, 2022
fa7fefc
add clone
xiaowuhu Dec 31, 2022
252f05f
add more ops
xiaowuhu Dec 31, 2022
b5da543
add ops
xiaowuhu Jan 1, 2023
d21219d
add ops
xiaowuhu Jan 1, 2023
c8ead64
fix bug
xiaowuhu Jan 1, 2023
0a250fd
Update ops_correctness_test.py
xiaowuhu Jan 1, 2023
d8b8e33
Update ops_correctness_test.py
xiaowuhu Jan 1, 2023
a7952d2
Merge branch 'main' into xiaowu/trySome2n
xiaowuhu Jan 4, 2023
744f2f3
update
xiaowuhu Jan 4, 2023
188690f
Merge remote-tracking branch 'upstream/main' into xiaowu/trySome2n
xiaowuhu Jan 5, 2023
62c2ebd
fix comments
xiaowuhu Jan 5, 2023
81ae886
Update core.py
xiaowuhu Jan 5, 2023
e214859
Update core.py
xiaowuhu Jan 5, 2023
59cf7f5
fix issues
xiaowuhu Jan 5, 2023
85773d4
fix pylint
xiaowuhu Jan 5, 2023
ceec613
fix lint
xiaowuhu Jan 5, 2023
536b785
Update core.py
xiaowuhu Jan 5, 2023
2913eac
fix lint
xiaowuhu Jan 5, 2023
2fd768b
fix failed case
xiaowuhu Jan 5, 2023
632bce8
Update ops_correctness_test.py
xiaowuhu Jan 5, 2023
4979616
fix bug
xiaowuhu Jan 5, 2023
e945690
Merge branch 'main' into xiaowu/trySome2n
xiaowuhu Jan 6, 2023
0082b23
fix issues
xiaowuhu Jan 6, 2023
43263a6
Merge branch 'xiaowu/trySome2n' of https://github.com/microsoft/onnx-…
xiaowuhu Jan 6, 2023
4d75d22
Merge remote-tracking branch 'upstream/main' into xiaowu/trySome2n
xiaowuhu Jan 6, 2023
c0bbc62
fix comments
xiaowuhu Jan 6, 2023
31e8bb0
Update values.py
xiaowuhu Jan 6, 2023
2a6c291
Update core.py
xiaowuhu Jan 6, 2023
5ba21fb
remove logsoftmax
xiaowuhu Jan 6, 2023
2bdb0c0
remove import
xiaowuhu Jan 6, 2023
6834a04
fix lint
xiaowuhu Jan 6, 2023
72d6381
fix comments
xiaowuhu Jan 7, 2023
5a3713e
Merge branch 'main' into xiaowu/trySome2n
xiaowuhu Jan 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,10 +803,13 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal:
return result


def aten_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@torch_op("aten::clone")
def aten_clone(
self: TTensor, memory_format: str = "" # pylint: disable=unused-argument
) -> TTensor:
# clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor

raise NotImplementedError()
return op.Identity(self)


def aten_coalesce(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -1406,10 +1409,11 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2) -> TensorType:
raise NotImplementedError()


def aten_div(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::div")
def aten_div(self: TReal, other: TReal) -> TReal:
# div.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.Div(self, other)


def aten_divide(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -1529,16 +1533,21 @@ def aten_empty_strided(size: INT64, stride: INT64) -> TensorType:
raise NotImplementedError()


def aten_eq(self: TensorType, other: TensorType) -> TensorType:
@torch_op("aten::eq")
def aten_eq(self: TTensor, other: TTensor) -> BOOL:
# eq.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.Equal(self, other)


def aten_equal(self: TensorType, other: TensorType) -> bool:
@torch_op("aten::equal")
def aten_equal(self: TTensor, other: TTensor) -> BOOL:
# equal(Tensor self, Tensor other) -> bool

raise NotImplementedError()
sub_self_other = op.Sub(self, other)
abs_sub = op.Abs(sub_self_other)
sum_of_abs = op.ReduceSum(abs_sub, keepdims=0)
return op.Equal(sum_of_abs, 0)


@torch_op("aten::erf")
Expand Down Expand Up @@ -1576,10 +1585,12 @@ def aten_exp2(self: TFloat) -> TFloat:
return op.Pow(two, self)


def aten_expand(self: TensorType, size: INT64, implicit: bool = False) -> TensorType:
@torch_op("aten::expand")
def aten_expand(self: TTensor, size: INT64) -> TTensor:
# expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)

raise NotImplementedError()
size = op.Cast(size, to=INT64.dtype) # to INT64
return op.Expand(self, size)


def aten_expand_as(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -4046,10 +4057,12 @@ def aten_repeat_interleave(
raise NotImplementedError()


def aten_reshape(self: TensorType, shape: INT64) -> TensorType:
@torch_op("aten::reshape")
def aten_reshape(self: TTensor, shape: INT64) -> TTensor:
# reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)

raise NotImplementedError()
shape = op.Cast(shape, to=INT64.dtype) # Reshape only support INT64 as 'shape'
return op.Reshape(self, shape)


def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -4484,7 +4497,9 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens
raise NotImplementedError()


def aten_sum(self: TensorType, dtype: Optional[int] = None) -> TensorType:
def aten_sum(
self: TensorType, dim: Optional[int] = None, keepdim: bool = False, dtype: int = -1
) -> TensorType:
# sum(Tensor self, *, ScalarType? dtype=None) -> Tensor

raise NotImplementedError()
Expand Down Expand Up @@ -4903,10 +4918,12 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


def aten_view(self: TensorType, size: INT64) -> TensorType:
@torch_op("aten::view")
def aten_view(self: TTensor, size: INT64) -> TTensor:
# view(Tensor(a) self, SymInt[] size) -> Tensor(a)

raise NotImplementedError()
size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
return op.Reshape(self, size)


def aten_view_as(self: TensorType, other: TensorType) -> TensorType:
Expand Down
4 changes: 1 addition & 3 deletions onnxscript/function_libs/torch_aten/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_special_log_softmax(
self: TensorType, dim: int, dtype: Optional[int] = None
) -> TensorType:
def aten_special_log_softmax(self: TensorType, dim: int, dtype: int = -1) -> TensorType:
# special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor

raise NotImplementedError()
Expand Down
19 changes: 18 additions & 1 deletion onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,17 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
"clamp_max": core_ops.aten_clamp_max,
"clamp_min": core_ops.aten_clamp_min,
"clamp": core_ops.aten_clamp,
"clone": core_ops.aten_clone,
"cos": core_ops.aten_cos,
"cosh": core_ops.aten_cosh,
"div": core_ops.aten_div,
"dot": core_ops.aten_dot,
"erf": core_ops.aten_erf,
"eq": core_ops.aten_eq,
"equal": core_ops.aten_equal,
"exp": core_ops.aten_exp,
"exp2": core_ops.aten_exp2,
"expand": core_ops.aten_expand,
"erf": core_ops.aten_erf,
"fmod": core_ops.aten_fmod,
# TODO(justinchuby): Test aten::full
"full_like": core_ops.aten_full_like,
Expand Down Expand Up @@ -242,6 +247,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
"reciprocal": core_ops.aten_reciprocal,
"remainder": core_ops.aten_remainder,
"repeat": core_ops.aten_repeat,
"reshape": core_ops.aten_reshape,
"round": core_ops.aten_round,
"rsqrt": core_ops.aten_rsqrt,
"rsub": core_ops.aten_rsub,
Expand All @@ -256,6 +262,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
"tanh": core_ops.aten_tanh,
"transpose": core_ops.aten_transpose,
"unsqueeze": core_ops.aten_unsqueeze,
"view": core_ops.aten_view,
"where": core_ops.aten_where,
"zeros": core_ops.aten_zeros,
"zeros_like": core_ops.aten_zeros_like,
Expand All @@ -282,6 +289,16 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:


SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
skip(
"div",
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None,
reason="rounding_mode is not yet supported",
),
skip(
"expand",
matcher=lambda sample: (np.array(sample.args[0]) > 0).all() is np.bool_(False),
reason="Negative value is not supported",
),
skip(
"nonzero",
matcher=lambda sample: sample.kwargs.get("as_tuple") is not None,
Expand Down