Skip to content

Commit 8d98094

Browse files
authored
[torchlib] Fix scatter reduce on error cases (#2287)
Fix three errors ```pytb value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype) ^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/numpy/_core/getlimits.py", line 706, in __init__ raise ValueError("Invalid integer data type %r." % (self.kind,)) ValueError: Invalid integer data type 'b'. ``` ```pytb Traceback (most recent call last): File "/Users/runner/work/torch-onnx-op-matrix/torch-onnx-op-matrix/op_matrix/onnx_dynamo_op_survey.py", line 54, in check_single_op onnx.checker.check_model(onnx_model, full_check=True) # type: ignore ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/onnx/checker.py", line 180, in check_model C.check_model( onnx.onnx_cpp2py_export.checker.ValidationError: Mismatched attribute type in 'node_ConstantOfShape_1 : value'. Expected: 'TENSOR', actual: 'INT' ==> Context: Bad node spec for node. Name: node_ConstantOfShape_1 OpType: ConstantOfShape ``` Fix a case for bfloat16 when min should be max.
1 parent a0cf581 commit 8d98094

File tree

1 file changed

+6
-2
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+6
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7627,6 +7627,8 @@ def aten_scatter_reduce(
76277627
value = ir.tensor([np.finfo(dtype.numpy()).min], dtype=dtype)
76287628
elif dtype == ir.DataType.BFLOAT16:
76297629
value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype)
7630+
elif dtype == ir.DataType.BOOL:
7631+
value = ir.tensor([False], dtype=dtype)
76307632
else:
76317633
value = ir.tensor([np.iinfo(dtype.numpy()).min], dtype=dtype)
76327634
reduction_init = "min"
@@ -7638,7 +7640,9 @@ def aten_scatter_reduce(
76387640
}:
76397641
value = ir.tensor([np.finfo(dtype.numpy()).max], dtype=dtype)
76407642
elif dtype == ir.DataType.BFLOAT16:
7641-
value = ir.tensor([torch.finfo(torch.bfloat16).min], dtype=dtype)
7643+
value = ir.tensor([torch.finfo(torch.bfloat16).max], dtype=dtype)
7644+
elif dtype == ir.DataType.BOOL:
7645+
value = ir.tensor([True], dtype=dtype)
76427646
else:
76437647
value = ir.tensor([np.iinfo(dtype.numpy()).max], dtype=dtype)
76447648
reduction_init = "max"
@@ -7649,7 +7653,7 @@ def aten_scatter_reduce(
76497653
value = ir.tensor([1], dtype=dtype)
76507654
reduction_init = "none"
76517655
else:
7652-
value = 0
7656+
value = ir.tensor([0], dtype=dtype)
76537657
reduction_init = "none"
76547658

76557659
cst = op.ConstantOfShape(op.Shape(src), value=value)

0 commit comments

Comments
 (0)