Skip to content

Eliminate unnecessary ScatterND #2422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 30, 2025
Merged

Eliminate unnecessary ScatterND #2422

merged 4 commits into from
Jun 30, 2025

Conversation

gramalingam
Copy link
Collaborator

Identify ScatterND(data, indices, updates) that can be replaced by Identity(updates). This is generated by the translation of x[:, ...] = y in PyTorch. The specific pattern is that the updated indices take the form [[0], ..., [S-1]] for the first dimension, where S is the size of the first dimension of the updated-data tensor. In effect, the scatter-update ends up being an assignment of a new value to the entire tensor.

Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Copy link

codecov bot commented Jun 28, 2025

❌ 10 Tests Failed:

Tests completed Failed Passed Skipped
13236 10 13226 2528
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0316_test_ceil_example
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_ceil_example'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_ceil_example' (e=No module named 'tests.onnx_backend_test_code.test_ceil_example') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_ceil_example.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_ceil_example.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_ceil_example(x: FLOAT[2]) -> (FLOAT[2]):
E       y = opset13.Ceil(x)
E       return y
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0743_test_max_int64
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_max_int64'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_max_int64' (e=No module named 'tests.onnx_backend_test_code.test_max_int64') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_max_int64.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_max_int64.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import INT64
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_max_int64(data_0: INT64[3], data_1: INT64[3]) -> (INT64[3]):
E       result = opset13.Max(data_0, data_1)
E       return result
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1392_test_tile_precomputed
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_tile_precomputed'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_tile_precomputed' (e=No module named 'tests.onnx_backend_test_code.test_tile_precomputed') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_tile_precomputed.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_tile_precomputed.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_tile_precomputed(x: FLOAT[2,2], y: INT64[2]) -> (FLOAT[4,4]):
E       z = opset13.Tile(x, y)
E       return z

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

@gramalingam
Copy link
Collaborator Author

Can the torchlib implementation of slice_scatter be optimized to do the same, I wonder? I thought this part was potentially inside pytorch decomposition of x[:, :, :, s] = ..., but perhaps it can be done in slice_scatter itself?

@justinchuby
Copy link
Collaborator

Can the torchlib implementation of slice_scatter be optimized to do the same, I wonder? I thought this part was potentially inside pytorch decomposition of x[:, :, :, s] = ..., but perhaps it can be done in slice_scatter itself?

Defininately!

@justinchuby
Copy link
Collaborator

justinchuby commented Jun 28, 2025

The torchlib implementation also seems off: SymInt? start=None, SymInt? end=None means they can be None even though the comment says otherwise. We should improve torchlib for this op. Related #2372

@gramalingam
Copy link
Collaborator Author

The torchlib implementation also seems off: SymInt? start=None, SymInt? end=None means they can be None even though the comment says otherwise. We should improve torchlib for this op. Related #2372

Right. I was puzzled by the comments. Not entirely sure what's happening there.

Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
@titaiwangms titaiwangms self-requested a review June 30, 2025 17:18
@gramalingam gramalingam enabled auto-merge (squash) June 30, 2025 17:19
shape = op.Shape(data, start=0)
dim = op.Gather(shape, axis, axis=0)
full_range = op.Range(0, dim, 1)
full_range_2d = op.Unsqueeze(full_range, [-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see these ops in the repro: pytorch/pytorch#157289

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can delete

remove_redundant_scatternd = pattern.RewriteRule(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can consolidate the rules separately. (I am thinking of trying out Copilot to do it.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need to make other dimensions symbolic. Otherwise, all of these ops will be constant-folded, and the indices becomes a constant.

@gramalingam gramalingam merged commit ff0a132 into main Jun 30, 2025
25 of 32 checks passed
@gramalingam gramalingam deleted the rama/scatternd branch June 30, 2025 19:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

3 participants