@@ -803,10 +803,13 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal:
803803 return result
804804
805805
806- def aten_clone (self : TensorType , memory_format : Optional [str ] = None ) -> TensorType :
806+ @torch_op ("aten::clone" )
807+ def aten_clone (
808+ self : TTensor , memory_format : str = "" # pylint: disable=unused-argument
809+ ) -> TTensor :
807810 # clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
808811
809- raise NotImplementedError ( )
812+ return op . Identity ( self )
810813
811814
812815def aten_coalesce (self : TensorType ) -> TensorType :
@@ -1406,10 +1409,11 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2) -> TensorType:
14061409 raise NotImplementedError ()
14071410
14081411
1409- def aten_div (self : TensorType , other : TensorType ) -> TensorType :
1412+ @torch_op ("aten::div" )
1413+ def aten_div (self : TReal , other : TReal ) -> TReal :
14101414 # div.Tensor(Tensor self, Tensor other) -> Tensor
14111415
1412- raise NotImplementedError ( )
1416+ return op . Div ( self , other )
14131417
14141418
14151419def aten_divide (self : TensorType , other : TensorType ) -> TensorType :
@@ -1529,16 +1533,21 @@ def aten_empty_strided(size: INT64, stride: INT64) -> TensorType:
15291533 raise NotImplementedError ()
15301534
15311535
1532- def aten_eq (self : TensorType , other : TensorType ) -> TensorType :
1536+ @torch_op ("aten::eq" )
1537+ def aten_eq (self : TTensor , other : TTensor ) -> BOOL :
15331538 # eq.Tensor(Tensor self, Tensor other) -> Tensor
15341539
1535- raise NotImplementedError ( )
1540+ return op . Equal ( self , other )
15361541
15371542
1538- def aten_equal (self : TensorType , other : TensorType ) -> bool :
1543+ @torch_op ("aten::equal" )
1544+ def aten_equal (self : TTensor , other : TTensor ) -> BOOL :
15391545 # equal(Tensor self, Tensor other) -> bool
15401546
1541- raise NotImplementedError ()
1547+ sub_self_other = op .Sub (self , other )
1548+ abs_sub = op .Abs (sub_self_other )
1549+ sum_of_abs = op .ReduceSum (abs_sub , keepdims = 0 )
1550+ return op .Equal (sum_of_abs , 0 )
15421551
15431552
15441553@torch_op ("aten::erf" )
@@ -1576,10 +1585,12 @@ def aten_exp2(self: TFloat) -> TFloat:
15761585 return op .Pow (two , self )
15771586
15781587
1579- def aten_expand (self : TensorType , size : INT64 , implicit : bool = False ) -> TensorType :
1588+ @torch_op ("aten::expand" )
1589+ def aten_expand (self : TTensor , size : INT64 ) -> TTensor :
15801590 # expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
15811591
1582- raise NotImplementedError ()
1592+ size = op .Cast (size , to = INT64 .dtype ) # to INT64
1593+ return op .Expand (self , size )
15831594
15841595
15851596def aten_expand_as (self : TensorType , other : TensorType ) -> TensorType :
@@ -4046,10 +4057,12 @@ def aten_repeat_interleave(
40464057 raise NotImplementedError ()
40474058
40484059
4049- def aten_reshape (self : TensorType , shape : INT64 ) -> TensorType :
4060+ @torch_op ("aten::reshape" )
4061+ def aten_reshape (self : TTensor , shape : INT64 ) -> TTensor :
40504062 # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
40514063
4052- raise NotImplementedError ()
4064+ shape = op .Cast (shape , to = INT64 .dtype ) # Reshape only support INT64 as 'shape'
4065+ return op .Reshape (self , shape )
40534066
40544067
40554068def aten_reshape_as (self : TensorType , other : TensorType ) -> TensorType :
@@ -4484,7 +4497,9 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens
44844497 raise NotImplementedError ()
44854498
44864499
4487- def aten_sum (self : TensorType , dtype : Optional [int ] = None ) -> TensorType :
4500+ def aten_sum (
4501+ self : TensorType , dim : Optional [int ] = None , keepdim : bool = False , dtype : int = - 1
4502+ ) -> TensorType :
44884503 # sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
44894504
44904505 raise NotImplementedError ()
@@ -4903,10 +4918,12 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
49034918 raise NotImplementedError ()
49044919
49054920
4906- def aten_view (self : TensorType , size : INT64 ) -> TensorType :
4921+ @torch_op ("aten::view" )
4922+ def aten_view (self : TTensor , size : INT64 ) -> TTensor :
49074923 # view(Tensor(a) self, SymInt[] size) -> Tensor(a)
49084924
4909- raise NotImplementedError ()
4925+ size = op .Cast (size , to = INT64 .dtype ) # Reshape only support INT64 as second input
4926+ return op .Reshape (self , size )
49104927
49114928
49124929def aten_view_as (self : TensorType , other : TensorType ) -> TensorType :
0 commit comments