-
Notifications
You must be signed in to change notification settings - Fork 107
feat(atenlib): add op(unfold) #534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
dd719c0
0ed5fa8
5535660
1311962
e8e34c7
068c834
df3c94c
60d5604
b3c41a3
535d5bc
97399ae
346cc9e
51fbe4c
0084cb1
838700d
9c4a385
2fe7160
25657fe
8ada8bf
b77b2a3
eeac06f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5740,10 +5740,43 @@ def aten_type_as(self: TensorType, other: TensorType) -> TensorType: | |
| raise NotImplementedError() | ||
|
|
||
|
|
||
| def aten_unfold(self: TensorType, dimension: int, size: int, step: int) -> TensorType: | ||
| @torch_op("aten::unfold", trace_only=True) # FIXME: Seems ast.For was not supported | ||
| def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor: | ||
| """unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)""" | ||
|
|
||
| raise NotImplementedError() | ||
| self_shape = op.Shape(self) | ||
| self_rank = op.Size(self_shape) | ||
| if self_rank == 0: | ||
| result = op.Unsqueeze(self, 0) | ||
| else: | ||
| dims = op.Constant(value_ints=[dimension]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may need to use Expand because Constant requires a compile time constant |
||
| dim_size = op.Gather(self_shape, dims, axis=0) | ||
| # target = ((i-size)/step + 1) * step != i-size+step | ||
| target_end = op.Squeeze(((dim_size - size) / step + 1) * step) | ||
| seq_result = op.SequenceEmpty() | ||
|
|
||
| for i in range(0, target_end, step): | ||
|
||
| starts = op.Constant(value_ints=[i]) | ||
| ends = starts + size | ||
| slice_result = op.Slice(self, starts, ends, dims) | ||
| seq_result = op.SequenceInsert(seq_result, slice_result) | ||
| concat_result = op.ConcatFromSequence(seq_result, axis=dimension, new_axis=1) | ||
|
|
||
| # Generate permute of the new shape | ||
| # Below logic equal to: | ||
| # perm = [0,1,2,3,4] | ||
| # perm.append(perm.pop(dimension+1)) | ||
|
|
||
| rank_result = op.Squeeze(op.Size(op.Shape(concat_result))) | ||
| dim_plus_1 = dimension + 1 | ||
| dim_plus_2 = dimension + 2 | ||
| perm_prefix = op.Range(0, dim_plus_1, 1) | ||
| perm_suffix = op.Range(dim_plus_2, rank_result, 1) | ||
| per_dim = op.Range(dim_plus_1, dim_plus_2, 1) | ||
| perm = op.Concat(perm_prefix, perm_suffix, per_dim, axis=0) | ||
|
|
||
| result = op.Transpose(concat_result, perm=perm) | ||
| return result | ||
|
|
||
|
|
||
| def aten_unfold_backward( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what the op is supposed to do. Is there a description I can read somewhere? Looks like it is a variant of a slice-like op with params (size, step) in (dimension), followed by some form of transpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unfold do below things:
x = [1,2,3,4,5,6],
unfold it to [1,2],[2,3],[3,4]... when stride=2,step=1
unfold it to [1,2],[3,4],[5,6] when stride=2,step=2
it is a Core aten op.