Skip to content

Commit 84dfcad

Browse files
authored
[torchlib] Fix prod (#2038)
1 parent dbf2353 commit 84dfcad

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6682,11 +6682,21 @@ def aten_prelu_backward(
66826682
raise NotImplementedError()
66836683

66846684

6685-
@torch_op("aten::prod.dim_int", trace_only=True)
6686-
def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal:
6685+
@torch_op("aten::prod", trace_only=True)
6686+
def aten_prod(self: TReal, dtype: int = -1) -> TReal:
66876687
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""
66886688

6689-
# Todo: add test for this function later
6689+
if dtype != -1 and dtype is not None:
6690+
self = op.Cast(self, to=dtype)
6691+
return op.ReduceProd(self)
6692+
6693+
6694+
@torch_op("aten::prod.dim_int", trace_only=True)
6695+
def aten_prod_dim_int(self: TReal, dim: int, keepdim: bool = False, dtype: int = -1) -> TReal:
6696+
"""prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
6697+
6698+
if dtype != -1 and dtype is not None:
6699+
self = op.Cast(self, to=dtype)
66906700
return op.ReduceProd(self, axes=[dim], keepdims=keepdim)
66916701

66926702

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,19 @@ def _where_input_wrangler(
12711271
),
12721272
TorchLibOpInfo("polar", core_ops.aten_polar),
12731273
TorchLibOpInfo("pow", core_ops.aten_pow),
1274+
TorchLibOpInfo("prod", core_ops.aten_prod).skip(
1275+
matcher=lambda sample: sample.kwargs.get("dim") is not None
1276+
or sample.kwargs.get("keepdim") is not None
1277+
or sample.kwargs.get("dtype") != -1,
1278+
reason="this Aten overload only accept 1 inputs: self",
1279+
),
1280+
TorchLibOpInfo("prod_dim_int", core_ops.aten_prod_dim_int).skip(
1281+
matcher=lambda sample: (
1282+
sample.kwargs.get("dim") is None and sample.kwargs.get("keepdim") is None
1283+
)
1284+
or sample.kwargs.get("dtype") != -1,
1285+
reason="this Aten overload can accept 3 inputs:(self, dim, keepdim)",
1286+
),
12741287
TorchLibOpInfo("nn.functional.prelu", core_ops.aten_prelu),
12751288
TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True),
12761289
TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True),
@@ -2203,6 +2216,7 @@ def _where_input_wrangler(
22032216
OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",)
22042217
)
22052218
ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",))
2219+
ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",))
22062220
ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",))
22072221
ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))
22082222
ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))

0 commit comments

Comments
 (0)