Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
55f8dd9
fix: annotate script()
justinchuby Dec 9, 2022
8a3a587
feat(atenlib): clamp, lt, gt
justinchuby Dec 9, 2022
f821b6a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 9, 2022
aecc148
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
6555a55
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
f8385b0
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
468f86f
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
47b8380
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
00f1760
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
060f9db
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
497cb16
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
9bb4038
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
d24110a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
cbfb867
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
875f235
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
27008e1
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
3a8737d
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
c5871c8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
012905c
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
49be5ec
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
3a9c5f6
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
d4f09e8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
ee3143e
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
691772b
feat(atenlib): ops 2/n
justinchuby Dec 13, 2022
f160dfa
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
c8e4a54
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
7cd967d
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
e8f07c9
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
c4c80a1
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 13, 2022
7c0e305
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
2db3170
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
b7b03ee
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
5216435
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
93ed77b
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
a5c9629
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 14, 2022
57676c8
Update base for Update on "feat(atenlib): ops 2/n"
justinchuby Dec 15, 2022
3967967
Update on "feat(atenlib): ops 2/n"
justinchuby Dec 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 91 additions & 27 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,24 +739,55 @@ def aten_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType:
raise NotImplementedError()


def aten_clamp(
self: TensorType, min: Optional[float] = None, max: Optional[float] = None
) -> TensorType:
def aten_clamp(self: TensorType, min_=None, max_=None) -> TensorType:
# clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Handle integer inputs
# FIXME(justinchuby): Enable test for this after None values are supported
# TODO(justinchuby): If min is greater than max torch.clamp(..., min, max)
# sets all elements in input to the value of max.
if op.OptionalHasElement(min_):
min_ = op.OptionalGetElement(min_)
min_clamp = op.CastLike(min_, self) # type: ignore[arg-type]
else:
min_clamp = op.Constant(value_float=float("-inf"))

if op.OptionalHasElement(max_):
max_ = op.OptionalGetElement(max_)
max_clamp = op.CastLike(max_, self) # type: ignore[arg-type]
else:
max_clamp = op.Constant(value_float=float("inf"))

# Enforce the lower and upper bounds
clamped = op.Max(op.Min(self, max_clamp), min_clamp) # type: ignore[arg-type]
return clamped

def aten_clamp_max(self: TensorType, max: float) -> TensorType:

def aten_clamp_max_scalar(self, max_):
# clamp_max(Tensor self, Scalar max) -> Tensor

raise NotImplementedError()
max_ = op.CastLike(max_, self)
return op.Clip(self, None, max_)


def aten_clamp_max_tensor(self, max_):
# clamp_max(Tensor self, Scalar max) -> Tensor

return op.Min(self, max_)

def aten_clamp_min(self: TensorType, min: float) -> TensorType:

def aten_clamp_min_scalar(self, min_):
# clamp_min(Tensor self, Scalar min) -> Tensor
# NOTE: min_ is a rank 0 tensor.
# TODO(justinchuby): Specify the type constraints.
min_ = op.CastLike(min_, self)
return op.Clip(self, min_, None)

raise NotImplementedError()

def aten_clamp_min_tensor(self, min_):
# clamp_min(Tensor self, Tensor min) -> Tensor
# TODO(justinchuby): Specify the type constraints.
return op.Max(self, min_)


def aten_clip(
Expand Down Expand Up @@ -1958,10 +1989,12 @@ def aten_gru_cell(
raise NotImplementedError()


def aten_gt(self: TensorType, other: TensorType) -> TensorType:
def aten_gt(self, other):
# gt.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Input spec: non bool tensor
# Boolean inputs can be pre-casted by policy
return op.Greater(self, other)


def aten_hamming_window(window_length: int) -> TensorType:
Expand Down Expand Up @@ -2572,10 +2605,12 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


def aten_lt(self: TensorType, other: TensorType) -> TensorType:
def aten_lt(self, other):
# lt.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Input spec: non bool tensor
# Boolean inputs can be pre-casted by policy
return op.Less(self, other)


def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType:
Expand Down Expand Up @@ -3428,22 +3463,29 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:
raise NotImplementedError()


def aten_numpy_T(self: TensorType) -> TensorType:
# numpy_T(Tensor(a) self) -> Tensor(a)

raise NotImplementedError()


def aten_ones(size: INT64) -> TensorType:
def aten_ones(size: INT64, dtype: int = -1) -> TensorType:
# ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

raise NotImplementedError()
one = op.Constant(value_float=1)
if dtype != -1:
one = op.Cast(one, to=dtype) # type: ignore[arg-type]
return op.Expand(one, size) # type: ignore[arg-type]


def aten_ones_like(self, dtype: int = -1):
"""ones_like.

def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype
before calling this function.
"""
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

raise NotImplementedError()
shape = op.Shape(self)
if dtype == -1:
one = op.CastLike(1, self) # type: ignore[arg-type]
else:
one = op.Cast(1, to=dtype) # type: ignore[arg-type]
return op.Expand(one, shape)


def aten_or(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -3916,10 +3958,19 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
raise NotImplementedError()


def aten_repeat(self: TensorType, repeats: INT64) -> TensorType:
def aten_repeat(self, repeats: INT64):
# repeat(Tensor self, SymInt[] repeats) -> Tensor

raise NotImplementedError()
# FIXME(justinchuby): When repeats.shape == [0]

# TODO(justinchuby): Make ones_like a function when onnxscript supports it
# shape = ones_like(repeats) := {
one = op.Constant(value_int=1)
repeats_shape = op.Shape(repeats)
shape = op.Expand(one, repeats_shape)
# }
self_expanded = op.Expand(self, shape) # type: ignore[arg-type]
return op.Tile(self_expanded, repeats)


def aten_repeat_interleave(
Expand Down Expand Up @@ -4012,10 +4063,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
raise NotImplementedError()


def aten_round(self: TensorType) -> TensorType:
def aten_round(self):
# round(Tensor self) -> Tensor

raise NotImplementedError()
return op.Round(self)


def aten_row_indices(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -4407,7 +4458,13 @@ def aten_symeig(
def aten_t(self: TensorType) -> TensorType:
# t(Tensor(a) self) -> Tensor(a)

raise NotImplementedError()
# TODO(justinchuby): Make rank a function
rank = op.Shape(op.Shape(self))
if rank == 0 or rank == 1:
result = self
else:
result = op.Transpose(self, perm=[1, 0])
return result


def aten_t_copy(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -4552,6 +4609,13 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
raise NotImplementedError()


def aten_transpose(self, dim0: int, dim1: int):
# transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)

# FIXME(justinchuby): onnxscript raises Unsupported expression type
return op.Transpose(self, [dim0, dim1])


def aten_triangular_solve(
self: TensorType,
A: TensorType,
Expand Down
Loading