Skip to content

Commit 6bdfcfd

Browse files
authored
[torchlib] Fix calls to Unsqueeze to provide correct 1d axes (#2186)
Discovered in onnx/onnx#6886 (comment), the `axes` input in calls to unsqueeze are sometimes 0d. This is incorrect according to the ONNX spec. The PR fixes the instances I could find.
1 parent 005568a commit 6bdfcfd

File tree

3 files changed

+18
-20
lines changed

3 files changed

+18
-20
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2991,8 +2991,8 @@ def _aten_embedding_bag_onnx(
29912991
indices_1d = op.Reshape(indices, neg_1)
29922992
# Get weight out according to indices_1d,
29932993
new_weight = op.Gather(weight, indices_1d)
2994-
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
2995-
new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=1))
2994+
# This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
2995+
new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=[1]))
29962996
weight_dim_1 = op.Reshape(op.Shape(weight, start=1), neg_1)
29972997
indices_size = op.Shape(indices_1d)
29982998

@@ -3131,8 +3131,8 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
31313131
# Get weight out according to indices,
31323132
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
31333133
indices_weight = op.Gather(weight, indices)
3134-
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
3135-
indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=1))
3134+
# This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
3135+
indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=[1]))
31363136

31373137
# The element in sequence must be FLOAT32 dtype due to ORT bug
31383138
indices_weight = op.Cast(indices_weight, to=FLOAT.dtype)
@@ -4145,7 +4145,6 @@ def _shape_of_broadcast_tensors(*args: TensorType) -> INT64:
41454145
return op.Shape(broadcasted)
41464146

41474147

4148-
@torch_op("aten::index.Tensor", private=True, trace_only=True)
41494148
def _aten_index_onnx(
41504149
self: TensorType,
41514150
indices: Sequence[Optional[INT64]],
@@ -4173,7 +4172,7 @@ def _aten_index_onnx(
41734172
not_none_indices = [idx for idx in indices if idx is not None]
41744173
broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices)
41754174
final_index = op.Concat(
4176-
*(op.Unsqueeze(op.Expand(idx, broadcast_shape), -1) for idx in not_none_indices),
4175+
*(op.Unsqueeze(op.Expand(idx, broadcast_shape), [-1]) for idx in not_none_indices),
41774176
axis=-1,
41784177
)
41794178

@@ -7706,13 +7705,13 @@ def aten_select_backward(
77067705
raise NotImplementedError()
77077706

77087707

7709-
@torch_op("aten::select_scatter")
7708+
@torch_op("aten::select_scatter", trace_only=True)
77107709
def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int) -> TensorType:
77117710
"""select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor"""
77127711

77137712
# Change src rank to self rank according to dim
77147713
# e.g. if self is [2,3,4], src is [2,4], dim=1, then update is [2,1,4]
7715-
update = op.Unsqueeze(src, axes=dim)
7714+
update = op.Unsqueeze(src, axes=[dim])
77167715
# Change index rank to the same as 'update' [2,1,4]
77177716
indices = op.Expand(index, op.Shape(update))
77187717
return op.ScatterElements(self, indices, update, axis=dim, reduction="none")
@@ -7880,7 +7879,7 @@ def aten_slice_scatter(
78807879
zero,
78817880
op.Unsqueeze(step, zero),
78827881
)
7883-
index_base = op.Unsqueeze(index_base, -1)
7882+
index_base = op.Unsqueeze(index_base, [-1])
78847883

78857884
# Use trace only to construct the perm attribute in Transpose
78867885
dims = None
@@ -8623,7 +8622,7 @@ def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor:
86238622

86248623
self_rank = len(self.shape)
86258624
if self_rank == 0:
8626-
result = op.Unsqueeze(self, 0)
8625+
result = op.Unsqueeze(self, [0])
86278626
else:
86288627
# Handle negative dimension
86298628
if dimension < 0:
@@ -8792,8 +8791,7 @@ def aten_unsafe_split_with_sizes(
87928791
def aten_unsqueeze(self: TTensor, dim: int) -> TTensor:
87938792
"""unsqueeze(Tensor(a) self, int dim) -> Tensor(a)"""
87948793

8795-
dim = op.Cast(dim, to=INT64.dtype)
8796-
return op.Unsqueeze(self, dim)
8794+
return op.Unsqueeze(self, [dim])
87978795

87988796

87998797
def aten_unsqueeze_copy(self: TensorType, dim: int) -> TensorType:

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ def _aten_max_pool_onnx(
10021002
) -> TFloatOrUInt8:
10031003
self_rank_is_unbatched_rank = Rank(self) == unbatched_rank
10041004
if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1
1005-
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
1005+
self = op.Unsqueeze(self, [0])
10061006

10071007
pool_result, _ = op.MaxPool(
10081008
self,
@@ -1014,7 +1014,7 @@ def _aten_max_pool_onnx(
10141014
)
10151015

10161016
if self_rank_is_unbatched_rank:
1017-
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
1017+
pool_result = op.Squeeze(pool_result, [0])
10181018

10191019
return pool_result
10201020

@@ -1136,7 +1136,7 @@ def _aten_max_pool_with_indices_onnx(
11361136
) -> Tuple[TFloatOrUInt8, INT64]:
11371137
self_rank_is_unbatched_rank = Rank(self) == unbatched_rank
11381138
if self_rank_is_unbatched_rank:
1139-
self = op.Unsqueeze(self, axes=0)
1139+
self = op.Unsqueeze(self, axes=[0])
11401140

11411141
pool_result, indices = op.MaxPool(
11421142
self,
@@ -1191,8 +1191,8 @@ def _aten_max_pool_with_indices_onnx(
11911191
indices = op.Sub(indices, delta)
11921192

11931193
if self_rank_is_unbatched_rank:
1194-
pool_result = op.Squeeze(pool_result, op.Constant(value_ints=[0]))
1195-
indices = op.Squeeze(indices, op.Constant(value_ints=[0]))
1194+
pool_result = op.Squeeze(pool_result, [0])
1195+
indices = op.Squeeze(indices, [0])
11961196

11971197
return (pool_result, indices)
11981198

@@ -1365,11 +1365,11 @@ def aten_nll_loss(
13651365

13661366
self_rank_is_1 = Rank(self) == 1
13671367
if self_rank_is_1: # self rank should be at least 2
1368-
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
1368+
self = op.Unsqueeze(self, [0])
13691369

13701370
rank_target = Rank(target)
13711371
if rank_target == 0: # target rank should be at least 1
1372-
target = op.Unsqueeze(target, op.Constant(value_ints=[0]))
1372+
target = op.Unsqueeze(target, [0])
13731373

13741374
if reduction == 0:
13751375
reduction_str = "none"

onnxscript/function_libs/torch_lib/ops/special.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
219219

220220
self_is_scalar = len(self.shape) == 0
221221
if self_is_scalar:
222-
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
222+
self = op.Unsqueeze(self, [0])
223223
result = op.LogSoftmax(self, axis=dim)
224224
if dtype != -1:
225225
result = op.Cast(result, to=dtype)

0 commit comments

Comments
 (0)