1616from onnxscript import BOOL , DOUBLE , FLOAT , INT16 , INT32 , INT64
1717from onnxscript .function_libs .torch_aten .registration import torch_op
1818from onnxscript .function_libs .torch_aten .typing import (
19+ IntType ,
1920 TFloat ,
2021 TFloatOrBFloat16 ,
2122 TInt ,
@@ -1642,10 +1643,10 @@ def aten_exp2(self: TFloat) -> TFloat:
16421643
16431644
16441645@torch_op ("aten::expand" )
1645- def aten_expand (self : TTensor , size : INT64 ) -> TTensor :
1646+ def aten_expand (self : TTensor , size : TInt ) -> TTensor :
16461647 # expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
16471648
1648- size = op .Cast (size , to = INT64 .dtype ) # to INT64
1649+ size = op .Cast (size , to = INT64 .dtype )
16491650 return op .Expand (self , size )
16501651
16511652
@@ -3518,10 +3519,11 @@ def aten_new_empty_strided(self: TensorType, size: INT64, stride: INT64) -> Tens
35183519
35193520@torch_op ("aten::new_full" )
35203521def aten_new_full (
3521- self , size : INT64 , fill_value , dtype : int = FLOAT .dtype
3522+ self , size : IntType , fill_value , dtype : int = FLOAT .dtype
35223523): # pylint: disable=unused-argument
35233524 # new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
35243525
3526+ size = op .Cast (size , to = INT64 .dtype )
35253527 fill_value = op .Cast (fill_value , to = dtype )
35263528
35273529 return op .Expand (fill_value , size )
@@ -3585,12 +3587,12 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:
35853587
35863588
35873589@torch_op ("aten::ones" )
3588- def aten_ones (size : INT64 , dtype : int = - 1 ):
3590+ def aten_ones (size : IntType , dtype : int = FLOAT . dtype ):
35893591 # ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
35903592
3593+ size = op .Cast (size , to = INT64 .dtype )
35913594 one = op .Constant (value_float = 1 )
3592- if dtype != - 1 :
3593- one = op .Cast (one , to = dtype )
3595+ one = op .Cast (one , to = dtype )
35943596 return op .Expand (one , size )
35953597
35963598
@@ -4088,13 +4090,14 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
40884090
40894091
40904092@torch_op ("aten::repeat" )
4091- def aten_repeat (self : TTensor , repeats : INT64 ) -> TTensor :
4093+ def aten_repeat (self : TTensor , repeats : TInt ) -> TTensor :
40924094 # repeat(Tensor self, SymInt[] repeats) -> Tensor
40934095
40944096 if op .Size (repeats ) == 0 :
40954097 result = self
40964098 else :
40974099 # TODO(justinchuby): Make ones_like a function when onnxscript supports it
4100+ repeats = op .Cast (repeats , to = INT64 .dtype )
40984101 # shape = ones_like(repeats) := {
40994102 one = op .Constant (value_int = 1 )
41004103 repeats_shape = op .Shape (repeats )
@@ -4114,10 +4117,11 @@ def aten_repeat_interleave(
41144117
41154118
41164119@torch_op ("aten::reshape" )
4117- def aten_reshape (self : TTensor , shape : INT64 ) -> TTensor :
4120+ def aten_reshape (self : TTensor , shape : IntType ) -> TTensor :
41184121 # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
41194122
4120- shape = op .Cast (shape , to = INT64 .dtype ) # Reshape only support INT64 as 'shape'
4123+ # Reshape only support INT64 as 'shape'
4124+ shape = op .Cast (shape , to = INT64 .dtype )
41214125 return op .Reshape (self , shape )
41224126
41234127
@@ -4975,7 +4979,7 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
49754979
49764980
49774981@torch_op ("aten::view" )
4978- def aten_view (self : TTensor , size : INT64 ) -> TTensor :
4982+ def aten_view (self : TTensor , size : IntType ) -> TTensor :
49794983 # view(Tensor(a) self, SymInt[] size) -> Tensor(a)
49804984
49814985 size = op .Cast (size , to = INT64 .dtype ) # Reshape only support INT64 as second input
@@ -5044,12 +5048,12 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:
50445048
50455049
50465050@torch_op ("aten::zeros" )
5047- def aten_zeros (size : INT64 , dtype : int = - 1 ):
5051+ def aten_zeros (size : IntType , dtype : int = FLOAT . dtype ):
50485052 # zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
50495053
5054+ size = op .Cast (size , to = INT64 .dtype )
50505055 zero = op .Constant (value_float = 0 )
5051- if dtype != - 1 :
5052- zero = op .Cast (zero , to = dtype )
5056+ zero = op .Cast (zero , to = dtype )
50535057
50545058 return op .Expand (zero , size )
50555059
0 commit comments