Skip to content

Commit b8a7671

Browse files
xaduprejustinchubytitaiwangms
authored andcommitted
Fix include_self for scatter_reduce (microsoft#2090)
Implement logic for include_self. Fixes pytorch/pytorch#147617 --------- Co-authored-by: Justin Chu <[email protected]> Co-authored-by: Ti-Tai Wang <[email protected]>
1 parent 1ff713e commit b8a7671

File tree

4 files changed

+124
-13
lines changed

4 files changed

+124
-13
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import math
1515
from typing import Any, Optional, Sequence, Tuple, Union
1616

17+
import numpy as np
18+
import torch
19+
1720
from onnxscript import (
1821
BFLOAT16,
1922
BOOL,
@@ -7512,13 +7515,62 @@ def aten_scatter_reduce(
75127515
"amax": "max",
75137516
}
75147517
onnx_reduce = reduce_mode[reduce]
7518+
dtype = src.dtype or self.dtype
7519+
assert dtype is not None, "dtype should be not None"
7520+
75157521
self_is_scalar = len(self.shape) == 0
75167522
if self_is_scalar: # assert (index_rank == 0 and rank_src == 0)
75177523
neg_1 = op.Constant(value_ints=[-1])
75187524
self = op.Reshape(self, neg_1)
75197525
index = op.Reshape(index, neg_1)
75207526
src = op.Reshape(src, neg_1)
7527+
7528+
if not include_self:
7529+
# onnx standard always assume the value from self is part of the reduction.
7530+
# A first step is added to replace the impacted value by another one
7531+
# chosen in a way that the results of the reduction is not changed
7532+
# whether or not it takes part in it.
7533+
# It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul.
7534+
# mean is not supported.
7535+
if onnx_reduce == "max":
7536+
if dtype in {
7537+
ir.DataType.FLOAT16,
7538+
ir.DataType.FLOAT,
7539+
ir.DataType.DOUBLE,
7540+
}:
7541+
value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype)
7542+
elif dtype == ir.DataType.BFLOAT16:
7543+
value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype)
7544+
else:
7545+
value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype)
7546+
reduction_init = "min"
7547+
elif onnx_reduce == "min":
7548+
if dtype in {
7549+
ir.DataType.FLOAT16,
7550+
ir.DataType.FLOAT,
7551+
ir.DataType.DOUBLE,
7552+
}:
7553+
value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype)
7554+
elif dtype == ir.DataType.BFLOAT16:
7555+
value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype)
7556+
else:
7557+
value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype)
7558+
reduction_init = "max"
7559+
elif onnx_reduce == "add":
7560+
value = ir.tensor([0], dtype=dtype)
7561+
reduction_init = "none"
7562+
elif onnx_reduce == "mul":
7563+
value = ir.tensor([1], dtype=dtype)
7564+
reduction_init = "none"
7565+
else:
7566+
value = 0
7567+
reduction_init = "none"
7568+
7569+
cst = op.ConstantOfShape(op.Shape(src), value=value)
7570+
self = op.ScatterElements(self, index, cst, axis=dim, reduction=reduction_init)
7571+
75217572
result = op.ScatterElements(self, index, src, axis=dim, reduction=onnx_reduce)
7573+
75227574
if self_is_scalar:
75237575
result = op.Squeeze(result)
75247576
return result

onnxscript/optimizer/_constant_folding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -867,9 +867,10 @@ def _do_inference(self, node: ir.Node) -> None:
867867

868868
# TODO: handle optional inputs
869869
def get_constant_value(x: ir.Value) -> onnx.TensorProto | None:
870-
value = _get_numpy_value(x)
871-
if isinstance(value, np.ndarray) and value.size < 20:
872-
return onnx.numpy_helper.from_array(value, x.name)
870+
value = _get_numpy_value(x, size_limit=20)
871+
if value is not None:
872+
assert x.const_value is not None
873+
return ir.serde.serialize_tensor(x.const_value)
873874
return None
874875

875876
def get_type(value: ir.Value) -> onnx.TypeProto | None:
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo
5+
6+
import unittest
7+
8+
import onnxruntime
9+
import torch
10+
11+
from tests.common import testutils
12+
13+
14+
class TorchLibe2eTest(testutils.TestBase):
15+
def test_investigate_one_particular_model(self):
16+
"""This test can be used to investigate a particular issue."""
17+
red, include, stype = "amin", False, "int32"
18+
dtype = getattr(torch, stype)
19+
20+
class Model(torch.nn.Module):
21+
def __init__(self, include, red):
22+
super().__init__()
23+
self.include = include
24+
self.red = red
25+
26+
def forward(self, x, indices, updates):
27+
x = x.clone()
28+
return x.scatter_reduce(
29+
0, indices, updates, self.red, include_self=self.include
30+
)
31+
32+
model = Model(include, red)
33+
xs = (
34+
torch.tensor([[-2, 0, 2], [2, -2, 0]], dtype=dtype),
35+
torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64),
36+
torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype),
37+
)
38+
expected = model(*xs)
39+
model_path = (
40+
f"test_aten_scatter_{red}_{'include' if include else 'exclude'}_{stype}.onnx"
41+
)
42+
torch.onnx.export(model, xs, model_path, dynamo=True)
43+
feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs]))
44+
45+
sess_options = onnxruntime.SessionOptions()
46+
sess = onnxruntime.InferenceSession(
47+
model_path, sess_options=sess_options, providers=["CPUExecutionProvider"]
48+
)
49+
got = sess.run(None, feeds)[0]
50+
torch.testing.assert_close(expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5)
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2100,26 +2100,30 @@ def _where_input_wrangler(
21002100
variant_name="mean",
21012101
reason="ONNX doesn't support reduce='mean' option",
21022102
)
2103-
.skip(
2104-
# ONNX has not include_self parameter and default is include_self=True mode
2105-
matcher=lambda sample: sample.kwargs.get("include_self") is False,
2106-
reason="ONNX does't support include_self=False option",
2103+
.xfail(
2104+
variant_name="prod",
2105+
dtypes=(torch.float16, torch.float64),
2106+
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 16 when reduction is 'mul'",
21072107
)
21082108
.xfail(
2109-
variant_name="amax",
2110-
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'",
2109+
variant_name="sum",
2110+
dtypes=(torch.float16, torch.float64),
2111+
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'",
21112112
)
21122113
.xfail(
2113-
variant_name="amin",
2114-
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'",
2114+
variant_name="mean",
2115+
dtypes=(torch.bfloat16,),
2116+
reason="onnxruntime does not support ml_dtypes.bfloat16",
21152117
)
21162118
.xfail(
21172119
variant_name="prod",
2118-
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'prod'",
2120+
dtypes=(torch.bfloat16,),
2121+
reason="onnxruntime does not support ml_dtypes.bfloat16",
21192122
)
21202123
.xfail(
21212124
variant_name="sum",
2122-
reason="fixme: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'add'",
2125+
dtypes=(torch.bfloat16,),
2126+
reason="onnxruntime does not support ml_dtypes.bfloat16",
21232127
),
21242128
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter),
21252129
TorchLibOpInfo("slice", core_ops.aten_slice),

0 commit comments

Comments
 (0)