Skip to content

Commit 35fa2c9

Browse files
authored
feat(atenlib): add ops (new_empty, new_empty_strided) (#436)
1 parent 0298154 commit 35fa2c9

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3937,16 +3937,36 @@ def aten_negative(self: TensorType) -> TensorType:
39373937
raise NotImplementedError()
39383938

39393939

3940-
def aten_new_empty(self: TensorType, size: INT64) -> TensorType:
3940+
@torch_op("aten::new_empty")
3941+
def aten_new_empty(self: TTensor, size: INT64, dtype: int = -1) -> TTensor:
39413942
# new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
39423943

3943-
raise NotImplementedError()
3944+
# using zero to simulate empty array
3945+
zero = op.Constant(value_float=0.0)
3946+
result = op.Expand(zero, size)
3947+
if dtype == -1:
3948+
result = op.CastLike(result, self)
3949+
else:
3950+
result = op.Cast(result, to=dtype)
3951+
return result
39443952

39453953

3946-
def aten_new_empty_strided(self: TensorType, size: INT64, stride: INT64) -> TensorType:
3954+
@torch_op("aten::new_empty_strided")
3955+
def aten_new_empty_strided(
3956+
self: TTensor,
3957+
size: INT64,
3958+
stride: INT64, # pylint: disable=unused-argument
3959+
dtype: int = -1,
3960+
) -> TTensor:
39473961
# new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
39483962

3949-
raise NotImplementedError()
3963+
# using zero to simulate empty array
3964+
zero = op.ConstantOfShape(size)
3965+
if dtype == -1:
3966+
result = op.CastLike(zero, self)
3967+
else:
3968+
result = op.Cast(zero, to=dtype)
3969+
return result
39503970

39513971

39523972
@torch_op("aten::new_full")

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ def _where_input_wrangler(
302302
"mul": core_ops.aten_mul,
303303
"ne": core_ops.aten_ne,
304304
"neg": core_ops.aten_neg,
305+
"new_empty": core_ops.aten_new_empty,
306+
"new_empty_strided": core_ops.aten_new_empty_strided,
305307
"new_full": core_ops.aten_new_full,
306308
"nn.functional.adaptive_avg_pool1d": nn_ops.aten_adaptive_avg_pool1d,
307309
"nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d,
@@ -396,6 +398,8 @@ def _where_input_wrangler(
396398
skip("empty_like", reason="Using zeros_like to simulate empty_like"),
397399
xfail("logcumsumexp", reason="naive implementation not numerically stable"),
398400
xfail("logsumexp", reason="ONNX Runtime 1.13 does not support ReduceLogSumExp-18"),
401+
xfail("new_empty", reason="Using zeros to simulate empty"),
402+
xfail("new_empty_strided", reason="Using zeros to simulate empty"),
399403
xfail(
400404
"nn.functional.upsample_nearest2d",
401405
reason="enable when ONNX Runtime does support opset18",

0 commit comments

Comments
 (0)