Skip to content

Commit 10f9a1f

Browse files
authored
Clean up Shape calls | chore(torchlib) (#1163)
Update calls to `Shape` to use the `start` and `end` arguments to simplify the graph and avoid `Gather` nodes.
1 parent 804ed01 commit 10f9a1f

File tree

1 file changed

+4
-7
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+4
-7
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4058,12 +4058,9 @@ def aten_index_put_bool(
40584058
# change array([F,F,T,F,F]) to array([2])
40594059
index = op.ArgMax(index_int) # assume index only have 1 True
40604060
# change array([2]) to array([2,2,2,2,2])
4061-
self_dim_1 = op.Gather(op.Shape(self), 1)
4062-
index_dim_0 = op.Gather(op.Shape(index), 0)
4063-
neg_1 = op.Constant(value_ints=[-1])
4064-
shape = op.Concat(
4065-
op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0
4066-
)
4061+
self_dim_1 = op.Shape(self, start=1, end=2)
4062+
index_dim_0 = op.Shape(index, start=0, end=1)
4063+
shape = op.Concat(self_dim_1, index_dim_0, axis=0)
40674064
new_ind = op.Expand(index, shape)
40684065
new_ind_t = op.Transpose(new_ind)
40694066

@@ -7512,7 +7509,7 @@ def _center_window_around_zeros_if_needed(
75127509
window: TFloatOrBFloat16, n_fft: int
75137510
) -> TFloatOrBFloat16:
75147511
# first dimension
7515-
n_win = op.Gather(op.Shape(window), 0)
7512+
n_win = op.Shape(window, start=0, end=1)
75167513
# Center window around zeros if needed (required by ONNX's STFT)
75177514
if n_win < n_fft:
75187515
left = (n_fft - n_win) / 2

0 commit comments

Comments
 (0)