Skip to content

[torchlib] Set allowzero=True on Reshape where appropriate #2346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 28, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4390,7 +4390,7 @@
reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape)

# Reshape and expand the index.
idx = op.Reshape(idx, reshape_list)
idx = op.Reshape(idx, reshape_list, allowzero=True)

Check warning on line 4393 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4393

Added line #L4393 was not covered by tests
idx = op.Expand(idx, values_shape)

# Flatten the index to 1D and unsqueeze to form a column vector.
Expand Down Expand Up @@ -4531,6 +4531,7 @@
bn_input = op.Reshape(
input,
op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0),
allowzero=True
)
weight = op.Tile(weight, batch_size)
bias = op.Tile(bias, batch_size)
Expand All @@ -4547,7 +4548,7 @@
momentum=1.0 - momentum,
training_mode=False,
)
return op.Reshape(norm, op.Shape(input))
return op.Reshape(norm, op.Shape(input), allowzero=True)

Check warning on line 4551 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4551

Added line #L4551 was not covered by tests


def aten_int_repr(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -6237,14 +6238,14 @@
group_tensor = op.Reshape(group, neg_1)
# 0 in the shape list keeps dimension value unchanged, for InstanceNorm need [0,group,-1]
shape_input = op.Concat(op.Constant(value_ints=[0]), group_tensor, neg_1, axis=0)
input_reshaped = op.Reshape(input, shape_input)
input_reshaped = op.Reshape(input, shape_input, allowzero=True)

Check warning on line 6241 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6241

Added line #L6241 was not covered by tests
weight_inst_norm = op.Expand(op.CastLike(1.0, input), group_tensor)
bias_inst_norm = op.Expand(op.CastLike(0.0, input), group_tensor)
norm = op.InstanceNormalization(
input_reshaped, weight_inst_norm, bias_inst_norm, epsilon=eps
)
# Reshape back to input's shape
norm = op.Reshape(norm, op.Shape(input))
norm = op.Reshape(norm, op.Shape(input), allowzero=True)

Check warning on line 6248 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6248

Added line #L6248 was not covered by tests
# Using the input weight and bias to do affine
# But need to unsqueeze to the target shape for broading cast easy
input_rank = Rank(input)
Expand All @@ -6259,7 +6260,7 @@
# The returned shape for mean and vstd should be [N, group, -1]
N = op.Shape(input, start=0, end=1)
shape_N_group_neg1 = op.Concat(N, group_tensor, neg_1, axis=0)
input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1)
input_N_group_neg1 = op.Reshape(input, shape_N_group_neg1, allowzero=True)

Check warning on line 6263 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6263

Added line #L6263 was not covered by tests
# The output size is [N, group], so dims = [2]
axes = op.Constant(value_ints=[2])
# Get mean which size is [N, group, 1], for broadcasting
Expand Down Expand Up @@ -6693,7 +6694,7 @@
)
depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD")
output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0)
return op.Reshape(depth_to_space, output_shape)
return op.Reshape(depth_to_space, output_shape, allowzero=True)

Check warning on line 6697 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6697

Added line #L6697 was not covered by tests


@torch_op("aten::pixel_unshuffle")
Expand All @@ -6709,7 +6710,7 @@
)
space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor)
output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0)
return op.Reshape(space_to_depth, output_shape)
return op.Reshape(space_to_depth, output_shape, allowzero=True)

Check warning on line 6713 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6713

Added line #L6713 was not covered by tests


def aten_poisson(self: TensorType, generator: Optional[str] = None) -> TensorType:
Expand Down Expand Up @@ -8390,7 +8391,7 @@
exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d)
self_shape = op.Shape(self)
self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0)
self = op.Reshape(self, self_final_shape)
self = op.Reshape(self, self_final_shape, allowzero=True)

Check warning on line 8394 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8394

Added line #L8394 was not covered by tests

return op.Tile(self, dims)

Expand Down Expand Up @@ -8630,7 +8631,7 @@
final_shape = op.Concat(head_part_rank, *sizes, axis=0)
else:
final_shape = op.Concat(head_part_rank, *sizes, tail_part_rank, axis=0)
return op.Reshape(self, final_shape)
return op.Reshape(self, final_shape, allowzero=True)

Check warning on line 8634 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8634

Added line #L8634 was not covered by tests


@torch_op("aten::unfold", trace_only=True)
Expand Down Expand Up @@ -8706,11 +8707,11 @@
unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
input_size = op.Shape(self)
if return_inverse:
inverse_indices = op.Reshape(inverse_indices, input_size)
inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)

Check warning on line 8710 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8710

Added line #L8710 was not covered by tests
else:
input_numel = op.ReduceProd(input_size, keepdims=False)
if input_numel == 0:
inverse_indices = op.Reshape(inverse_indices, input_size)
inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)

Check warning on line 8714 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8714

Added line #L8714 was not covered by tests
else:
inverse_indices = op.ConstantOfShape([0])
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
Expand All @@ -8729,11 +8730,11 @@
unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=True)
input_size = op.Shape(self)
if return_inverse:
inverse_indices = op.Reshape(inverse_indices, input_size)
inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)

Check warning on line 8733 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8733

Added line #L8733 was not covered by tests
else:
input_numel = op.ReduceProd(input_size, keepdims=False)
if input_numel == 0:
inverse_indices = op.Reshape(inverse_indices, input_size)
inverse_indices = op.Reshape(inverse_indices, input_size, allowzero=True)

Check warning on line 8737 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8737

Added line #L8737 was not covered by tests
else:
inverse_indices = op.ConstantOfShape([0])
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
Expand Down Expand Up @@ -9019,7 +9020,7 @@
"""view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""

size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
return op.Reshape(self, size)
return op.Reshape(self, size, allowzero=True)

Check warning on line 9023 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L9023

Added line #L9023 was not covered by tests


@torch_op(("aten::view", "aten::_unsafe_view"), complex=True)
Expand All @@ -9028,15 +9029,15 @@

size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0)
return op.Reshape(self, complex_size)
return op.Reshape(self, complex_size, allowzero=True)

Check warning on line 9032 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L9032

Added line #L9032 was not covered by tests


@torch_op("aten::view_as")
def aten_view_as(self: TTensor, other: TTensor2) -> TTensor:
"""view_as(Tensor(a) self, Tensor other) -> Tensor(a)"""

size = op.Shape(other)
return op.Reshape(self, size)
return op.Reshape(self, size, allowzero=True)

Check warning on line 9040 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L9040

Added line #L9040 was not covered by tests


@torch_op("aten::view_as_complex", trace_only=True)
Expand Down
Loading