@@ -186,15 +186,15 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:
186186 raise NotImplementedError ()
187187
188188
189- # @torch_op("aten::amax") # FIXME(#249): Uncomment when CI uses onnx 1.13
189+ @torch_op ("aten::amax" )
190190def aten_amax (self : TReal , dim : INT64 , keepdim : int = 0 ) -> TReal :
191191 # amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
192192
193193 # TODO(justinchuby): Make dim optional, keepdim bool
194194 return op .ReduceMax (self , dim , keepdims = keepdim )
195195
196196
197- # @torch_op("aten::amin") # FIXME(#249): Uncomment when CI uses onnx 1.13
197+ @torch_op ("aten::amin" )
198198def aten_amin (self : TReal , dim : INT64 , keepdim : int = 0 ) -> TReal :
199199 # amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
200200
@@ -710,10 +710,15 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType:
710710 raise NotImplementedError ()
711711
712712
713- def aten_cat (tensors : Sequence [TensorType ], dim : int = 0 ) -> TensorType :
713+ @torch_op ("aten::cat" , trace_only = True )
714+ def aten_cat (tensors : Sequence [TTensor ], dim : int = 0 ) -> TTensor :
714715 # cat(Tensor[] tensors, int dim=0) -> Tensor
715716
716- raise NotImplementedError ()
717+ num_of_input = len (tensors ) # len() function not support yet
718+ a = op .SequenceEmpty ()
719+ for i in range (num_of_input ):
720+ a = op .SequenceInsert (a , tensors [i ])
721+ return op .ConcatFromSequence (a , axis = dim )
717722
718723
719724def aten_ccol_indices (self : TensorType ) -> TensorType :
@@ -1506,16 +1511,15 @@ def aten_einsum(
15061511 raise NotImplementedError ()
15071512
15081513
1514+ @torch_op ("aten::embedding" )
15091515def aten_embedding (
1510- weight : TensorType ,
1511- indices : TensorType ,
1512- padding_idx : int = - 1 ,
1513- scale_grad_by_freq : bool = False ,
1514- sparse : bool = False ,
1515- ) -> TensorType :
1516+ weight : TTensor ,
1517+ indices : TTensor ,
1518+ ** _ ,
1519+ ) -> TTensor :
15161520 # embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
15171521
1518- raise NotImplementedError ( )
1522+ return op . Gather ( weight , indices )
15191523
15201524
15211525def aten_embedding_backward (
@@ -1570,10 +1574,29 @@ def aten_embedding_sparse_backward(
15701574 raise NotImplementedError ()
15711575
15721576
1573- def aten_empty_like (self : TensorType , memory_format : Optional [str ] = None ) -> TensorType :
1577+ @torch_op ("aten::empty" )
1578+ def aten_empty (size : IntType , dtype : int = FLOAT .dtype ) -> TTensor : # type: ignore[type-var]
1579+ # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
1580+
1581+ # using Zeros to simulate np.empty()
1582+ size = op .Cast (size , to = INT64 .dtype )
1583+ zero = op .Constant (value_float = 0 )
1584+ zero = op .Cast (zero , to = dtype )
1585+
1586+ return op .Expand (zero , size )
1587+
1588+
1589+ @torch_op ("aten::empty_like" )
1590+ def aten_empty_like (self : TTensor , dtype : int = - 1 ) -> TTensor :
15741591 # empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
15751592
1576- raise NotImplementedError ()
1593+ shape = op .Shape (self )
1594+ if dtype == - 1 :
1595+ zero = op .CastLike (0 , self )
1596+ else :
1597+ zero = op .Cast (0 , to = dtype )
1598+
1599+ return op .Expand (zero , shape )
15771600
15781601
15791602def aten_empty_quantized (
@@ -1957,10 +1980,11 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
19571980 raise NotImplementedError ()
19581981
19591982
1960- def aten_ge (self : TensorType , other : TensorType ) -> TensorType :
1983+ @torch_op ("aten::ge" )
1984+ def aten_ge (self : TReal , other : TReal ) -> BOOL :
19611985 # ge.Tensor(Tensor self, Tensor other) -> Tensor
19621986
1963- raise NotImplementedError ( )
1987+ return op . Greater ( self , other )
19641988
19651989
19661990def aten_geqrf (self : TensorType ) -> tuple [TensorType , TensorType ]:
@@ -2514,10 +2538,11 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
25142538 raise NotImplementedError ()
25152539
25162540
2517- def aten_le (self : TensorType , other : TensorType ) -> TensorType :
2541+ @torch_op ("aten::le" )
2542+ def aten_le (self : TReal , other : TReal ) -> BOOL :
25182543 # le.Tensor(Tensor self, Tensor other) -> Tensor
25192544
2520- raise NotImplementedError ( )
2545+ return op . Less ( self , other )
25212546
25222547
25232548def aten_lerp (self : TensorType , end : TensorType , weight : TensorType ) -> TensorType :
@@ -2680,7 +2705,7 @@ def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> T
26802705 raise NotImplementedError ()
26812706
26822707
2683- @torch_op ("aten::logsumexp" , trace_only = True ) # FIXME(#249): Script when CI uses onnx 1.13
2708+ @torch_op ("aten::logsumexp" )
26842709def aten_logsumexp (self : TReal , dim : INT64 , keepdim : int = False ) -> TReal :
26852710 # logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
26862711
@@ -2902,10 +2927,11 @@ def aten_max_pool3d(
29022927 raise NotImplementedError ()
29032928
29042929
2905- def aten_maximum (self : TensorType , other : TensorType ) -> TensorType :
2930+ @torch_op ("aten::maximum" )
2931+ def aten_maximum (self : TReal , other : TReal ) -> TReal :
29062932 # maximum(Tensor self, Tensor other) -> Tensor
29072933
2908- raise NotImplementedError ( )
2934+ return op . Max ( self , other )
29092935
29102936
29112937def aten_mean (self : TensorType , dtype : Optional [int ] = None ) -> TensorType :
@@ -2932,10 +2958,11 @@ def aten_min(self: TensorType) -> TensorType:
29322958 raise NotImplementedError ()
29332959
29342960
2935- def aten_minimum (self : TensorType , other : TensorType ) -> TensorType :
2961+ @torch_op ("aten::minimum" )
2962+ def aten_minimum (self : TReal , other : TReal ) -> TReal :
29362963 # minimum(Tensor self, Tensor other) -> Tensor
29372964
2938- raise NotImplementedError ( )
2965+ return op . Min ( self , other )
29392966
29402967
29412968def aten_miopen_batch_norm (
@@ -4393,16 +4420,30 @@ def aten_sinh(self: TFloat) -> TFloat:
43934420 return op .Sinh (self )
43944421
43954422
4423+ @torch_op ("aten::slice" )
43964424def aten_slice (
4397- self : TensorType ,
4425+ self : TTensor ,
43984426 dim : int = 0 ,
43994427 start : Optional [INT64 ] = None ,
44004428 end : Optional [INT64 ] = None ,
44014429 step : INT64 = 1 ,
4402- ) -> TensorType :
4430+ ) -> TTensor :
44034431 # slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
44044432
4405- raise NotImplementedError ()
4433+ # TODO: using OptionalHasElement() to check start/end value
4434+ start = op .Cast (start , to = INT64 .dtype )
4435+ start = op .Reshape (start , op .Constant (value_ints = [- 1 ]))
4436+
4437+ end = op .Cast (end , to = INT64 .dtype )
4438+ end = op .Reshape (end , op .Constant (value_ints = [- 1 ]))
4439+
4440+ dim = op .Cast (dim , to = INT64 .dtype )
4441+ dim = op .Reshape (dim , op .Constant (value_ints = [- 1 ]))
4442+
4443+ step = op .Cast (step , to = INT64 .dtype )
4444+ step = op .Reshape (step , op .Constant (value_ints = [- 1 ]))
4445+
4446+ return op .Slice (self , start , end , dim , step )
44064447
44074448
44084449def aten_slice_backward (
0 commit comments