Skip to content

Support dynamic shapes for aten_unfold #2407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8655,29 +8655,36 @@
# Handle negative dimension
if dimension < 0:
dimension = dimension + self_rank
dim_size = self.shape[dimension]

low_indices = range(0, dim_size, step)
hi_indices = range(size, dim_size + 1, step)
stack = [
op.Slice(
self,
op.Constant(value_ints=[low]),
op.Constant(value_ints=[hi]),
op.Constant(value_ints=[dimension]),
)
for low, hi in zip(low_indices, hi_indices)
]

input_shape = op.Shape(self)
dim_size = op.Gather(input_shape, op.Constant(value_ints=[dimension]))

Check warning on line 8660 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8659-L8660

Added lines #L8659 - L8660 were not covered by tests

# Create indices for each window
window_starts = op.Range(0, op.Sub(dim_size, size - 1), step)

Check warning on line 8663 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8663

Added line #L8663 was not covered by tests

# Create the base indices for one window
window_indices = list(range(size))

Check warning on line 8666 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8666

Added line #L8666 was not covered by tests

# Broadcast to create all indices
starts_expanded = op.Unsqueeze(window_starts, [1]) # [num_windows, 1]
indices_expanded = op.Unsqueeze(window_indices, [0]) # [1, size]
all_indices = op.Add(starts_expanded, indices_expanded) # [num_windows, size]

Check warning on line 8671 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8669-L8671

Added lines #L8669 - L8671 were not covered by tests

# Gather along the specified dimension
result = op.Gather(self, all_indices, axis=dimension)

Check warning on line 8674 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8674

Added line #L8674 was not covered by tests

# The result shape is now [..., num_windows, size, ...] with num_windows at position 'dimension'.
# We need to move the size dimension to the end:
# Current shape: [..., num_windows, size, ...]
# Target shape: [..., num_windows, ..., size]

# Move the size dimension (at position dimension+1) to the end
# perm need to be list[int], so have to be generated in trace_only mode
perm = list(range(self_rank))
# from [0,1,2,3,4] -> [0,1,3,4,2] when dimension=1
perm.append(perm.pop(dimension))
unsqueeze = [
op.Unsqueeze(op.Transpose(t, perm=perm), op.Constant(value_ints=[dimension]))
for t in stack
]
result = op.Concat(*unsqueeze, axis=dimension)
perm = list(range(self_rank + 1))
perm.append(perm.pop(dimension + 1))

Check warning on line 8684 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8683-L8684

Added lines #L8683 - L8684 were not covered by tests

result = op.Transpose(result, perm=perm)

Check warning on line 8686 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8686

Added line #L8686 was not covered by tests

return result


Expand Down
Loading