Skip to content

feat: support aten.index_select converter #2710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2782,3 +2782,28 @@ def aten_ops_roll(
args[1],
args_bounds_check(args, 2, []),
)


@dynamo_tensorrt_converter(torch.ops.aten.index_select.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
2: (TRTTensor,),
}
)
def aten_ops_index_select(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.index.index_select(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the index_select function could be put into select.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved index_select inside select.py. Thank you!

ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
elementwise,
embedding,
grid,
index,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can likely be removed - it seems to be causing a circular import error in CI

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It seems I overlooked removing an unnecessary import.

linear,
matmul,
normalization,
Expand Down
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Optional

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor


def index_select(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
index: TRTTensor,
) -> TRTTensor:
# The axis parameter specifies the dimension along which to index.
gather_layer = ctx.net.add_gather(input, index, axis=dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim likely needs to be corrected using get_positive_dim to ensure the value is positive for add_gather

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have modified it. Thanks!


set_layer_name(gather_layer, target, f"{name}_gather", source_ir)

return gather_layer.get_output(0)
38 changes: 38 additions & 0 deletions tests/py/dynamo/conversion/test_index_select_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import torch.nn as nn
from harness import DispatchTestCase
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input


class TestIndexSelectConverter(DispatchTestCase):
@parameterized.expand(
[
("1d_input", (10,), 0, (1,)),
("2d_input_dim_0", (10, 3), 0, (0, 2)),
("2d_input_dim_1", (5, 10), 1, (1, 2, 3)),
("3d_input_dim_0", (10, 5, 10), 0, (0, 5)),
("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case for a negative dim input

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a test case for a negative dim input and verified a test case. Thank you!

]
)
def test_index_select(self, _, source_shape, dim, indices_val):
class TestIndexSelect(torch.nn.Module):
def forward(self, source_tensor, indices_tensor):
return torch.ops.aten.index_select.default(
source_tensor, dim, indices_tensor
)

input = [
torch.randn(*source_shape, dtype=torch.float32),
torch.tensor([*indices_val], dtype=torch.int32),
]

self.run_test(
TestIndexSelect(),
input,
)


if __name__ == "__main__":
run_tests()