Skip to content

add reshape/view/expand/div/clone/eq/equal #265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
4687189
add reshape/view
xiaowuhu Dec 28, 2022
8ba5b11
Delete aaa.onnx
xiaowuhu Dec 28, 2022
d6cd5b1
Update core.py
xiaowuhu Dec 28, 2022
87a7a82
add ops
xiaowuhu Dec 31, 2022
f5956be
Update core.py
xiaowuhu Dec 31, 2022
fa7fefc
add clone
xiaowuhu Dec 31, 2022
252f05f
add more ops
xiaowuhu Dec 31, 2022
b5da543
add ops
xiaowuhu Jan 1, 2023
d21219d
add ops
xiaowuhu Jan 1, 2023
c8ead64
fix bug
xiaowuhu Jan 1, 2023
0a250fd
Update ops_correctness_test.py
xiaowuhu Jan 1, 2023
d8b8e33
Update ops_correctness_test.py
xiaowuhu Jan 1, 2023
a7952d2
Merge branch 'main' into xiaowu/trySome2n
xiaowuhu Jan 4, 2023
744f2f3
update
xiaowuhu Jan 4, 2023
188690f
Merge remote-tracking branch 'upstream/main' into xiaowu/trySome2n
xiaowuhu Jan 5, 2023
62c2ebd
fix comments
xiaowuhu Jan 5, 2023
81ae886
Update core.py
xiaowuhu Jan 5, 2023
e214859
Update core.py
xiaowuhu Jan 5, 2023
59cf7f5
fix issues
xiaowuhu Jan 5, 2023
85773d4
fix pylint
xiaowuhu Jan 5, 2023
ceec613
fix lint
xiaowuhu Jan 5, 2023
536b785
Update core.py
xiaowuhu Jan 5, 2023
2913eac
fix lint
xiaowuhu Jan 5, 2023
2fd768b
fix failed case
xiaowuhu Jan 5, 2023
632bce8
Update ops_correctness_test.py
xiaowuhu Jan 5, 2023
4979616
fix bug
xiaowuhu Jan 5, 2023
e945690
Merge branch 'main' into xiaowu/trySome2n
xiaowuhu Jan 6, 2023
0082b23
fix issues
xiaowuhu Jan 6, 2023
43263a6
Merge branch 'xiaowu/trySome2n' of https://github.com/microsoft/onnx-…
xiaowuhu Jan 6, 2023
4d75d22
Merge remote-tracking branch 'upstream/main' into xiaowu/trySome2n
xiaowuhu Jan 6, 2023
c0bbc62
fix comments
xiaowuhu Jan 6, 2023
31e8bb0
Update values.py
xiaowuhu Jan 6, 2023
2a6c291
Update core.py
xiaowuhu Jan 6, 2023
5ba21fb
remove logsoftmax
xiaowuhu Jan 6, 2023
2bdb0c0
remove import
xiaowuhu Jan 6, 2023
6834a04
fix lint
xiaowuhu Jan 6, 2023
72d6381
fix comments
xiaowuhu Jan 7, 2023
5a3713e
Merge branch 'main' into xiaowu/trySome2n
xiaowuhu Jan 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,12 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType:
raise NotImplementedError()


@torch_op("aten::cat")
def aten_cat(tensors: Sequence[TensorType], dim: int = 0) -> TensorType:
# cat(Tensor[] tensors, int dim=0) -> Tensor
# TODO: onnxscript cannot support parsing correctly input as Tensor[] now

raise NotImplementedError()
return op.Concat(tensors, axis=dim)


