Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3930,16 +3930,36 @@ def aten_negative(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_new_empty(self: TensorType, size: INT64) -> TensorType:
@torch_op("aten::new_empty")
def aten_new_empty(self: TTensor, size: INT64, dtype: int = -1) -> TTensor:
# new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

raise NotImplementedError()
# using zero to simulate empty array
zero = op.Constant(value_float=0.0)
result = op.Expand(zero, size)
if dtype == -1:
result = op.CastLike(result, self)
else:
result = op.Cast(result, to=dtype)
Comment on lines +3947 to +3950
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype needs to be the same for both branches. this op may need to be trace only

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not got your point

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So ONNX requires that both branches of an If node’s type to be the same. Since we are casting here the graph violates this constraint. I think Rama is looking at potential solutions, but for now we will need to mark the function traceonly, or else the graph would be invalid.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More info in onnx/onnx#4872

return result


def aten_new_empty_strided(self: TensorType, size: INT64, stride: INT64) -> TensorType:
@torch_op("aten::new_empty_strided")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially trace only

def aten_new_empty_strided(
self: TTensor,
size: INT64,
stride: INT64, # pylint: disable=unused-argument
dtype: int = -1,
) -> TTensor:
# new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

raise NotImplementedError()
# using zero to simulate empty array
zero = op.ConstantOfShape(size)
if dtype == -1:
result = op.CastLike(zero, self)
else:
result = op.Cast(zero, to=dtype)
return result


@torch_op("aten::new_full")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def _where_input_wrangler(
"mul": core_ops.aten_mul,
"ne": core_ops.aten_ne,
"neg": core_ops.aten_neg,
"new_empty": core_ops.aten_new_empty,
"new_empty_strided": core_ops.aten_new_empty_strided,
"new_full": core_ops.aten_new_full,
"nn.functional.adaptive_avg_pool1d": nn_ops.aten_adaptive_avg_pool1d,
"nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d,
Expand Down Expand Up @@ -396,6 +398,8 @@ def _where_input_wrangler(
skip("empty_like", reason="Using zeros_like to simulate empty_like"),
xfail("logcumsumexp", reason="naive implementation not numerically stable"),
xfail("logsumexp", reason="ONNX Runtime 1.13 does not support ReduceLogSumExp-18"),
xfail("new_empty", reason="Using zeros to simulate empty"),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaowuhu I also realized this can succeed unexpectedly. In that case a skip may be a better choice than xfail.

xfail("new_empty_strided", reason="Using zeros to simulate empty"),
xfail(
"nn.functional.upsample_nearest2d",
reason="enable when ONNX Runtime does support opset18",
Expand Down