Skip to content

Add ops(unflatten) | feat(atenlib) #612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6031,6 +6031,30 @@ def aten_type_as(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::unflatten")
def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
"""unflatten(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)"""

self_size = op.Shape(self)

if dim < 0:
# PyTorch accepts negative dim as reversed counting
self_rank = op.Size(self_size)
dim = self_rank + dim

head_start_idx = op.Constant(value_ints=[0])
head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1]))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gramalingam @justinchuby
ORTEvaluator doesn't seem to understand:

head_end_idx = op.Reshape(dim, [1]))

head_part_rank = op.Slice(self_size, head_start_idx, head_end_idx)

tail_start_idx = op.Reshape(dim + 1, op.Constant(value_ints=[1]))
tail_end_idx = op.Constant(value_ints=[_INT64_MAX])
tail_part_rank = op.Slice(self_size, tail_start_idx, tail_end_idx)

final_shape = op.Concat(head_part_rank, sizes, tail_part_rank, axis=0)

return op.Reshape(self, final_shape)


def aten_unfold(self: TensorType, dimension: int, size: int, step: int) -> TensorType:
"""unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)"""

Expand Down
18 changes: 18 additions & 0 deletions onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,13 @@ def _sum_input_wrangler(
return args, kwargs


def _unflatten_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args[1] = np.array(args[1], dtype=np.int64)
return args, kwargs


def _where_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -609,6 +616,7 @@ def _where_input_wrangler(
"tril": core_ops.aten_tril,
"triu": core_ops.aten_triu,
"trunc": core_ops.aten_trunc,
"unflatten": (core_ops.aten_unflatten, _unflatten_input_wrangler),
"unsqueeze": core_ops.aten_unsqueeze,
"view": core_ops.aten_view,
"where": (core_ops.aten_where, _where_input_wrangler),
Expand Down Expand Up @@ -808,6 +816,11 @@ def _where_input_wrangler(
reason="ORT Graph attribute inferencing failed on rank-1 input",
test_class_name="TestOutputConsistencyFullGraph",
),
xfail(
"unflatten",
reason="fixme: ORT fails with invalid model: 'INVALID_ARGUMENT : Failed to load model with error: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1)'",
test_class_name="TestOutputConsistencyFullGraph",
),
)


Expand Down Expand Up @@ -1063,6 +1076,11 @@ def _where_input_wrangler(
matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)),
reason="this Aten overload only support one tensor as input and one int as args by design",
),
skip(
"unflatten",
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
reason="0 dim in ONNX is undefined behavior.",
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create a condition to handle this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To raise a warning? Not sure what we should give back to users.

Copy link
Collaborator

@justinchuby justinchuby Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is pytorch's behavior?

To raise a warning? Not sure what we should give back to users.

I was suggesting handling it like pytorch does if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch takes this because 0-d is in their spec.
For example,

a = torch.randn(4, 0)
# tensor([], size=(4, 0))

torch.unflatten(a, 0, (2, 2))
# tensor([], size=(2, 2, 0))

Copy link
Contributor Author

@titaiwangms titaiwangms Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if torch_libs has a certain way to inform users when the difference between specs is hit? In TorchScript exporter, we could raise UnsupportedONNXError

Copy link
Collaborator

@justinchuby justinchuby Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. We could skip this if it is too complicated and come back to it as needed. I would update the skip reason to say something to the effect of "fixme: logic not implemented for size 0 inputs"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't raise errors in atenlib. The exporter can raise errors when it sees a difference after OpSchema is implemented

)

duplicate_opinfo(OPS_DB, "all", ("all_dim",))
Expand Down