-
Notifications
You must be signed in to change notification settings - Fork 364
feat: support aten.index_select converter #2710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
elementwise, | ||
embedding, | ||
grid, | ||
index, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can likely be removed - it seems to be causing a circular import error in CI There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! It seems I overlooked removing an unnecessary import. |
||
linear, | ||
matmul, | ||
normalization, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from typing import Optional | ||
|
||
from torch.fx.node import Target | ||
from torch_tensorrt.dynamo._SourceIR import SourceIR | ||
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext | ||
from torch_tensorrt.fx.converters.converter_utils import set_layer_name | ||
from torch_tensorrt.fx.types import TRTTensor | ||
|
||
|
||
def index_select( | ||
ctx: ConversionContext, | ||
target: Target, | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
dim: int, | ||
index: TRTTensor, | ||
) -> TRTTensor: | ||
# The axis parameter specifies the dimension along which to index. | ||
gather_layer = ctx.net.add_gather(input, index, axis=dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have modified it. Thanks! |
||
|
||
set_layer_name(gather_layer, target, f"{name}_gather", source_ir) | ||
|
||
return gather_layer.get_output(0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch | ||
import torch.nn as nn | ||
from harness import DispatchTestCase | ||
from parameterized import param, parameterized | ||
from torch.testing._internal.common_utils import run_tests | ||
from torch_tensorrt import Input | ||
|
||
|
||
class TestIndexSelectConverter(DispatchTestCase): | ||
@parameterized.expand( | ||
[ | ||
("1d_input", (10,), 0, (1,)), | ||
("2d_input_dim_0", (10, 3), 0, (0, 2)), | ||
("2d_input_dim_1", (5, 10), 1, (1, 2, 3)), | ||
("3d_input_dim_0", (10, 5, 10), 0, (0, 5)), | ||
("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a test case for a negative There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added a test case for a negative |
||
] | ||
) | ||
def test_index_select(self, _, source_shape, dim, indices_val): | ||
class TestIndexSelect(torch.nn.Module): | ||
def forward(self, source_tensor, indices_tensor): | ||
return torch.ops.aten.index_select.default( | ||
source_tensor, dim, indices_tensor | ||
) | ||
|
||
input = [ | ||
torch.randn(*source_shape, dtype=torch.float32), | ||
torch.tensor([*indices_val], dtype=torch.int32), | ||
] | ||
|
||
self.run_test( | ||
TestIndexSelect(), | ||
input, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the
index_select
function could be put intoselect.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved
index_select
insideselect.py
. Thank you!