@@ -176,10 +176,35 @@ def aten_align_to(self: TensorType, names: Sequence[str]) -> TensorType:
176176 raise NotImplementedError ()
177177
178178
179- def aten_all (self : TensorType ) -> TensorType :
179+ @torch_op ("aten::all" )
180+ def aten_all (self : TTensor ) -> BOOL :
180181 """all(Tensor self) -> Tensor"""
181182
182- raise NotImplementedError ()
183+ if op .Size (op .Shape (self )) == 0 :
184+ result = op .Cast (self , to = BOOL .dtype )
185+ else :
186+ self_bool = op .Cast (self , to = BOOL .dtype )
187+ self_int = op .Cast (self_bool , to = INT64 .dtype )
188+ result_int = op .ReduceMin (self_int , keepdims = 0 )
189+ result = op .Cast (result_int , to = BOOL .dtype )
190+
191+ return result
192+
193+
194+ @torch_op ("aten::all" , overload = True )
195+ def aten_all_dim (self : TTensor , dim : int , keepdim : bool = False ) -> BOOL :
196+ """all(Tensor self) -> Tensor"""
197+
198+ if op .Size (op .Shape (self )) == 0 :
199+ result = op .Cast (self , to = BOOL .dtype )
200+ else :
201+ self_bool = op .Cast (self , to = BOOL .dtype )
202+ self_int = op .Cast (self_bool , to = INT64 .dtype )
203+ dims = op .Reshape (dim , op .Constant (value_ints = [- 1 ]))
204+ result_int = op .ReduceMin (self_int , dims , keepdims = keepdim )
205+ result = op .Cast (result_int , to = BOOL .dtype )
206+
207+ return result
183208
184209
185210def aten_allclose (
@@ -2899,10 +2924,13 @@ def aten_isclose(
28992924 raise NotImplementedError ()
29002925
29012926
2927+ @torch_op ("aten::isfinite" )
29022928def aten_isfinite (self : TensorType ) -> TensorType :
29032929 """isfinite(Tensor self) -> Tensor"""
29042930
2905- raise NotImplementedError ()
2931+ not_inf = op .Not (op .IsInf (self ))
2932+ not_nan = op .Not (op .IsNaN (self )) # TODO: The test case doesnt cover this condition
2933+ return op .And (not_inf , not_nan )
29062934
29072935
29082936@torch_op ("aten::isinf" )
@@ -5187,10 +5215,11 @@ def aten_split_copy(self: TensorType, split_size: INT64, dim: int = 0) -> Tensor
51875215 raise NotImplementedError ()
51885216
51895217
5190- def aten_split_with_sizes (self : TensorType , split_sizes : INT64 , dim : int = 0 ) -> TensorType :
5218+ @torch_op ("aten::split_with_sizes" )
5219+ def aten_split_with_sizes (self : TTensor , split_sizes : INT64 , dim : int = 0 ) -> TTensor :
51915220 """split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]"""
51925221
5193- raise NotImplementedError ( )
5222+ return op . SplitToSequence ( self , split_sizes , axis = dim )
51945223
51955224
51965225def aten_split_with_sizes_copy (
0 commit comments