|
17 | 17 |
|
18 | 18 | from typing import Any, Optional, Sequence |
19 | 19 |
|
| 20 | +import onnx.helper |
| 21 | + |
20 | 22 | from onnxscript import BOOL, INT64 |
21 | | -from onnxscript.onnx_opset import default_opset as op |
| 23 | +from onnxscript.function_libs.torch_aten.ops import common |
| 24 | +from onnxscript.onnx_opset import opset18 as op |
22 | 25 | from onnxscript.onnx_types import TensorType |
23 | 26 |
|
24 | 27 |
|
@@ -747,16 +750,31 @@ def aten_clamp( |
747 | 750 | raise NotImplementedError() |
748 | 751 |
|
749 | 752 |
|
750 | | -def aten_clamp_max(self: TensorType, max: float) -> TensorType: |
| 753 | +def aten_clamp_max_scalar(self, max_): |
751 | 754 | # clamp_max(Tensor self, Scalar max) -> Tensor |
752 | 755 |
|
753 | | - raise NotImplementedError() |
| 756 | + max_ = op.CastLike(max_, self) |
| 757 | + return op.Clip(self, None, max_) |
754 | 758 |
|
755 | 759 |
|
756 | | -def aten_clamp_min(self: TensorType, min: float) -> TensorType: |
| 760 | +def aten_clamp_max_tensor(self, max_): |
| 761 | + # clamp_max(Tensor self, Scalar max) -> Tensor |
| 762 | + |
| 763 | + return op.Min(self, max_) |
| 764 | + |
| 765 | + |
| 766 | +def aten_clamp_min_scalar(self, min_): |
757 | 767 | # clamp_min(Tensor self, Scalar min) -> Tensor |
| 768 | + # NOTE: min_ is a rank 0 tensor. |
| 769 | + # TODO(justinchuby): Specify the type constraints. |
| 770 | + min_ = op.CastLike(min_, self) |
| 771 | + return op.Clip(self, min_, None) |
758 | 772 |
|
759 | | - raise NotImplementedError() |
| 773 | + |
| 774 | +def aten_clamp_min_tensor(self, min_): |
| 775 | + # clamp_min(Tensor self, Tensor min) -> Tensor |
| 776 | + # TODO(justinchuby): Specify the type constraints. |
| 777 | + return op.Max(self, min_) |
760 | 778 |
|
761 | 779 |
|
762 | 780 | def aten_clip( |
@@ -1958,10 +1976,12 @@ def aten_gru_cell( |
1958 | 1976 | raise NotImplementedError() |
1959 | 1977 |
|
1960 | 1978 |
|
1961 | | -def aten_gt(self: TensorType, other: TensorType) -> TensorType: |
| 1979 | +def aten_gt(self, other): |
1962 | 1980 | # gt.Tensor(Tensor self, Tensor other) -> Tensor |
1963 | 1981 |
|
1964 | | - raise NotImplementedError() |
| 1982 | + # TODO(justinchuby): Input spec: non bool tensor |
| 1983 | + # Boolean inputs can be pre-casted by policy |
| 1984 | + return op.Greater(self, other) |
1965 | 1985 |
|
1966 | 1986 |
|
1967 | 1987 | def aten_hamming_window(window_length: int) -> TensorType: |
@@ -2572,10 +2592,12 @@ def aten_lstm_mps_backward( |
2572 | 2592 | raise NotImplementedError() |
2573 | 2593 |
|
2574 | 2594 |
|
2575 | | -def aten_lt(self: TensorType, other: TensorType) -> TensorType: |
| 2595 | +def aten_lt(self, other): |
2576 | 2596 | # lt.Tensor(Tensor self, Tensor other) -> Tensor |
2577 | 2597 |
|
2578 | | - raise NotImplementedError() |
| 2598 | + # TODO(justinchuby): Input spec: non bool tensor |
| 2599 | + # Boolean inputs can be pre-casted by policy |
| 2600 | + return op.Less(self, other) |
2579 | 2601 |
|
2580 | 2602 |
|
2581 | 2603 | def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: |
@@ -3440,10 +3462,17 @@ def aten_ones(size: INT64) -> TensorType: |
3440 | 3462 | raise NotImplementedError() |
3441 | 3463 |
|
3442 | 3464 |
|
3443 | | -def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType: |
| 3465 | +def aten_ones_like(self, dtype: Optional[int] = None): |
| 3466 | + """ones_like. |
| 3467 | +
|
| 3468 | + Note: dtype is a torch enum. We need to convert it to ONNX dtype. |
| 3469 | + """ |
3444 | 3470 | # ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor |
3445 | 3471 |
|
3446 | | - raise NotImplementedError() |
| 3472 | + # TODO(justinchuby): Create a helper to convert torch dtype to ONNX dtype |
| 3473 | + if dtype is None: |
| 3474 | + dtype = onnx.TensorProto.FLOAT |
| 3475 | + return common.ones_like(self, dtype) |
3447 | 3476 |
|
3448 | 3477 |
|
3449 | 3478 | def aten_or(self: TensorType, other: TensorType) -> TensorType: |
@@ -3916,10 +3945,13 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT |
3916 | 3945 | raise NotImplementedError() |
3917 | 3946 |
|
3918 | 3947 |
|
3919 | | -def aten_repeat(self: TensorType, repeats: INT64) -> TensorType: |
| 3948 | +def aten_repeat(self, repeats: INT64): |
3920 | 3949 | # repeat(Tensor self, SymInt[] repeats) -> Tensor |
3921 | 3950 |
|
3922 | | - raise NotImplementedError() |
| 3951 | + # FIXME(justinchuby): 'common' is not an instance of type Opset but <class 'module'>. |
| 3952 | + shape = common.ones_like(repeats, onnx.TensorProto.INT64) |
| 3953 | + expanded = op.Expand(self, shape) |
| 3954 | + return op.Tile(expanded, repeats) |
3923 | 3955 |
|
3924 | 3956 |
|
3925 | 3957 | def aten_repeat_interleave( |
@@ -4012,10 +4044,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te |
4012 | 4044 | raise NotImplementedError() |
4013 | 4045 |
|
4014 | 4046 |
|
4015 | | -def aten_round(self: TensorType) -> TensorType: |
| 4047 | +def aten_round(self): |
4016 | 4048 | # round(Tensor self) -> Tensor |
4017 | 4049 |
|
4018 | | - raise NotImplementedError() |
| 4050 | + return op.Round(self) |
4019 | 4051 |
|
4020 | 4052 |
|
4021 | 4053 | def aten_row_indices(self: TensorType) -> TensorType: |
|
0 commit comments