Skip to content

feat(atenlib): ops 4/n #256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
55f8dd9
fix: annotate script()
justinchuby Dec 9, 2022
8a3a587
feat(atenlib): clamp, lt, gt
justinchuby Dec 9, 2022
f821b6a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 9, 2022
aecc148
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
6555a55
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
f8385b0
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
468f86f
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
47b8380
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
00f1760
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
060f9db
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
497cb16
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
9bb4038
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
d24110a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
cbfb867
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
875f235
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
27008e1
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
3a8737d
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
c5871c8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
012905c
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
49be5ec
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
3a9c5f6
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
d4f09e8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
ee3143e
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
691772b
feat(atenlib): ops 2/n
justinchuby Dec 13, 2022
f160dfa
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
c8e4a54
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
7cd967d
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
e8f07c9
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
c4c80a1
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
7c0e305
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
2db3170
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
b7b03ee
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
5216435
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
042912d
feat(atenlib): ops 3/n
justinchuby Dec 14, 2022
d1c47c4
Update on "feat(atenlib): ops 3/n"
justinchuby Dec 14, 2022
f5ff175
feat(atenlib): ops 3/n
justinchuby Dec 14, 2022
fbfe507
Update base for Update on "feat(atenlib): ops 4/n"
justinchuby Dec 14, 2022
1aa6703
Update on "feat(atenlib): ops 4/n"
justinchuby Dec 14, 2022
7b1797c
Update base for Update on "feat(atenlib): ops 4/n"
justinchuby Dec 14, 2022
543997d
Update on "feat(atenlib): ops 4/n"
justinchuby Dec 14, 2022
1e65d5b
Update base for Update on "feat(atenlib): ops 4/n"
justinchuby Dec 15, 2022
78b5161
Update on "feat(atenlib): ops 4/n"
justinchuby Dec 15, 2022
9e2f306
Update base for Update on "feat(atenlib): ops 4/n"
justinchuby Dec 15, 2022
d2d0522
Update on "feat(atenlib): ops 4/n"
justinchuby Dec 15, 2022
f796bb0
Merge branch 'main' into gh/justinchuby/14/head
justinchuby Dec 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 67 additions & 79 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,22 @@
from onnxscript.onnx_types import TensorType


def aten_abs(self: TensorType) -> TensorType:
def aten_abs(self):
# abs(Tensor self) -> Tensor

raise NotImplementedError()

return op.Abs(self)

def aten_absolute(self: TensorType) -> TensorType:
# absolute(Tensor self) -> Tensor

raise NotImplementedError()


def aten_acos(self: TensorType) -> TensorType:
def aten_acos(self):
# acos(Tensor self) -> Tensor

raise NotImplementedError()
return op.Acos(self)


def aten_acosh(self: TensorType) -> TensorType:
def aten_acosh(self):
# acosh(Tensor self) -> Tensor

raise NotImplementedError()
return op.Acosh(self)


def aten_adaptive_avg_pool1d(self: TensorType, output_size: Sequence[int]) -> TensorType:
Expand Down Expand Up @@ -91,12 +85,13 @@ def aten_addcmul(
raise NotImplementedError()


def aten_addmm(
self: TensorType, mat1: TensorType, mat2: TensorType, beta: float = 1, alpha: float = 1
) -> TensorType:
def aten_addmm(self, mat1, mat2, beta: float = 1, alpha: float = 1):
# addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor

raise NotImplementedError()
mat1_mat2 = op.MatMul(mat1, mat2)
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha) # type: ignore[arg-type]
scaled_self = op.Mul(self, beta) # type: ignore[arg-type]
return op.Add(scaled_self, scaled_mat1_mat2) # type: ignore[arg-type]


