-
Notifications
You must be signed in to change notification settings - Fork 107
feat(atenlib): implement aten functions 1/n #247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
55f8dd9
8a3a587
f821b6a
aecc148
6555a55
f8385b0
468f86f
47b8380
00f1760
060f9db
497cb16
9bb4038
d24110a
cbfb867
875f235
27008e1
3a8737d
c5871c8
012905c
49be5ec
3a9c5f6
d4f09e8
ee3143e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| """Commonly shared functions for the function library.""" | ||
| from __future__ import annotations | ||
|
|
||
| import onnxscript | ||
| from onnxscript.onnx_opset import opset18 as op | ||
|
|
||
|
|
||
| @onnxscript.script() | ||
| def ones_like(x, dtype: int): | ||
| shape = op.Shape(x) | ||
| one_dtype = op.Cast(1, to=dtype) | ||
| return op.Expand(one_dtype, shape) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,24 @@ | |
|
|
||
| from typing import Any, Optional, Sequence | ||
|
|
||
| import onnx.helper | ||
|
||
|
|
||
| import onnxscript | ||
| from onnxscript import BOOL, INT64 | ||
| from onnxscript.onnx_opset import default_opset as op | ||
| from onnxscript.function_libs.torch_aten.ops import common | ||
| from onnxscript.onnx_opset import opset18 as op | ||
| from onnxscript.onnx_types import TensorType | ||
|
|
||
|
|
||
| @onnxscript.script() | ||
| def _ones_like(x, dtype: int): | ||
| """Common function for ones_like.""" | ||
| # TODO(justinchuby): Put this in another module | ||
| shape = op.Shape(x) | ||
| one_dtype = op.Cast(1, to=dtype) | ||
|
||
| return op.Expand(one_dtype, shape) | ||
|
|
||
|
|
||
| def aten_abs(self: TensorType) -> TensorType: | ||
| # abs(Tensor self) -> Tensor | ||
|
|
||
|
|
@@ -747,16 +760,31 @@ def aten_clamp( | |
| raise NotImplementedError() | ||
|
|
||
|
|
||
| def aten_clamp_max(self: TensorType, max: float) -> TensorType: | ||
| def aten_clamp_max_scalar(self, max_): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to validate the shape of max_ to know it is a scalar or tensor instead of having such information in the function name?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't thought of a good way. Any suggestions?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, for validation, we can specify it in the signature / via some data structure. We can discuss today.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They still need to be two functions though. |
||
| # clamp_max(Tensor self, Scalar max) -> Tensor | ||
|
|
||
| raise NotImplementedError() | ||
| max_ = op.CastLike(max_, self) | ||
| return op.Clip(self, None, max_) | ||
|
|
||
|
|
||
| def aten_clamp_max_tensor(self, max_): | ||
| # clamp_max(Tensor self, Scalar max) -> Tensor | ||
|
|
||
| return op.Min(self, max_) | ||
|
|
||
|
|
||
| def aten_clamp_min(self: TensorType, min: float) -> TensorType: | ||
| def aten_clamp_min_scalar(self, min_): | ||
| # clamp_min(Tensor self, Scalar min) -> Tensor | ||
| # NOTE: min_ is a rank 0 tensor. | ||
| # TODO(justinchuby): Specify the type constraints. | ||
| min_ = op.CastLike(min_, self) | ||
| return op.Clip(self, min_, None) | ||
|
|
||
| raise NotImplementedError() | ||
|
|
||
| def aten_clamp_min_tensor(self, min_): | ||
| # clamp_min(Tensor self, Tensor min) -> Tensor | ||
| # TODO(justinchuby): Specify the type constraints. | ||
| return op.Max(self, min_) | ||
|
|
||
|
|
||
| def aten_clip( | ||
|
|
@@ -1958,10 +1986,12 @@ def aten_gru_cell( | |
| raise NotImplementedError() | ||
|
|
||
|
|
||
| def aten_gt(self: TensorType, other: TensorType) -> TensorType: | ||
| def aten_gt(self, other): | ||
| # gt.Tensor(Tensor self, Tensor other) -> Tensor | ||
|
|
||
| raise NotImplementedError() | ||
| # TODO(justinchuby): Input spec: non bool tensor | ||
| # Boolean inputs can be pre-casted by policy | ||
| return op.Greater(self, other) | ||
|
|
||
|
|
||
| def aten_hamming_window(window_length: int) -> TensorType: | ||
|
|
@@ -2572,10 +2602,12 @@ def aten_lstm_mps_backward( | |
| raise NotImplementedError() | ||
|
|
||
|
|
||
| def aten_lt(self: TensorType, other: TensorType) -> TensorType: | ||
| def aten_lt(self, other): | ||
| # lt.Tensor(Tensor self, Tensor other) -> Tensor | ||
|
|
||
| raise NotImplementedError() | ||
| # TODO(justinchuby): Input spec: non bool tensor | ||
| # Boolean inputs can be pre-casted by policy | ||
| return op.Less(self, other) | ||
|
|
||
|
|
||
| def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: | ||
|
|
@@ -3440,10 +3472,15 @@ def aten_ones(size: INT64) -> TensorType: | |
| raise NotImplementedError() | ||
|
|
||
|
|
||
| def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType: | ||
| def aten_ones_like(self, dtype: int = onnx.TensorProto.FLOAT): | ||
| """ones_like. | ||
|
|
||
| Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype | ||
| before calling this function. | ||
| """ | ||
| # ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor | ||
|
|
||
| raise NotImplementedError() | ||
| return common.ones_like(self, dtype) | ||
|
||
|
|
||
|
|
||
| def aten_or(self: TensorType, other: TensorType) -> TensorType: | ||
|
|
@@ -3916,10 +3953,12 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT | |
| raise NotImplementedError() | ||
|
|
||
|
|
||
| def aten_repeat(self: TensorType, repeats: INT64) -> TensorType: | ||
| def aten_repeat(self, repeats: INT64): | ||
| # repeat(Tensor self, SymInt[] repeats) -> Tensor | ||
|
|
||
| raise NotImplementedError() | ||
| shape = _ones_like(repeats, onnx.TensorProto.INT64) | ||
| expanded = op.Expand(self, shape) | ||
| return op.Tile(expanded, repeats) | ||
|
|
||
|
|
||
| def aten_repeat_interleave( | ||
|
|
@@ -4012,10 +4051,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te | |
| raise NotImplementedError() | ||
|
|
||
|
|
||
| def aten_round(self: TensorType) -> TensorType: | ||
| def aten_round(self): | ||
| # round(Tensor self) -> Tensor | ||
|
|
||
| raise NotImplementedError() | ||
| return op.Round(self) | ||
|
|
||
|
|
||
| def aten_row_indices(self: TensorType) -> TensorType: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.