From f6eb9d22926f4a1fb4e8619394549a441ddf1fdd Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 25 Mar 2024 17:04:37 +0100 Subject: [PATCH 1/5] fix test " " --- .../function_libs/torch_lib/ops_test_data.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 7975f4ed47..879d7d1efa 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -851,17 +851,10 @@ def _where_input_wrangler( TorchLibOpInfo( "index_put_bool", core_ops.aten_index_put_bool, - ).skip( - matcher=lambda sample: not (sample.args[0][0].dtype == torch.bool), - reason="this Aten overload only supports tensor(bool) as indices", - ), - TorchLibOpInfo( - "index_put", - core_ops.aten_index_put, ) .skip( - matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64), - reason="this Aten overload only supports tensor(int) as indices", + matcher=lambda sample: not (sample.args[0][0].dtype == torch.bool), + reason="this Aten overload only supports tensor(bool) as indices", ) .xfail( enabled_if=version_utils.onnxruntime_older_than("1.18"), @@ -869,6 +862,10 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("accumulate") is True, reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'", ), + TorchLibOpInfo( + "index_put", + core_ops.aten_index_put, + ), TorchLibOpInfo("ops.aten.index_put", core_ops.aten_index_put), TorchLibOpInfo("index_select", core_ops.aten_index_select), TorchLibOpInfo("isclose", core_ops.aten_isclose), From 57e1ce9fa8c9855b992fe6cae4659edc6bf2e755 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 25 Mar 2024 17:24:58 +0100 Subject: [PATCH 2/5] ut --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index cbf1f50f71..021259ee06 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4051,7 +4051,7 @@ def aten_index_copy( raise NotImplementedError() -@torch_op("aten::index_put") +@torch_op(("aten::index_put", "aten::_unsafe_index_put")) def aten_index_put( self: TReal, indices: Sequence[INT64], diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 879d7d1efa..56ac0a41ff 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -866,6 +866,10 @@ def _where_input_wrangler( "index_put", core_ops.aten_index_put, ), + TorchLibOpInfo( + "_unsafe_index_put", + core_ops.aten_index_put, + ), TorchLibOpInfo("ops.aten.index_put", core_ops.aten_index_put), TorchLibOpInfo("index_select", core_ops.aten_index_select), TorchLibOpInfo("isclose", core_ops.aten_isclose), From b4c144d793e84c1f2672e2f8bed7e9f278f69be7 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 25 Mar 2024 17:29:37 +0100 Subject: [PATCH 3/5] fix test --- .../function_libs/torch_lib/ops_test_data.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 56ac0a41ff..54bccc1423 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -851,10 +851,17 @@ def _where_input_wrangler( TorchLibOpInfo( "index_put_bool", core_ops.aten_index_put_bool, - ) - .skip( + ).skip( matcher=lambda sample: not (sample.args[0][0].dtype == torch.bool), reason="this Aten overload only supports tensor(bool) as indices", + ), + TorchLibOpInfo( + "index_put", + core_ops.aten_index_put, + ) + .skip( + matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64), + reason="this Aten overload only supports tensor(int) as indices", ) .xfail( enabled_if=version_utils.onnxruntime_older_than("1.18"), @@ -862,15 +869,8 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("accumulate") is True, reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'", ), - TorchLibOpInfo( - "index_put", - core_ops.aten_index_put, - ), - TorchLibOpInfo( - "_unsafe_index_put", - core_ops.aten_index_put, - ), TorchLibOpInfo("ops.aten.index_put", core_ops.aten_index_put), + TorchLibOpInfo("ops.aten._unsafe_index_put", core_ops.aten__unsafe_index_put), TorchLibOpInfo("index_select", core_ops.aten_index_select), TorchLibOpInfo("isclose", core_ops.aten_isclose), TorchLibOpInfo("isfinite", core_ops.aten_isfinite), From d87d7f1c0740bcf02c386388d94fab9cac871ea4 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 26 Mar 2024 08:43:25 +0100 Subject: [PATCH 4/5] add case --- onnxscript/tests/function_libs/torch_lib/extra_opinfo.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 2ed8136e87..3f8c148822 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1963,6 +1963,13 @@ def __init__(self): sample_inputs_func=sample_inputs_index_put, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._unsafe_index_put", + aten_name="_unsafe_index_put", + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs_index_put, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.layer_norm", aten_name="layer_norm", From 6e271bb6729d29c7ee3e3ad716fcb90910c34896 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 26 Mar 2024 06:25:12 -0700 Subject: [PATCH 5/5] Update onnxscript/tests/function_libs/torch_lib/ops_test_data.py --- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 54bccc1423..2b6c13df7e 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -870,7 +870,7 @@ def _where_input_wrangler( reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'", ), TorchLibOpInfo("ops.aten.index_put", core_ops.aten_index_put), - TorchLibOpInfo("ops.aten._unsafe_index_put", core_ops.aten__unsafe_index_put), + TorchLibOpInfo("ops.aten._unsafe_index_put", core_ops.aten_index_put), TorchLibOpInfo("index_select", core_ops.aten_index_select), TorchLibOpInfo("isclose", core_ops.aten_isclose), TorchLibOpInfo("isfinite", core_ops.aten_isfinite),