Skip to content

Fix Op(unflatten) #2070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8433,16 +8433,16 @@
return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False)


@torch_op("aten::unflatten.int")
def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
@torch_op("aten::unflatten.int", trace_only=True)
def aten_unflatten(self: TReal, dim: int, sizes: Sequence[INT64]):
"""unflatten(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)"""

self_size = op.Shape(self)

# PyTorch accepts negative dim as reversed counting
self_rank = op.Size(self_size)
dim = self_rank + dim
dim = dim % self_rank
self_rank = len(self.shape)

Check warning on line 8443 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8443

Added line #L8443 was not covered by tests
if dim < 0:
dim = self_rank + dim

Check warning on line 8445 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8445

Added line #L8445 was not covered by tests

head_start_idx = op.Constant(value_ints=[0])
head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1]))
Expand All @@ -8452,8 +8452,16 @@
tail_end_idx = op.Constant(value_ints=[_INT64_MAX])
tail_part_rank = op.Slice(self_size, tail_start_idx, tail_end_idx)

final_shape = op.Concat(head_part_rank, sizes, tail_part_rank, axis=0)
sizes = [op.Reshape(size, op.Constant(value_ints=[1])) for size in sizes]

Check warning on line 8455 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8455

Added line #L8455 was not covered by tests

# corner case 1: head part is None
if dim == 0:
final_shape = op.Concat(*sizes, tail_part_rank, axis=0)

Check warning on line 8459 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8459

Added line #L8459 was not covered by tests
# corner case 2: tail part is None
elif dim == self_rank - 1:
final_shape = op.Concat(head_part_rank, *sizes, axis=0)

Check warning on line 8462 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8462

Added line #L8462 was not covered by tests
else:
final_shape = op.Concat(head_part_rank, *sizes, tail_part_rank, axis=0)

Check warning on line 8464 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8464

Added line #L8464 was not covered by tests
return op.Reshape(self, final_shape)


Expand Down
14 changes: 1 addition & 13 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,6 @@ def _sum_input_wrangler(
return args, kwargs


def _unflatten_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args[1] = np.array(args[1], dtype=np.int64)
return args, kwargs


def _where_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -1471,14 +1464,9 @@ def _where_input_wrangler(
TorchLibOpInfo(
"unflatten",
core_ops.aten_unflatten,
input_wrangler=_unflatten_input_wrangler,
)
.xfail(
).xfail(
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
reason="fixme: Logic not implemented for size 0 inputs in op.Reshape",
)
.xfail(
reason="fixme: https://github.com/pytorch/pytorch/issues/146336",
),
TorchLibOpInfo("unfold", core_ops.aten_unfold),
TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold),
Expand Down
Loading