diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4a338ecc1d..555806e299 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6718,10 +6718,47 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64): return op.Reshape(self, final_shape) -def aten_unfold(self: TensorType, dimension: int, size: int, step: int) -> TensorType: +@torch_op("aten::unfold", trace_only=True) +def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor: """unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)""" - raise NotImplementedError() + self_rank = len(self.shape) + if self_rank == 0: + result = op.Unsqueeze(self, 0) + else: + dim_size = self.shape[dimension] + target_end = (dim_size - size) // step + 1 + if target_end > 1: # the rank of final reuslt will be self_rank + 1 + self_rank = self_rank + 1 + # perm need to be list[int], so have to be generated in trace_only mode + perm = list(range(self_rank)) + # from [0,1,2,3,4] -> [0,1,3,4,2] when dimension=1 + perm.append(perm.pop(dimension + 1)) + result = _aten_unfold_onnx(self, dimension, size, step, target_end, perm) + return result + + +@torch_op("aten::unfold", private=True) +def _aten_unfold_onnx( + self: TTensor, dim: int, size: int, step: int, target_end: int, perm: Sequence[int] +) -> TTensor: + dims = op.Reshape(op.Constant(value_int=dim), op.Constant(value_ints=[-1])) + # FIXME: the dtype for this function cannot work, default to float + seq_result = op.SequenceEmpty() + i = op.Constant(value_ints=[0]) + cond = i < target_end + while cond: # because for loop cannot work here, so use while loop + starts = i * step # starts is [0, step, step*2, step*3, ...] + ends = starts + size # ends is [0+size, step+size, step*2+size, step*3+size, ...] + slice_result = op.Slice(self, starts, ends, dims) + # sequence only support float32 + slice_result_float32 = op.Cast(slice_result, to=FLOAT.dtype) + seq_result = op.SequenceInsert(seq_result, slice_result_float32) + i = i + 1 + cond = i < target_end + concat_result = op.ConcatFromSequence(seq_result, axis=dim, new_axis=1) + result = op.Transpose(concat_result, perm=perm) + return op.CastLike(result, self) def aten_unfold_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 4c7b027984..db7cd97da5 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -430,7 +430,7 @@ def _where_input_wrangler( TorchLibOpInfo("acos", core_ops.aten_acos), TorchLibOpInfo("acosh", core_ops.aten_acosh), TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), - TorchLibOpInfo("addbmm", core_ops.aten_addbmm), + TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}), TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv), TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}), TorchLibOpInfo("addmm", core_ops.aten_addmm), @@ -1166,6 +1166,7 @@ def _where_input_wrangler( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", ), + TorchLibOpInfo("unfold", core_ops.aten_unfold, trace_only=True), TorchLibOpInfo("unsqueeze", core_ops.aten_unsqueeze), TorchLibOpInfo("view", core_ops.aten_view), TorchLibOpInfo("view_as", core_ops.aten_view_as),