@@ -331,20 +331,38 @@ def aten_arctanh(self: TensorType) -> TensorType:
331
331
raise NotImplementedError ()
332
332
333
333
334
- def aten_argmax (
335
- self : TensorType , dim : Optional [int ] = None , keepdim : bool = False
336
- ) -> TensorType :
334
+ @torch_op ("aten::argmax" , trace_only = True )
335
+ def aten_argmax (self : TReal , dim : Optional [int ] = None , keepdim : bool = False ) -> TReal :
337
336
# argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
338
337
339
- raise NotImplementedError ()
338
+ self_is_scaler = op .Size (op .Shape (self )) == 0
339
+ if self_is_scaler :
340
+ self = op .Reshape (self , op .Constant (value_ints = [- 1 ]))
341
+ elif dim is None : # should use OptionalHasElement(dim)
342
+ self = op .Reshape (self , op .Constant (value_ints = [- 1 ]))
340
343
344
+ result = op .ArgMax (self , axis = dim , keepdims = keepdim )
345
+ if self_is_scaler :
346
+ result = op .Squeeze (result )
341
347
342
- def aten_argmin (
343
- self : TensorType , dim : Optional [int ] = None , keepdim : bool = False
344
- ) -> TensorType :
348
+ return result
349
+
350
+
351
+ @torch_op ("aten::argmin" , trace_only = True )
352
+ def aten_argmin (self : TReal , dim : Optional [int ] = None , keepdim : bool = False ) -> TReal :
345
353
# argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
346
354
347
- raise NotImplementedError ()
355
+ self_is_scaler = op .Size (op .Shape (self )) == 0
356
+ if self_is_scaler :
357
+ self = op .Reshape (self , op .Constant (value_ints = [- 1 ]))
358
+ elif dim is None : # should use OptionalHasElement(dim)
359
+ self = op .Reshape (self , op .Constant (value_ints = [- 1 ]))
360
+
361
+ result = op .ArgMin (self , axis = dim , keepdims = keepdim )
362
+ if self_is_scaler :
363
+ result = op .Squeeze (result )
364
+
365
+ return result
348
366
349
367
350
368
def aten_argsort (self : TensorType , dim : int = - 1 , descending : bool = False ) -> TensorType :
@@ -1383,10 +1401,11 @@ def aten_det(self: TensorType) -> TensorType:
1383
1401
raise NotImplementedError ()
1384
1402
1385
1403
1404
+ @torch_op ("aten::detach" )
1386
1405
def aten_detach (self : TensorType ) -> TensorType :
1387
1406
# detach(Tensor(a) self) -> Tensor(a)
1388
1407
1389
- raise NotImplementedError ( )
1408
+ return op . Identity ( self )
1390
1409
1391
1410
1392
1411
def aten_detach_copy (self : TensorType ) -> TensorType :
0 commit comments