Skip to content

Commit ed3df93

Browse files
authored
Add op(unfold) | feat(torchlib) (#893)
from: #534
1 parent 9383d1a commit ed3df93

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6718,10 +6718,47 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
67186718
return op.Reshape(self, final_shape)
67196719

67206720

6721-
def aten_unfold(self: TensorType, dimension: int, size: int, step: int) -> TensorType:
6721+
@torch_op("aten::unfold", trace_only=True)
6722+
def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor:
67226723
"""unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)"""
67236724

6724-
raise NotImplementedError()
6725+
self_rank = len(self.shape)
6726+
if self_rank == 0:
6727+
result = op.Unsqueeze(self, 0)
6728+
else:
6729+
dim_size = self.shape[dimension]
6730+
target_end = (dim_size - size) // step + 1
6731+
if target_end > 1: # the rank of final reuslt will be self_rank + 1
6732+
self_rank = self_rank + 1
6733+
# perm need to be list[int], so have to be generated in trace_only mode
6734+
perm = list(range(self_rank))
6735+
# from [0,1,2,3,4] -> [0,1,3,4,2] when dimension=1
6736+
perm.append(perm.pop(dimension + 1))
6737+
result = _aten_unfold_onnx(self, dimension, size, step, target_end, perm)
6738+
return result
6739+
6740+
6741+
@torch_op("aten::unfold", private=True)
6742+
def _aten_unfold_onnx(
6743+
self: TTensor, dim: int, size: int, step: int, target_end: int, perm: Sequence[int]
6744+
) -> TTensor:
6745+
dims = op.Reshape(op.Constant(value_int=dim), op.Constant(value_ints=[-1]))
6746+
# FIXME: the dtype for this function cannot work, default to float
6747+
seq_result = op.SequenceEmpty()
6748+
i = op.Constant(value_ints=[0])
6749+
cond = i < target_end
6750+
while cond: # because for loop cannot work here, so use while loop
6751+
starts = i * step # starts is [0, step, step*2, step*3, ...]
6752+
ends = starts + size # ends is [0+size, step+size, step*2+size, step*3+size, ...]
6753+
slice_result = op.Slice(self, starts, ends, dims)
6754+
# sequence only support float32
6755+
slice_result_float32 = op.Cast(slice_result, to=FLOAT.dtype)
6756+
seq_result = op.SequenceInsert(seq_result, slice_result_float32)
6757+
i = i + 1
6758+
cond = i < target_end
6759+
concat_result = op.ConcatFromSequence(seq_result, axis=dim, new_axis=1)
6760+
result = op.Transpose(concat_result, perm=perm)
6761+
return op.CastLike(result, self)
67256762

67266763

67276764
def aten_unfold_backward(

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def _where_input_wrangler(
430430
TorchLibOpInfo("acos", core_ops.aten_acos),
431431
TorchLibOpInfo("acosh", core_ops.aten_acosh),
432432
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
433-
TorchLibOpInfo("addbmm", core_ops.aten_addbmm),
433+
TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}),
434434
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
435435
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
436436
TorchLibOpInfo("addmm", core_ops.aten_addmm),
@@ -1166,6 +1166,7 @@ def _where_input_wrangler(
11661166
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
11671167
reason="fixme: Logic not implemented for size 0 inputs in op.Reshape",
11681168
),
1169+
TorchLibOpInfo("unfold", core_ops.aten_unfold, trace_only=True),
11691170
TorchLibOpInfo("unsqueeze", core_ops.aten_unsqueeze),
11701171
TorchLibOpInfo("view", core_ops.aten_view),
11711172
TorchLibOpInfo("view_as", core_ops.aten_view_as),

0 commit comments

Comments
 (0)