From 80a45b35ecff2d4a3bd5c6494d5557e0be0ca094 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 May 2025 07:18:30 -0700 Subject: [PATCH 1/2] [torchlib] Fix scatter reduce on error cases --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ea43c2c4db..96d81fb3a3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7649,7 +7649,7 @@ def aten_scatter_reduce( value = ir.tensor([1], dtype=dtype) reduction_init = "none" else: - value = 0 + value = ir.tensor([0], dtype=dtype) reduction_init = "none" cst = op.ConstantOfShape(op.Shape(src), value=value) From d7ce9424125aa43b8a034eddbb62278fe1894a20 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 May 2025 07:23:55 -0700 Subject: [PATCH 2/2] Update core.py --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 96d81fb3a3..9892e31052 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7627,6 +7627,8 @@ def aten_scatter_reduce( value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype) elif dtype == ir.DataType.BFLOAT16: value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) + elif dtype == ir.DataType.BOOL: + value = ir.tensor([False], dtype=dtype) else: value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) reduction_init = "min" @@ -7638,7 +7640,9 @@ def aten_scatter_reduce( }: value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype) elif dtype == ir.DataType.BFLOAT16: - value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype) + value = ir.tensor([torch.finfo(torch.bfloat16).max], dtype=dtype) + elif dtype == ir.DataType.BOOL: + value = ir.tensor([True], dtype=dtype) else: value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype) reduction_init = "max"