def aten_ccol_indices(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -803,10 +805,11 @@ def aten_clamp_min(self, min_):
return result


def aten_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@torch_op("aten::clone")
def aten_clone(self: TensorType, memory_format: str = None) -> TensorType:
# clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor

raise NotImplementedError()
return op.CastLike(self, self)


def aten_coalesce(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -1406,10 +1409,11 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2) -> TensorType:
raise NotImplementedError()


@torch_op("aten::div")
def aten_div(self: TensorType, other: TensorType) -> TensorType:
# div.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.Div(self, other)


def aten_divide(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -1529,16 +1533,21 @@ def aten_empty_strided(size: INT64, stride: INT64) -> TensorType:
raise NotImplementedError()


@torch_op("aten::eq")
def aten_eq(self: TensorType, other: TensorType) -> TensorType:
# eq.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.Equal(self, other)


@torch_op("aten::equal")
def aten_equal(self: TensorType, other: TensorType) -> bool:
# equal(Tensor self, Tensor other) -> bool

raise NotImplementedError()
sub_self_other = op.Sub(self, other)
abs_sub = op.Abs(sub_self_other)
sum_of_abs = op.ReduceSum(abs_sub)
return sum_of_abs == 0


def aten_erf(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -1575,10 +1584,11 @@ def aten_exp2(self):
return op.Pow(two, self) # type: ignore[arg-type]


@torch_op("aten::expand")
def aten_expand(self: TensorType, size: INT64, implicit: bool = False) -> TensorType:
# expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)

raise NotImplementedError()
return op.Expand(self, size)


def aten_expand_as(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -4005,10 +4015,11 @@ def aten_repeat_interleave(
raise NotImplementedError()


@torch_op("aten::reshape")
def aten_reshape(self: TensorType, shape: INT64) -> TensorType:
# reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)

raise NotImplementedError()
return op.Reshape(self, shape) # type: ignore[arg-type]


def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -4249,16 +4260,18 @@ def aten_sinh(self):
return op.Sinh(self)


@torch_op("aten::slice")
def aten_slice(
self: TensorType,
dim: int = 0,
start: Optional[INT64] = None,
end: Optional[INT64] = None,
start: INT64 = None,
end: INT64 = None,
step: INT64 = 1,
) -> TensorType:
# slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)

raise NotImplementedError()
return op.Slice(self, start, end, dim, step)



def aten_slice_backward(
Expand Down Expand Up @@ -4438,10 +4451,12 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens
raise NotImplementedError()


def aten_sum(self: TensorType, dtype: Optional[int] = None) -> TensorType:
@torch_op("aten::sum")
def aten_sum(self: TensorType, dtype: int = None) -> TensorType:
# sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
# since op.Sum() is element-wise sum, so we have to use op.ReduceSum()

raise NotImplementedError()
return op.ReduceSum(self, keepdims=0)


def aten_sum_to_size(self: TensorType, size: Sequence[int]) -> TensorType:
Expand Down Expand Up @@ -4840,11 +4855,11 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::view")
def aten_view(self: TensorType, size: INT64) -> TensorType:
# view(Tensor(a) self, SymInt[] size) -> Tensor(a)

raise NotImplementedError()

return op.Reshape(self, size) # type: ignore[arg-type]

def aten_view_as(self: TensorType, other: TensorType) -> TensorType:
# view_as(Tensor(a) self, Tensor other) -> Tensor(a)
Expand Down
7 changes: 5 additions & 2 deletions onnxscript/function_libs/torch_aten/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from typing import Optional, Sequence

from onnxscript.function_libs.torch_aten.registration import torch_op
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


Expand Down Expand Up @@ -205,12 +207,13 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::log_softmax")
def aten_special_log_softmax(
self: TensorType, dim: int, dtype: Optional[int] = None
self: TensorType, dim: int, dtype: int = None
) -> TensorType:
# special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor

raise NotImplementedError()
return op.LogSoftmax(self, axis=dim)


def aten_special_logit(self: TensorType, eps: Optional[float] = None) -> TensorType:
Expand Down
16 changes: 16 additions & 0 deletions onnxscript/test/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import onnxscript
from onnxscript.function_libs.torch_aten.ops import core as core_ops
from onnxscript.function_libs.torch_aten.ops import nn as nn_ops
from onnxscript.function_libs.torch_aten.ops import special as special_ops

T = TypeVar("T")

Expand Down Expand Up @@ -170,17 +171,24 @@ def wrapped(fn):
"atan": core_ops.aten_atan,
"atanh": core_ops.aten_atanh,
"bmm": core_ops.aten_bmm,
# "cat": core_ops.aten_cat, # TODO(xiaowuhu): Enable after fix: it cannot parse Sequence[tensor] as input
"ceil": core_ops.aten_ceil,
"clamp_max": core_ops.aten_clamp_max,
"clamp_min": core_ops.aten_clamp_min,
"clamp": core_ops.aten_clamp,
"clone": core_ops.aten_clone,
"cos": core_ops.aten_cos,
"cosh": core_ops.aten_cosh,
"div": core_ops.aten_div,
"dot": core_ops.aten_dot,
"eq": core_ops.aten_eq,
"equal": core_ops.aten_equal,
"exp": core_ops.aten_exp,
"exp2": core_ops.aten_exp2,
"expand": core_ops.aten_expand,
"gt": core_ops.aten_gt,
"lt": core_ops.aten_lt,
"log_softmax": special_ops.aten_special_log_softmax,
"matmul": core_ops.aten_matmul,
"mm": core_ops.aten_mm,
"mul": core_ops.aten_mul,
Expand All @@ -191,14 +199,18 @@ def wrapped(fn):
"ones_like": core_ops.aten_ones_like,
"ones": core_ops.aten_ones,
"repeat": core_ops.aten_repeat,
"reshape": core_ops.aten_reshape,
"round": core_ops.aten_round,
"sin": core_ops.aten_sin,
"sinh": core_ops.aten_sinh,
"slice": core_ops.aten_slice,
"sub": core_ops.aten_sub,
"sum": core_ops.aten_sum,
"t": core_ops.aten_t,
"tan": core_ops.aten_tan,
"tanh": core_ops.aten_tanh,
# "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed,
"view": core_ops.aten_view,
"zeros": core_ops.aten_zeros,
"zeros_like": core_ops.aten_zeros_like,
}
Expand Down Expand Up @@ -341,6 +353,10 @@ def wrapped(fn):
reason="Sinh is not defined on bool or int tensors",
),
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
xfail(
"sum",
dtypes=except(torch.bfloat16, torch.int32, torch.int64, torch.double, torch.float32, torch.float16),
reason="Sum is not defined on bool tensors"),
xfail(
"tan",
dtypes=BOOL_TYPES + INT_TYPES,
Expand Down
7 changes: 6 additions & 1 deletion onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,13 @@ def adapt(input: ExtendedModeValue) -> EagerModeValue:
return input
elif isinstance(input, (bool, int, float)):
return tensor.Tensor(np.array(input))
elif isinstance(input, int):
return input
elif input is None:
return None
elif isinstance(input, list):
return [adapt(elt) for elt in input]
return input
# return [adapt(elt) for elt in input]
elif isinstance(input, tuple):
return tuple(adapt(elt) for elt in input)
raise TypeError(f"Unexpected input type {type(input)}.")
Expand Down Expand Up @@ -236,6 +239,8 @@ def _adapt_to_user_mode(output: ExtendedModeValue) -> UserModeValue:
return tuple(_adapt_to_user_mode(elt) for elt in output)
elif isinstance(output, np.ndarray):
return output
elif isinstance(output, (bool, int, float)):
return output
raise TypeError(f"Unexpected type {type(output)}.")


Expand Down