Skip to content

Make aten::contiguous and device_put no-op | fix(torchlib) #835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,17 +1425,14 @@ def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTens
return op.Pad(self, onnx_padding, value)


@torch_op("aten::contiguous", trace_only=True)
def aten_contiguous(self: TTensor, memory_format: str = "contiguous_format") -> TTensor:
@torch_op("aten::contiguous")
def aten_contiguous(
self: TTensor, memory_format: str = "contiguous_format" # pylint: disable=unused-argument
) -> TTensor:
"""contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)"""

if memory_format in ["contiguous_format", "preserve_format"]:
return op.Identity(self)
else:
# TODO: Find out a way to annotate constraints for argument, as part of the function meta data structure.
raise NotImplementedError(
"memory_format value supports 'contiguous_format' or 'preserve_format' only."
)
# ONNX does not have the notion of memory_format. It is always treated as a no-op.
return op.Identity(self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we do so, users calling this function will have a wrong misunderstanding that all of formats were processed successfully, which is not right.
If we can handle this op in an earlier phase by exporter, this should be fine, and we'd better leave a comment here.

Copy link
Collaborator Author

@justinchuby justinchuby Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the format is an internal representation and does not affect computation in terms of the result? I gathered that from https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html Please feel free to correct me

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BowenBao do you have more info on this op? Do you think we should filter it out in a fx pass?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chatted offline. Handling in fx pass offers a more fundamental solution that asserts correctness, yet it requires much larger effort and targets only edge cases, which does not cut it in terms of priorities. Hence the approach in this PR is preferred.



@torch_op("aten::conv1d", trace_only=True)
Expand Down
10 changes: 8 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from onnxscript import INT64
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TTensor
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

Expand Down Expand Up @@ -246,10 +247,15 @@ def prims_cosh(self: TensorType) -> TensorType:
raise NotImplementedError()


def prims_device_put(a: TensorType, device: str) -> TensorType:
@torch_op("prims::device_put")
def prims_device_put(
a: TTensor,
device: str = "unspecified", # pylint: disable=unused-argument
) -> TTensor:
"""device_put(Tensor a, Device device) -> Tensor"""

raise NotImplementedError()
# ONNX does not have the notion of a "device". So we just return the input
return op.Identity(a)


def prims_digamma(self: TensorType) -> TensorType:
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,7 @@ def _where_input_wrangler(
reason="fixme: Tensor-likes are not close. https://github.com/microsoft/onnxruntime/issues/16007",
),
TorchLibOpInfo("cumsum", core_ops.aten_cumsum, trace_only=True),
TorchLibOpInfo("contiguous", core_ops.aten_contiguous, trace_only=True),
TorchLibOpInfo("contiguous", core_ops.aten_contiguous),
TorchLibOpInfo(
"convolution",
core_ops.aten_convolution,
Expand Down