@@ -4390,7 +4390,7 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
4390
4390
reshape_list = _make_reshape_list_broadcastable (reshape_list , values_shape )
4391
4391
4392
4392
# Reshape and expand the index.
4393
- idx = op .Reshape (idx , reshape_list )
4393
+ idx = op .Reshape (idx , reshape_list , allowzero = True )
4394
4394
idx = op .Expand (idx , values_shape )
4395
4395
4396
4396
# Flatten the index to 1D and unsqueeze to form a column vector.
@@ -4547,7 +4547,7 @@ def aten_instance_norm(
4547
4547
momentum = 1.0 - momentum ,
4548
4548
training_mode = False ,
4549
4549
)
4550
- return op .Reshape (norm , op .Shape (input ))
4550
+ return op .Reshape (norm , op .Shape (input ), allowzero = True )
4551
4551
4552
4552
4553
4553
def aten_int_repr (self : TensorType ) -> TensorType :
@@ -6244,7 +6244,7 @@ def _aten_native_group_norm_onnx(
6244
6244
input_reshaped , weight_inst_norm , bias_inst_norm , epsilon = eps
6245
6245
)
6246
6246
# Reshape back to input's shape
6247
- norm = op .Reshape (norm , op .Shape (input ))
6247
+ norm = op .Reshape (norm , op .Shape (input ), allowzero = True )
6248
6248
# Using the input weight and bias to do affine
6249
6249
# But need to unsqueeze to the target shape for broading cast easy
6250
6250
input_rank = Rank (input )
@@ -6693,7 +6693,7 @@ def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal:
6693
6693
)
6694
6694
depth_to_space = op .DepthToSpace (reshaped_self , blocksize = upscale_factor , mode = "CRD" )
6695
6695
output_shape = op .Concat (batch_dims , op .Shape (depth_to_space )[1 :], axis = 0 )
6696
- return op .Reshape (depth_to_space , output_shape )
6696
+ return op .Reshape (depth_to_space , output_shape , allowzero = True )
6697
6697
6698
6698
6699
6699
@torch_op ("aten::pixel_unshuffle" )
@@ -6709,7 +6709,7 @@ def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal:
6709
6709
)
6710
6710
space_to_depth = op .SpaceToDepth (reshaped_self , blocksize = downscale_factor )
6711
6711
output_shape = op .Concat (batch_dims , op .Shape (space_to_depth )[1 :], axis = 0 )
6712
- return op .Reshape (space_to_depth , output_shape )
6712
+ return op .Reshape (space_to_depth , output_shape , allowzero = True )
6713
6713
6714
6714
6715
6715
def aten_poisson (self : TensorType , generator : Optional [str ] = None ) -> TensorType :
@@ -8390,7 +8390,7 @@ def aten_tile(self: TTensor, dims: INT64) -> TTensor:
8390
8390
exapnd_ones = op .Expand (op .Constant (value_ints = [1 ]), diff_1d )
8391
8391
self_shape = op .Shape (self )
8392
8392
self_final_shape = op .Concat (exapnd_ones , self_shape , axis = 0 )
8393
- self = op .Reshape (self , self_final_shape )
8393
+ self = op .Reshape (self , self_final_shape , allowzero = True )
8394
8394
8395
8395
return op .Tile (self , dims )
8396
8396
@@ -8630,7 +8630,7 @@ def aten_unflatten(self: TReal, dim: int, sizes: Sequence[INT64]):
8630
8630
final_shape = op .Concat (head_part_rank , * sizes , axis = 0 )
8631
8631
else :
8632
8632
final_shape = op .Concat (head_part_rank , * sizes , tail_part_rank , axis = 0 )
8633
- return op .Reshape (self , final_shape )
8633
+ return op .Reshape (self , final_shape , allowzero = True )
8634
8634
8635
8635
8636
8636
@torch_op ("aten::unfold" , trace_only = True )
@@ -8706,11 +8706,11 @@ def aten__unique(
8706
8706
unique_values , _ , inverse_indices , _ = op .Unique (self , axis = None , sorted = True )
8707
8707
input_size = op .Shape (self )
8708
8708
if return_inverse :
8709
- inverse_indices = op .Reshape (inverse_indices , input_size )
8709
+ inverse_indices = op .Reshape (inverse_indices , input_size , allowzero = True )
8710
8710
else :
8711
8711
input_numel = op .ReduceProd (input_size , keepdims = False )
8712
8712
if input_numel == 0 :
8713
- inverse_indices = op .Reshape (inverse_indices , input_size )
8713
+ inverse_indices = op .Reshape (inverse_indices , input_size , allowzero = True )
8714
8714
else :
8715
8715
inverse_indices = op .ConstantOfShape ([0 ])
8716
8716
inverse_indices = op .Cast (inverse_indices , to = INT64 .dtype )
@@ -8729,11 +8729,11 @@ def aten__unique2(
8729
8729
unique_values , _ , inverse_indices , counts = op .Unique (self , axis = None , sorted = True )
8730
8730
input_size = op .Shape (self )
8731
8731
if return_inverse :
8732
- inverse_indices = op .Reshape (inverse_indices , input_size )
8732
+ inverse_indices = op .Reshape (inverse_indices , input_size , allowzero = True )
8733
8733
else :
8734
8734
input_numel = op .ReduceProd (input_size , keepdims = False )
8735
8735
if input_numel == 0 :
8736
- inverse_indices = op .Reshape (inverse_indices , input_size )
8736
+ inverse_indices = op .Reshape (inverse_indices , input_size , allowzero = True )
8737
8737
else :
8738
8738
inverse_indices = op .ConstantOfShape ([0 ])
8739
8739
inverse_indices = op .Cast (inverse_indices , to = INT64 .dtype )
@@ -9019,7 +9019,7 @@ def aten_view(self: TTensor, size: IntType) -> TTensor:
9019
9019
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
9020
9020
9021
9021
size = op .Cast (size , to = INT64 .dtype ) # Reshape only support INT64 as second input
9022
- return op .Reshape (self , size )
9022
+ return op .Reshape (self , size , allowzero = True )
9023
9023
9024
9024
9025
9025
@torch_op (("aten::view" , "aten::_unsafe_view" ), complex = True )
@@ -9028,15 +9028,15 @@ def aten_view_complex(self: TTensor, size: IntType) -> TTensor:
9028
9028
9029
9029
size = op .Cast (size , to = INT64 .dtype ) # Reshape only support INT64 as second input
9030
9030
complex_size = op .Concat (size , op .Constant (value_ints = [2 ]), axis = 0 )
9031
- return op .Reshape (self , complex_size )
9031
+ return op .Reshape (self , complex_size , allowzero = True )
9032
9032
9033
9033
9034
9034
@torch_op ("aten::view_as" )
9035
9035
def aten_view_as (self : TTensor , other : TTensor2 ) -> TTensor :
9036
9036
"""view_as(Tensor(a) self, Tensor other) -> Tensor(a)"""
9037
9037
9038
9038
size = op .Shape (other )
9039
- return op .Reshape (self , size )
9039
+ return op .Reshape (self , size , allowzero = True )
9040
9040
9041
9041
9042
9042
@torch_op ("aten::view_as_complex" , trace_only = True )
0 commit comments