Skip to content

Commit 9ee8c92

Browse files
xaduprejustinchubytitaiwangms
authored
Fix include_self for scatter_reduce (#2090)
Implement logic for include_self. Fixes pytorch/pytorch#147617 --------- Co-authored-by: Justin Chu <[email protected]> Co-authored-by: Ti-Tai Wang <[email protected]>
1 parent 5d969c4 commit 9ee8c92

File tree

4 files changed

+124
-13
lines changed

4 files changed

+124
-13
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import math
1515
from typing import Any, Optional, Sequence, Tuple, Union
1616

17+
import numpy as np
18+
import torch
19+
1720
from onnxscript import (
1821
BFLOAT16,
1922
BOOL,
@@ -7599,13 +7602,62 @@ def aten_scatter_reduce(
75997602
"amax": "max",
76007603
}
76017604
onnx_reduce = reduce_mode[reduce]
7605+
dtype = src.dtype or self.dtype
7606+
assert dtype is not None, "dtype should be not None"
7607+
76027608
self_is_scalar = len(self.shape) == 0
76037609
if self_is_scalar: # assert (index_rank == 0 and rank_src == 0)
76047610
neg_1 = op.Constant(value_ints=[-1])
76057611
self = op.Reshape(self, neg_1)
76067612
index = op.Reshape(index, neg_1)
76077613
src = op.Reshape(src, neg_1)
7614+
7615+
if not include_self:
7616+
# onnx standard always assume the value from self is part of the reduction.
7617+
# A first step is added to replace the impacted value by another one
7618+
# chosen in a way that the results of the reduction is not changed
7619+
# whether or not it takes part in it.
7620+
# It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul.
7621+
# mean is not supported.
7622+
if onnx_reduce == "max":
7623+
if dtype in {
7624+
ir.DataType.FLOAT16,
7625+
ir.DataType.FLOAT,
7626+
ir.DataType.DOUBLE,
7627+
}:
7628+
value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype)
7629+
elif dtype == ir.DataType.BFLOAT16:
7630+
value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype)
7631+
else:
7632+
value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype)
7633+
reduction_init = "min"
7634+
elif onnx_reduce == "min":
7635+
if dtype in {
7636+
ir.DataType.FLOAT16,
7637+
ir.DataType.FLOAT,
7638+
ir.DataType.DOUBLE,
7639+
}:
7640+
value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype)
7641+
elif dtype == ir.DataType.BFLOAT16:
7642+
value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype)
7643+
else:
7644+
value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype)
7645+
reduction_init = "max"
7646+
elif onnx_reduce == "add":
7647+
value = ir.tensor([0], dtype=dtype)
7648+
reduction_init = "none"
7649+
elif onnx_reduce == "mul":
7650+
value = ir.tensor([1], dtype=dtype)
7651+
reduction_init = "none"
7652+
else:
7653+
value = 0
7654+
reduction_init = "none"
7655+
7656+
cst = op.ConstantOfShape(op.Shape(src), value=value)
7657+
self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init)
7658+
76087659
result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce)
7660+
76097661
if self_is_scalar:
76107662
result = op.Squeeze(result)
76117663
return result

onnxscript/optimizer/_constant_folding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -867,9 +867,10 @@ def _do_inference(self, node: ir.Node) -> None:
867867

868868
# TODO: handle optional inputs
869869
def get_constant_value(x: ir.Value) -> onnx.TensorProto | None:
870-
value = _get_numpy_value(x)
871-
if isinstance(value, np.ndarray) and value.size < 20:
872-
return onnx.numpy_helper.from_array(value, x.name)
870+
value = _get_numpy_value(x, size_limit=20)
871+
if value is not None:
872+
assert x.const_value is not None
873+
return ir.serde.serialize_tensor(x.const_value)
873874
return None
874875

875876
def get_type(value: ir.Value) -> onnx.TypeProto | None:
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo
5+
6+
import unittest
7+
8+
import onnxruntime
9+
import torch
10+
11+
from tests.common import testutils
12+
13+
14+
class TorchLibe2eTest(testutils.TestBase):
15+
def test_investigate_one_particular_model(self):
16+
"""This test can be used to investigate a particular issue."""
17+
red, include, stype = "amin", False, "int32"
18+
dtype = getattr(torch, stype)
19+
20+
class Model(torch.nn.Module):
21+
def __init__(self, include, red):
22+
super().__init__()
23+
self.include = include
24+
self.red = red
25+
26+
def forward(self, x, indices, updates):
27+
x = x.clone()
28+
return x.scatter_reduce(
29+
0, indices, updates, self.red, include_self=self.include
30+
)
31+
32+
model = Model(include, red)
33+
xs = (
34+
torch.tensor([[-2, 0, 2], [2, -2, 0]], dtype=dtype),
35+
torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64),
36+
torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype),
37+
)
38+
expected = model(*xs)
39+
model_path = (
40+
f"test_aten_scatter_{red}_{'include' if include else 'exclude'}_{stype}.onnx"
41+
)
42+
torch.onnx.export(model, xs, model_path, dynamo=True)
43+
feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs]))
44+
45+
sess_options = onnxruntime.SessionOptions()
46+
sess = onnxruntime.InferenceSession(
47+
model_path, sess_options=sess_options, providers=["CPUExecutionProvider"]
48+
)
49+
got = sess.run(None, feeds)[0]
50+
torch.testing.assert_close(expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5)
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,26 +2026,30 @@ def _where_input_wrangler(
20262026
variant_name="mean",
20272027
reason="ONNX doesn't support reduce='mean' option",
20282028
)
2029-
.skip(
2030-
# ONNX has not include_self parameter and default is include_self=True mode
2031-
matcher=lambda sample: sample.kwargs.get("include_self") is False,
2032-
reason="ONNX does't support include_self=False option",
2029+
.xfail(
2030+
variant_name="prod",
2031+
dtypes=(torch.float16, torch.float64),
2032+
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'",
20332033
)
20342034
.xfail(
2035-
variant_name="amax",
2036-
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'",
2035+
variant_name="sum",
2036+
dtypes=(torch.float16, torch.float64),
2037+
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'",
20372038
)
20382039
.xfail(
2039-
variant_name="amin",
2040-
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'",
2040+
variant_name="mean",
2041+
dtypes=(torch.bfloat16,),
2042+
reason="onnxruntime does not support ml_dtypes.bfloat16",
20412043
)
20422044
.xfail(
20432045
variant_name="prod",
2044-
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'prod'",
2046+
dtypes=(torch.bfloat16,),
2047+
reason="onnxruntime does not support ml_dtypes.bfloat16",
20452048
)
20462049
.xfail(
20472050
variant_name="sum",
2048-
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'",
2051+
dtypes=(torch.bfloat16,),
2052+
reason="onnxruntime does not support ml_dtypes.bfloat16",
20492053
),
20502054
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter),
20512055
TorchLibOpInfo("slice", core_ops.aten_slice),

0 commit comments

Comments
 (0)