Skip to content
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5740,10 +5740,43 @@ def aten_type_as(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


def aten_unfold(self: TensorType, dimension: int, size: int, step: int) -> TensorType:
@torch_op("aten::unfold", trace_only=True) # FIXME: Seems ast.For was not supported
def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what the op is supposed to do. Is there a description I can read somewhere? Looks like it is a variant of a slice-like op with params (size, step) in (dimension), followed by some form of transpose?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unfold do below things:
x = [1,2,3,4,5,6],
unfold it to [1,2],[2,3],[3,4]... when stride=2,step=1
unfold it to [1,2],[3,4],[5,6] when stride=2,step=2
it is a Core aten op.

"""unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)"""

raise NotImplementedError()
self_shape = op.Shape(self)
self_rank = op.Size(self_shape)
if self_rank == 0:
result = op.Unsqueeze(self, 0)
else:
dims = op.Constant(value_ints=[dimension])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to use Expand because Constant requires a compile time constant

dim_size = op.Gather(self_shape, dims, axis=0)
# target = ((i-size)/step + 1) * step != i-size+step
target_end = op.Squeeze(((dim_size - size) / step + 1) * step)
seq_result = op.SequenceEmpty()

for i in range(0, target_end, step):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gramalingam Do you have recommendation on how this loop should be created?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not set trace_only=True, will get this error:

onnxscript\analysis.py:87: in defs
    raise ValueError(f"Unsupported statement type {type(stmt)!r}.")
E   ValueError: Unsupported statement type <class 'ast.For'>.

starts = op.Constant(value_ints=[i])
ends = starts + size
slice_result = op.Slice(self, starts, ends, dims)
seq_result = op.SequenceInsert(seq_result, slice_result)
concat_result = op.ConcatFromSequence(seq_result, axis=dimension, new_axis=1)

# Generate permute of the new shape
# Below logic equal to:
# perm = [0,1,2,3,4]
# perm.append(perm.pop(dimension+1))

rank_result = op.Squeeze(op.Size(op.Shape(concat_result)))
dim_plus_1 = dimension + 1
dim_plus_2 = dimension + 2
perm_prefix = op.Range(0, dim_plus_1, 1)
perm_suffix = op.Range(dim_plus_2, rank_result, 1)
per_dim = op.Range(dim_plus_1, dim_plus_2, 1)
perm = op.Concat(perm_prefix, perm_suffix, per_dim, axis=0)

result = op.Transpose(concat_result, perm=perm)
return result


def aten_unfold_backward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def _where_input_wrangler(
"slice": core_ops.aten_slice,
"sum": (core_ops.aten_sum_dim_IntList, _sum_input_wrangler),
"transpose": core_ops.aten_transpose,
"unfold": core_ops.aten_unfold,
"zeros_like": core_ops.aten_zeros_like,
}

Expand Down