Skip to content

Commit 52eb1bc

Browse files
authored
Merge branch 'main' into xiaowu/addOp(BatchNorm)
2 parents b294330 + 35fa2c9 commit 52eb1bc

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3955,16 +3955,36 @@ def aten_negative(self: TensorType) -> TensorType:
39553955
raise NotImplementedError()
39563956

39573957

3958-
def aten_new_empty(self: TensorType, size: INT64) -> TensorType:
3958+
@torch_op("aten::new_empty")
3959+
def aten_new_empty(self: TTensor, size: INT64, dtype: int = -1) -> TTensor:
39593960
# new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
39603961

3961-
raise NotImplementedError()
3962+
# using zero to simulate empty array
3963+
zero = op.Constant(value_float=0.0)
3964+
result = op.Expand(zero, size)
3965+
if dtype == -1:
3966+
result = op.CastLike(result, self)
3967+
else:
3968+
result = op.Cast(result, to=dtype)
3969+
return result
39623970

39633971

3964-
def aten_new_empty_strided(self: TensorType, size: INT64, stride: INT64) -> TensorType:
3972+
@torch_op("aten::new_empty_strided")
3973+
def aten_new_empty_strided(
3974+
self: TTensor,
3975+
size: INT64,
3976+
stride: INT64, # pylint: disable=unused-argument
3977+
dtype: int = -1,
3978+
) -> TTensor:
39653979
# new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
39663980

3967-
raise NotImplementedError()
3981+
# using zero to simulate empty array
3982+
zero = op.ConstantOfShape(size)
3983+
if dtype == -1:
3984+
result = op.CastLike(zero, self)
3985+
else:
3986+
result = op.Cast(zero, to=dtype)
3987+
return result
39683988

39693989

39703990
@torch_op("aten::new_full")

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ def _where_input_wrangler(
306306
# "native_dropout": core_ops.aten_native_dropout, # native_dropout is not in OPS_DB
307307
"ne": core_ops.aten_ne,
308308
"neg": core_ops.aten_neg,
309+
"new_empty": core_ops.aten_new_empty,
310+
"new_empty_strided": core_ops.aten_new_empty_strided,
309311
"new_full": core_ops.aten_new_full,
310312
"nn.functional.adaptive_avg_pool1d": nn_ops.aten_adaptive_avg_pool1d,
311313
"nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d,
@@ -400,6 +402,8 @@ def _where_input_wrangler(
400402
skip("empty_like", reason="Using zeros_like to simulate empty_like"),
401403
xfail("logcumsumexp", reason="naive implementation not numerically stable"),
402404
xfail("logsumexp", reason="ONNX Runtime 1.13 does not support ReduceLogSumExp-18"),
405+
xfail("new_empty", reason="Using zeros to simulate empty"),
406+
xfail("new_empty_strided", reason="Using zeros to simulate empty"),
403407
xfail(
404408
"nn.functional.upsample_nearest2d",
405409
reason="enable when ONNX Runtime does support opset18",

0 commit comments

Comments
 (0)