-
Notifications
You must be signed in to change notification settings - Fork 364
feat: dynamic shape support for aten.select.int #2990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,26 +47,16 @@ def select( | |
if dynamic_shape: | ||
# Check whether slice target dim is dynamic shape dim | ||
assert input.shape[dim] != -1, "Can't select on negative shape dimension!" | ||
index = index | ||
|
||
if index >= input.shape[dim]: | ||
raise RuntimeError( | ||
f"cannot have index greater than the dimension length! {input.shape[dim]}" | ||
) | ||
output_shape = list(input.shape) | ||
output_shape[dim] = 1 | ||
if dynamic_shape > 0: | ||
output_shape = get_shape_with_dynamic_shape( | ||
ctx, target, source_ir, name, output_shape, input | ||
) | ||
|
||
index_value = np.array(index, dtype=np.int32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can index itself be dynamic (ITensor)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on the PyTorch docs and the schema, the |
||
indices_tensor = ctx.net.add_constant( | ||
index_value.shape, to_numpy(index_value) | ||
).get_output(0) | ||
indices_tensor = ctx.net.add_constant(index_value.shape, index_value).get_output(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you use get_trt_tensor call here ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion!
|
||
layer = ctx.net.add_gather(input, indices_tensor, dim) | ||
out = layer.get_output(0) | ||
if len(out.shape) != 1: | ||
layer = ctx.net.add_shuffle(out) | ||
|
||
return layer.get_output(0) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If above asserts() are removed and full dynamic shape was used(e.g.(-1,-1,-1)), test worked.
I'm wondering if select on dynamic dim can be supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested the test you mentioned and confirmed that using -1, -1, -1 also passes the test cases successfully.
However, I didn't remove the asserts() because the test cases fail when the
index
is out of range, meaning theindex
is larger than the corresponding dimension in the dynamic input shape.It seems that modifying the
index
like dimension (dim = get_positive_dim(cast(int, dim), ranks)
) when theindex
is greater than thesize
would solve the issue. Do you have an example that handles it this way?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we cannot check invalid index for dynamic dim. Error will happen in runtime.
Maybe we can check index for only static dim.
if DYNAMIC_DIM != input.shape[dim] and index >= input.shape[dim]:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also tested by removing the
assert()
andraise RuntimeError()
statements and tested with (-1, -1, -1).The test case mentioned as a
success case
above passes, but the test case mentioned as afail case
fails. The reason for this is that in thefail case
, the size of the 0-th dimension is 3, but theindex
selects positions larger than that (index
=3, 4th position). The current converter does not handle this case where theindex
is larger than the size of the dimension specified bydim
, as in the 'fail case'. (Note: It might be possible to handle this using the slice layer, lt (less than) and div functions, but currently, it is handled with an assert statement.)Therefore, I left the
assert()
andraise RuntimeError()
statement without removing it.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! I misunderstood. If the
index
is larger than the size of the dimension specified bydim
, it won't work in PyTorch either, so we don't need to handle this case. Therefore, we don't need to consider test cases like the 'fail case' mentioned above. This means that a dynamic shape is supported for all dimensions.Pytorch example of correct usage of select.int
Pytorch example of IndexError occurring case of select.int
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@keehyuna Thanks for the suggestion!
Since PyTorch already raises an error when the index exceeds the input size, there’s no need for us to check this here. I removed those checks. Additionally, following your suggestion, I’ve added test cases to fully support dynamic shapes.