Skip to content

Commit 8c0046f

Browse files
authored
[torchlib] Set allowzero=True on Reshape where appropriate (#2346)
When we reshape from a dynamic shape, the shape can contain zeros. This change accounts for those cases.
1 parent 5a8b9e6 commit 8c0046f

File tree

2 files changed

+15
-21
lines changed

2 files changed

+15
-21
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4390,7 +4390,7 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
43904390
reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape)
43914391

43924392
# Reshape and expand the index.
4393-
idx = op.Reshape(idx, reshape_list)
4393+
idx = op.Reshape(idx, reshape_list, allowzero=True)
43944394
idx = op.Expand(idx, values_shape)
43954395

43964396
# Flatten the index to 1D and unsqueeze to form a column vector.
@@ -4547,7 +4547,7 @@ def aten_instance_norm(
45474547
momentum=1.0 - momentum,
45484548
training_mode=False,
45494549
)
4550-
return op.Reshape(norm, op.Shape(input))
4550+
return op.Reshape(norm, op.Shape(input), allowzero=True)
45514551

45524552

45534553
def aten_int_repr(self: TensorType) -> TensorType:
@@ -6244,7 +6244,7 @@ def _aten_native_group_norm_onnx(
62446244
input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps
62456245
)
62466246
# Reshape back to input's shape
6247-
norm = op.Reshape(norm, op.Shape(input))
6247+
norm = op.Reshape(norm, op.Shape(input), allowzero=True)
62486248
# Using the input weight and bias to do affine
62496249
# But need to unsqueeze to the target shape for broading cast easy
62506250
input_rank = Rank(input)
@@ -6693,7 +6693,7 @@ def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal:
66936693
)
66946694
depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD")
66956695
output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0)
6696-
return op.Reshape(depth_to_space, output_shape)
6696+
return op.Reshape(depth_to_space, output_shape, allowzero=True)
66976697

66986698

66996699
@torch_op("aten::pixel_unshuffle")
@@ -6709,7 +6709,7 @@ def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal:
67096709
)
67106710
space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor)
67116711
output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0)
6712-
return op.Reshape(space_to_depth, output_shape)
6712+
return op.Reshape(space_to_depth, output_shape, allowzero=True)
67136713

67146714

67156715
def aten_poisson(self: TensorType, generator: Optional[str] = None) -> TensorType:
@@ -8390,7 +8390,7 @@ def aten_tile(self: TTensor, dims: INT64) -> TTensor:
83908390
exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d)
83918391
self_shape = op.Shape(self)
83928392
self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0)
8393-
self = op.Reshape(self, self_final_shape)
8393+
self = op.Reshape(self, self_final_shape, allowzero=True)
83948394

83958395
return op.Tile(self, dims)
83968396

@@ -8630,7 +8630,7 @@ def aten_unflatten(self: TReal, dim: int, sizes: Sequence[INT64]):
86308630
final_shape = op.Concat(head_part_rank, *sizes, axis=0)
86318631
else:
86328632
final_shape = op.Concat(head_part_rank, *sizes, tail_part_rank, axis=0)
8633-
return op.Reshape(self, final_shape)
8633+
return op.Reshape(self, final_shape, allowzero=True)
86348634

86358635

86368636
@torch_op("aten::unfold", trace_only=True)
@@ -8706,11 +8706,11 @@ def aten__unique(
87068706
unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
87078707
input_size = op.Shape(self)
87088708
if return_inverse:
8709-
inverse_indices = op.Reshape(inverse_indices, input_size)
8709+
inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)
87108710
else:
87118711
input_numel = op.ReduceProd(input_size, keepdims=False)
87128712
if input_numel == 0:
8713-
inverse_indices = op.Reshape(inverse_indices, input_size)
8713+
inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)
87148714
else:
87158715
inverse_indices = op.ConstantOfShape([0])
87168716
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
@@ -8729,11 +8729,11 @@ def aten__unique2(
87298729
unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=True)
87308730
input_size = op.Shape(self)
87318731
if return_inverse:
8732-
inverse_indices = op.Reshape(inverse_indices, input_size)
8732+
inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)
87338733
else:
87348734
input_numel = op.ReduceProd(input_size, keepdims=False)
87358735
if input_numel == 0:
8736-
inverse_indices = op.Reshape(inverse_indices, input_size)
8736+
inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)
87378737
else:
87388738
inverse_indices = op.ConstantOfShape([0])
87398739
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
@@ -9019,7 +9019,7 @@ def aten_view(self: TTensor, size: IntType) -> TTensor:
90199019
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
90209020

90219021
size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
9022-
return op.Reshape(self, size)
9022+
return op.Reshape(self, size, allowzero=True)
90239023

90249024

90259025
@torch_op(("aten::view", "aten::_unsafe_view"), complex=True)
@@ -9028,15 +9028,15 @@ def aten_view_complex(self: TTensor, size: IntType) -> TTensor:
90289028

90299029
size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
90309030
complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0)
9031-
return op.Reshape(self, complex_size)
9031+
return op.Reshape(self, complex_size, allowzero=True)
90329032

90339033

90349034
@torch_op("aten::view_as")
90359035
def aten_view_as(self: TTensor, other: TTensor2) -> TTensor:
90369036
"""view_as(Tensor(a) self, Tensor other) -> Tensor(a)"""
90379037

90389038
size = op.Shape(other)
9039-
return op.Reshape(self, size)
9039+
return op.Reshape(self, size, allowzero=True)
90409040

90419041

90429042
@torch_op("aten::view_as_complex", trace_only=True)

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,13 +1457,7 @@ def _where_input_wrangler(
14571457
dtypes=(torch.bool,),
14581458
reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905",
14591459
),
1460-
TorchLibOpInfo(
1461-
"unflatten",
1462-
core_ops.aten_unflatten,
1463-
).xfail(
1464-
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
1465-
reason="fixme: Logic not implemented for size 0 inputs in op.Reshape",
1466-
),
1460+
TorchLibOpInfo("unflatten", core_ops.aten_unflatten),
14671461
TorchLibOpInfo("unfold", core_ops.aten_unfold),
14681462
TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold),
14691463
TorchLibOpInfo("unsqueeze", core_ops.aten_unsqueeze),

0 commit comments

Comments
 (0)