@@ -48,20 +48,6 @@ def aten_acosh(self: TFloat) -> TFloat:
48
48
return op .Acosh (self )
49
49
50
50
51
- def aten_adaptive_avg_pool1d (self : TensorType , output_size : Sequence [int ]) -> TensorType :
52
- # adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor
53
-
54
- raise NotImplementedError ()
55
-
56
-
57
- def aten_adaptive_max_pool1d (
58
- self : TensorType , output_size : Sequence [int ]
59
- ) -> tuple [TensorType , TensorType ]:
60
- # adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)
61
-
62
- raise NotImplementedError ()
63
-
64
-
65
51
@torch_op ("aten::add" )
66
52
def aten_add (self : TReal , other : TReal , alpha : float = 1 ) -> TReal :
67
53
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
@@ -198,20 +184,20 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
198
184
raise NotImplementedError ()
199
185
200
186
201
- def aten_amax (
202
- self : TensorType , dim : Optional [Sequence [int ]] = None , keepdim : bool = False
203
- ) -> TensorType :
187
+ # @torch_op("aten::amax") # FIXME: Uncomment when CI uses onnx 1.13
188
+ def aten_amax (self : TReal , dim : INT64 , keepdim : int = 0 ) -> TReal :
204
189
# amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
205
190
206
- raise NotImplementedError ()
191
+ # TODO(justinchuby): Make dim optional, keepdim bool
192
+ return op .ReduceMax (self , dim , keepdims = keepdim )
207
193
208
194
209
- def aten_amin (
210
- self : TensorType , dim : Optional [Sequence [int ]] = None , keepdim : bool = False
211
- ) -> TensorType :
195
+ # @torch_op("aten::amin") # FIXME: Uncomment when CI uses onnx 1.13
196
+ def aten_amin (self : TReal , dim : INT64 , keepdim : int = 0 ) -> TReal :
212
197
# amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
213
198
214
- raise NotImplementedError ()
199
+ # TODO(justinchuby): Make dim optional, keepdim bool
200
+ return op .ReduceMin (self , dim , keepdims = keepdim )
215
201
216
202
217
203
def aten_aminmax (
@@ -2190,10 +2176,21 @@ def aten_index_reduce(
2190
2176
raise NotImplementedError ()
2191
2177
2192
2178
2193
- def aten_index_select (self : TensorType , dim : int , index : TensorType ) -> TensorType :
2179
+ # FIXME(#277): Script when attributes can come before inputs
2180
+ @torch_op ("aten::index_select" , trace_only = True )
2181
+ def aten_index_select (self : TTensor , dim : int , index : TInt ) -> TTensor :
2194
2182
# index_select(Tensor self, int dim, Tensor index) -> Tensor
2195
2183
2196
- raise NotImplementedError ()
2184
+ if op .Size (op .Shape (self )) == 0 :
2185
+ result = self
2186
+ else :
2187
+ # Index can be a scalar. Reshape it to a rank 1 tensor.
2188
+ index = op .Reshape (index , op .Constant (value_floats = [- 1 ]))
2189
+ index = op .Cast (index , to = INT64 .dtype )
2190
+
2191
+ result = op .Gather (self , index , axis = dim )
2192
+
2193
+ return result
2197
2194
2198
2195
2199
2196
def aten_index_select_backward (
@@ -4194,10 +4191,11 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
4194
4191
return op .Reciprocal (op .Sqrt (self ))
4195
4192
4196
4193
4197
- def aten_rsub (self : TensorType , other : TensorType , alpha : float = 1 ) -> TensorType :
4194
+ @torch_op ("aten::rsub" )
4195
+ def aten_rsub (self : TReal , other : TReal , alpha : float = 1.0 ) -> TReal :
4198
4196
# rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
4199
4197
4200
- raise NotImplementedError ( )
4198
+ return op . Sub ( other , op . Mul ( self , alpha ) )
4201
4199
4202
4200
4203
4201
def aten_scalar_tensor (s : float ) -> TensorType :
@@ -4698,11 +4696,26 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
4698
4696
raise NotImplementedError ()
4699
4697
4700
4698
4699
+ @torch_op ("aten::transpose" , trace_only = True )
4701
4700
def aten_transpose (self , dim0 : int , dim1 : int ):
4702
4701
# transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
4703
4702
4704
4703
# FIXME(justinchuby): onnxscript raises Unsupported expression type
4705
- return op .Transpose (self , [dim0 , dim1 ])
4704
+ # Script the function when this is fixed
4705
+ self_rank = op .Size (op .Shape (self ))
4706
+
4707
+ if self_rank == 0 :
4708
+ result = self
4709
+ else :
4710
+ # Python code, change when onnxscript supports this
4711
+ self_rank_val = self_rank .value # type: ignore[attr-defined]
4712
+ dims = list (range (self_rank_val ))
4713
+ dims [dim0 ], dims [dim1 ] = dims [dim1 ], dims [dim0 ]
4714
+ # Python code ends
4715
+
4716
+ result = op .Transpose (self , perm = dims )
4717
+
4718
+ return result
4706
4719
4707
4720
4708
4721
def aten_triangular_solve (
0 commit comments