def aten_addmv(
Expand Down Expand Up @@ -337,22 +332,22 @@ def aten_as_strided_scatter(
raise NotImplementedError()


def aten_asin(self: TensorType) -> TensorType:
def aten_asin(self):
# asin(Tensor self) -> Tensor

raise NotImplementedError()
return op.Asin(self)


def aten_asinh(self: TensorType) -> TensorType:
def aten_asinh(self):
# asinh(Tensor self) -> Tensor

raise NotImplementedError()
return op.Asinh(self)


def aten_atan(self: TensorType) -> TensorType:
def aten_atan(self):
# atan(Tensor self) -> Tensor

raise NotImplementedError()
return op.Atan(self)


def aten_atan2(self: TensorType, other: TensorType) -> TensorType:
Expand All @@ -361,10 +356,10 @@ def aten_atan2(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


def aten_atanh(self: TensorType) -> TensorType:
def aten_atanh(self):
# atanh(Tensor self) -> Tensor

raise NotImplementedError()
return op.Atanh(self)


def aten_atleast_1d(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -611,10 +606,10 @@ def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType:
raise NotImplementedError()


def aten_bmm(self: TensorType, mat2: TensorType) -> TensorType:
def aten_bmm(self, mat2):
# bmm(Tensor self, Tensor mat2) -> Tensor

raise NotImplementedError()
return op.MatMul(self, mat2)


def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType:
Expand Down Expand Up @@ -675,16 +670,10 @@ def aten_cdist(
raise NotImplementedError()


def aten_ceil(self: TensorType) -> TensorType:
def aten_ceil(self):
# ceil(Tensor self) -> Tensor

raise NotImplementedError()


def aten_celu(self: TensorType, alpha: float = 1.0) -> TensorType:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to nn.

# celu(Tensor self, Scalar alpha=1.0) -> Tensor

raise NotImplementedError()
return op.Ceil(self)


def aten_chain_matmul(matrices: Sequence[TensorType]) -> TensorType:
Expand Down Expand Up @@ -790,14 +779,6 @@ def aten_clamp_min_tensor(self, min_):
return op.Max(self, min_)


def aten_clip(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alias of clamp

self: TensorType, min: Optional[float] = None, max: Optional[float] = None
) -> TensorType:
# clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor

raise NotImplementedError()


def aten_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
# clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor

Expand Down Expand Up @@ -1036,16 +1017,16 @@ def aten_corrcoef(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_cos(self: TensorType) -> TensorType:
def aten_cos(self):
# cos(Tensor self) -> Tensor

raise NotImplementedError()
return op.Cos(self)


def aten_cosh(self: TensorType) -> TensorType:
def aten_cosh(self):
# cosh(Tensor self) -> Tensor

raise NotImplementedError()
return op.Cosh(self)


def aten_cosine_embedding_loss(
Expand Down Expand Up @@ -1411,10 +1392,10 @@ def aten_divide(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


def aten_dot(self: TensorType, tensor: TensorType) -> TensorType:
def aten_dot(self, tensor):
# dot(Tensor self, Tensor tensor) -> Tensor

raise NotImplementedError()
return op.MatMul(self, tensor)


def aten_dropout(input: TensorType, p: float, train: bool) -> TensorType:
Expand Down Expand Up @@ -1551,16 +1532,18 @@ def aten_erfinv(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_exp(self: TensorType) -> TensorType:
def aten_exp(self):
# exp(Tensor self) -> Tensor

raise NotImplementedError()
return op.Exp(self)


def aten_exp2(self: TensorType) -> TensorType:
def aten_exp2(self):
# exp2(Tensor self) -> Tensor

raise NotImplementedError()
two = op.Constant(value_int=2)
two = op.CastLike(two, self) # type: ignore[arg-type]
return op.Pow(two, self) # type: ignore[arg-type]


def aten_expand(self: TensorType, size: INT64, implicit: bool = False) -> TensorType:
Expand Down Expand Up @@ -2680,10 +2663,10 @@ def aten_masked_select_backward(
raise NotImplementedError()


def aten_matmul(self: TensorType, other: TensorType) -> TensorType:
def aten_matmul(self, other):
# matmul(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
return op.MatMul(self, other)


def aten_matmul_backward(
Expand Down Expand Up @@ -3080,10 +3063,11 @@ def aten_mkldnn_max_pool3d_backward(
raise NotImplementedError()


def aten_mm(self: TensorType, mat2: TensorType) -> TensorType:
def aten_mm(self, mat2):
# mm(Tensor self, Tensor mat2) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Specify type conversion for uint8/int8/int16
return op.MatMul(self, mat2)


def aten_mode(
Expand Down Expand Up @@ -3463,16 +3447,13 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:
raise NotImplementedError()


def aten_numpy_T(self: TensorType) -> TensorType:
# numpy_T(Tensor(a) self) -> Tensor(a)

raise NotImplementedError()


def aten_ones(size: INT64) -> TensorType:
def aten_ones(size: INT64, dtype: int = -1) -> TensorType:
# ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

raise NotImplementedError()
one = op.Constant(value_float=1)
if dtype != -1:
one = op.Cast(one, to=dtype) # type: ignore[arg-type]
return op.Expand(one, size) # type: ignore[arg-type]


def aten_ones_like(self, dtype: int = -1):
Expand Down Expand Up @@ -4212,22 +4193,16 @@ def aten_signbit(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_sin(self: TensorType) -> TensorType:
def aten_sin(self):
# sin(Tensor self) -> Tensor

raise NotImplementedError()
return op.Sin(self)


def aten_sinc(self: TensorType) -> TensorType:
# sinc(Tensor self) -> Tensor

raise NotImplementedError()


def aten_sinh(self: TensorType) -> TensorType:
def aten_sinh(self):
# sinh(Tensor self) -> Tensor

raise NotImplementedError()
return op.Sinh(self)


def aten_slice(
Expand Down Expand Up @@ -4461,7 +4436,13 @@ def aten_symeig(
def aten_t(self: TensorType) -> TensorType:
# t(Tensor(a) self) -> Tensor(a)

raise NotImplementedError()
# TODO(justinchuby): Make rank a function
rank = op.Size(op.Shape(self)) # type: ignore[arg-type]
if rank == 0 or rank == 1: # pylint: disable=consider-using-in
result = self
else:
result = op.Transpose(self, perm=[1, 0]) # type: ignore[arg-type]
return result


def aten_t_copy(self: TensorType) -> TensorType:
Expand All @@ -4484,16 +4465,16 @@ def aten_take_along_dim(
raise NotImplementedError()


def aten_tan(self: TensorType) -> TensorType:
def aten_tan(self):
# tan(Tensor self) -> Tensor

raise NotImplementedError()
return op.Tan(self)


def aten_tanh(self: TensorType) -> TensorType:
def aten_tanh(self):
# tanh(Tensor self) -> Tensor

raise NotImplementedError()
return op.Tanh(self)


def aten_tensordot(
Expand Down Expand Up @@ -4606,6 +4587,13 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
raise NotImplementedError()


def aten_transpose(self, dim0: int, dim1: int):
# transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)

# FIXME(justinchuby): onnxscript raises Unsupported expression type
return op.Transpose(self, [dim0, dim1])


def aten_triangular_solve(
self: TensorType,
A: TensorType,
Expand Down
23 changes: 19 additions & 4 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def aten_binary_cross_entropy_backward(
raise NotImplementedError()


def aten_celu(self, alpha: float = 1.0):
# celu(Tensor self, Scalar alpha=1.0) -> Tensor

raise NotImplementedError()


def aten_col2im(
self: TensorType,
output_size: INT64,
Expand Down Expand Up @@ -407,12 +413,21 @@ def aten_leaky_relu_backward(
raise NotImplementedError()


def aten_linear(
input: TensorType, weight: TensorType, bias: Optional[TensorType] = None
) -> TensorType:
def aten_linear(input, weight, bias=None) -> TensorType:
# linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor

raise NotImplementedError()
# FIXME(justinchuby): Enable the test
# INVALID_GRAPH : This is an invalid model.
# In Node, ("", OptionalHasElement, "", -1) : () -> ("output0",) ,
# Error Node () has input size 0 not in range [min=1, max=1]

# NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases
# Optimizers may consider this path and replace it with Gemm
result = op.MatMul(input, weight)
if op.OptionalHasElement(bias):
bias = op.OptionalGetElement(bias)
result = op.Add(result, bias) # type: ignore[arg-type]
return result


def aten_log_sigmoid(self: TensorType) -> TensorType:
Expand Down
Loading