Skip to content

Add op(unfold) | feat(torchlib) #893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 19, 2023
41 changes: 39 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6718,10 +6718,47 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
return op.Reshape(self, final_shape)


def aten_unfold(self: TensorType, dimension: int, size: int, step: int) -> TensorType:
@torch_op("aten::unfold", trace_only=True)
def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor:
"""unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)"""

raise NotImplementedError()
self_rank = len(self.shape)
if self_rank == 0:
result = op.Unsqueeze(self, 0)
else:
dim_size = self.shape[dimension]
target_end = (dim_size - size) // step + 1
if target_end > 1: # the rank of final reuslt will be self_rank + 1
self_rank = self_rank + 1
# perm need to be list[int], so have to be generated in trace_only mode
perm = list(range(self_rank))
# from [0,1,2,3,4] -> [0,1,3,4,2] when dimension=1
perm.append(perm.pop(dimension + 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to add a comment to explain this logic, thanks!

Copy link
Contributor Author

@xiaowuhu xiaowuhu Jul 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this though? What does it mean? Is there a reference logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result = _aten_unfold_onnx(self, dimension, size, step, target_end, perm)
return result


@torch_op("aten::unfold", private=True)
def _aten_unfold_onnx(
self: TTensor, dim: int, size: int, step: int, target_end: int, perm: Sequence[int]
) -> TTensor:
dims = op.Reshape(op.Constant(value_int=dim), op.Constant(value_ints=[-1]))
# FIXME: the dtype for this function cannot work, default to float
seq_result = op.SequenceEmpty()
i = op.Constant(value_ints=[0])
cond = i < target_end
while cond: # because for loop cannot work here, so use while loop
starts = i * step # starts is [0, step, step*2, step*3, ...]
ends = starts + size # ends is [0+size, step+size, step*2+size, step*3+size, ...]
slice_result = op.Slice(self, starts, ends, dims)
# sequence only support float32
slice_result_float32 = op.Cast(slice_result, to=FLOAT.dtype)
seq_result = op.SequenceInsert(seq_result, slice_result_float32)
i = i + 1
cond = i < target_end
concat_result = op.ConcatFromSequence(seq_result, axis=dim, new_axis=1)
result = op.Transpose(concat_result, perm=perm)
return op.CastLike(result, self)


def aten_unfold_backward(
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def _where_input_wrangler(
TorchLibOpInfo("acos", core_ops.aten_acos),
TorchLibOpInfo("acosh", core_ops.aten_acosh),
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("addbmm", core_ops.aten_addbmm),
TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}),
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
TorchLibOpInfo("addmm", core_ops.aten_addmm),
Expand Down Expand Up @@ -1166,6 +1166,7 @@ def _where_input_wrangler(
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
reason="fixme: Logic not implemented for size 0 inputs in op.Reshape",
),
TorchLibOpInfo("unfold", core_ops.aten_unfold, trace_only=True),
TorchLibOpInfo("unsqueeze", core_ops.aten_unsqueeze),
TorchLibOpInfo("view", core_ops.aten_view),
TorchLibOpInfo("view_as", core_ops.aten_view_as),
Expand Down