@@ -1841,18 +1841,23 @@ def aten_from_file(
18411841 raise NotImplementedError ()
18421842
18431843
1844- def aten_full (size : INT64 , fill_value : float ) -> TensorType :
1844+ @torch_op ("aten::full" )
1845+ def aten_full (size : INT64 , fill_value , dtype : int = FLOAT .dtype ):
18451846 # full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
18461847
1847- raise NotImplementedError ( )
1848+ fill_value = op . Cast ( fill_value , to = dtype )
18481849
1850+ return op .Expand (fill_value , size )
18491851
1850- def aten_full_like (
1851- self : TensorType , fill_value : float , memory_format : Optional [ str ] = None
1852- ) -> TensorType :
1852+
1853+ @ torch_op ( "aten::full_like" )
1854+ def aten_full_like ( self , fill_value , dtype : int = FLOAT . dtype ) :
18531855 # full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
18541856
1855- raise NotImplementedError ()
1857+ fill_value = op .Cast (fill_value , to = dtype )
1858+ self_shape = op .Shape (self )
1859+
1860+ return op .Expand (fill_value , self_shape )
18561861
18571862
18581863def aten_fused_moving_avg_obs_fake_quant (
@@ -3447,10 +3452,15 @@ def aten_new_empty_strided(self: TensorType, size: INT64, stride: INT64) -> Tens
34473452 raise NotImplementedError ()
34483453
34493454
3450- def aten_new_full (self : TensorType , size : INT64 , fill_value : float ) -> TensorType :
3455+ @torch_op ("aten::new_full" )
3456+ def aten_new_full (
3457+ self , size : INT64 , fill_value , dtype : int = FLOAT .dtype
3458+ ): # pylint: disable=unused-argument
34513459 # new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
34523460
3453- raise NotImplementedError ()
3461+ fill_value = op .Cast (fill_value , to = dtype )
3462+
3463+ return op .Expand (fill_value , size )
34543464
34553465
34563466def aten_new_ones (self : TensorType , size : INT64 ) -> TensorType :
@@ -4928,10 +4938,11 @@ def aten_vstack(tensors: Sequence[TensorType]) -> TensorType:
49284938 raise NotImplementedError ()
49294939
49304940
4931- def aten_where (condition : TensorType ) -> TensorType :
4932- # where(Tensor condition) -> Tensor[]
4941+ @torch_op ("aten::where" )
4942+ def aten_where (self : TTensor , condition : BOOL , other : TTensor ) -> TTensor :
4943+ # where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
49334944
4934- raise NotImplementedError ( )
4945+ return op . Where ( condition , self , other )
49354946
49364947
49374948def aten_xlogy (self : TensorType , other : TensorType ) -> TensorType :
0 commit comments