Skip to content

Commit 7c74399

Browse files
authored
feat(atenlib): add ops(normal, narrow) (#440)
1 parent 5b46ad5 commit 7c74399

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3765,10 +3765,25 @@ def aten_nansum(
37653765
raise NotImplementedError()
37663766

37673767

3768-
def aten_narrow(self: TensorType, dim: int, start: INT64, length: INT64) -> TensorType:
3768+
@torch_op("aten::narrow")
3769+
def aten_narrow(self: TTensor, dim: INT64, start: INT64, length: INT64) -> TTensor:
37693770
# narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)
37703771

3771-
raise NotImplementedError()
3772+
dim_rank = op.Size(op.Shape(dim))
3773+
if dim_rank == 0:
3774+
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
3775+
3776+
start_rank = op.Size(op.Shape(start))
3777+
if start_rank == 0:
3778+
start = op.Reshape(start, op.Constant(value_ints=[-1]))
3779+
3780+
length_rank = op.Size(op.Shape(length))
3781+
if length_rank == 0:
3782+
length = op.Reshape(length, op.Constant(value_ints=[-1]))
3783+
3784+
end = op.Add(start, length)
3785+
result = op.Slice(self, start, end, dim)
3786+
return result
37723787

37733788

37743789
def aten_narrow_copy(self: TensorType, dim: int, start: INT64, length: INT64) -> TensorType:
@@ -4036,12 +4051,20 @@ def aten_norm_except_dim(v: TensorType, pow: int = 2, dim: int = 0) -> TensorTyp
40364051
raise NotImplementedError()
40374052

40384053

4054+
@torch_op("aten::normal")
40394055
def aten_normal(
4040-
self: TensorType, mean: float = 0.0, std: float = 1.0, generator: Optional[str] = None
4041-
) -> TensorType:
4056+
self: TTensor,
4057+
mean: float = 0.0,
4058+
std: float = 1.0,
4059+
) -> TFloat: # type: ignore[type-var]
40424060
# normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor
40434061

4044-
raise NotImplementedError()
4062+
self_rank = op.Size(op.Shape(self))
4063+
if self_rank == 0:
4064+
self = op.Reshape(self, op.Constant(value_ints=[-1]))
4065+
4066+
result = op.RandomNormalLike(self, mean=mean, scale=std)
4067+
return result
40454068

40464069

40474070
def aten_not_equal(self: TensorType, other: TensorType) -> TensorType:

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def _where_input_wrangler(
303303
"minimum": core_ops.aten_minimum,
304304
"mm": core_ops.aten_mm,
305305
"mul": core_ops.aten_mul,
306+
"narrow": core_ops.aten_narrow,
306307
# "native_dropout": core_ops.aten_native_dropout, # native_dropout is not in OPS_DB
307308
"ne": core_ops.aten_ne,
308309
"neg": core_ops.aten_neg,
@@ -325,6 +326,7 @@ def _where_input_wrangler(
325326
_upsample_input_wrangler,
326327
),
327328
"nonzero": core_ops.aten_nonzero,
329+
"normal": core_ops.aten_normal,
328330
"ones": core_ops.aten_ones,
329331
"permute": core_ops.aten_permute,
330332
"pow": core_ops.aten_pow,
@@ -408,6 +410,8 @@ def _where_input_wrangler(
408410
"nn.functional.upsample_nearest2d",
409411
reason="enable when ONNX Runtime does support opset18",
410412
),
413+
xfail("normal", reason="Random numbers are not close"),
414+
xfail("normal", variant_name="number_mean", reason="Random numbers are not close"),
411415
xfail("round", variant_name="decimals_0", reason="The op does not support decimals"),
412416
xfail("round", variant_name="decimals_3", reason="The op does not support decimals"),
413417
xfail("round", variant_name="decimals_neg_3", reason="The op does not support decimals"),
@@ -455,6 +459,11 @@ def _where_input_wrangler(
455459
matcher=lambda sample: sample.kwargs.get("as_tuple") is not None,
456460
reason="as_tuple=True is not supported",
457461
),
462+
skip(
463+
"normal",
464+
matcher=lambda sample: len(sample.args) > 0 and not isinstance(sample.args[0], float),
465+
reason="ORT only accept float type for args[0] 'mean'",
466+
),
458467
skip(
459468
"nn.functional.adaptive_avg_pool1d",
460469
# Shape should be [N, C, D1]

0 commit comments

Comments
 (0)