-
Notifications
You must be signed in to change notification settings - Fork 64
Implement aten::index | feat(torchlib) (#862) #883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5bde197
3ff6ee2
4e54582
ce12f0e
ebbb1cf
5706948
7eadba1
bda7e1c
7c26905
64ed9f9
8826d63
1c037f4
cfa360b
169debc
9d941c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -444,6 +444,55 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): | |
yield opinfo_core.SampleInput(tensor, args=(output_size, kernel_size), kwargs=kwargs) | ||
|
||
|
||
def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs): | ||
del op_info # Unused | ||
del kwargs # Unused | ||
make_arg = functools.partial( | ||
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad | ||
) | ||
s = 5 | ||
index_1d = common_methods_invocations.index_variable(2, s, device=device) | ||
index_2d = common_methods_invocations.index_variable((s + 1, 2), s, device=device) | ||
index_3d = common_methods_invocations.index_variable((s + 2, s + 1, 2), s, device=device) | ||
test_args = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above. Turns out some combinations are invalid to torch and it may be better to specify explicitly. |
||
([index_1d],), | ||
([None, index_1d],), | ||
([None, None, None, index_1d],), | ||
([index_1d, None],), | ||
([index_1d, None, None],), | ||
# Extra index | ||
([None, index_1d, None, index_1d],), | ||
([index_1d, None, index_1d, None],), | ||
([None, index_1d, index_1d, None],), | ||
([index_2d],), | ||
([None, index_2d],), | ||
([None, None, None, index_2d],), | ||
([index_2d, None],), | ||
([index_2d, None, None],), | ||
# Extra index | ||
([None, index_2d, None, index_2d],), | ||
([index_2d, None, index_2d, None],), | ||
([None, index_2d, index_2d, None],), | ||
([index_3d],), | ||
([None, index_3d],), | ||
([None, None, None, index_3d],), | ||
([index_3d, None],), | ||
([index_3d, None, None],), | ||
# Extra index | ||
([None, index_3d, None, index_3d],), | ||
([index_3d, None, index_3d, None],), | ||
([None, index_3d, index_3d, None],), | ||
# Mixed indices | ||
([None, index_3d, index_1d, index_2d],), | ||
# All indices are not None | ||
([index_2d, index_3d, index_1d],), | ||
([index_2d, index_3d, index_1d, index_2d],), | ||
] | ||
|
||
for args in test_args: | ||
yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args) | ||
|
||
|
||
def sample_inputs_native_dropout( | ||
op_info, device, dtype, requires_grad, *, valid_input_dim=None, **kwargs | ||
): | ||
|
@@ -616,6 +665,14 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra | |
sample_inputs_func=sample_inputs_convolution, | ||
supports_out=False, | ||
), | ||
opinfo_core.OpInfo( | ||
"ops.aten.index.Tensor", | ||
aten_name="index.Tensor", | ||
dtypes=common_dtype.all_types_and_complex_and( | ||
torch.bool, torch.float16, torch.bfloat16, torch.chalf | ||
), | ||
sample_inputs_func=sample_inputs_index, | ||
), | ||
opinfo_core.OpInfo( | ||
"ops.aten.layer_norm", | ||
aten_name="layer_norm", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -261,7 +261,8 @@ def convert_tensor_to_numpy(input: Any) -> Any: | |
if isinstance(input, (tuple, list)): | ||
if len(input) == 0: | ||
return np.array((), dtype=np.int64) | ||
if isinstance(input[0], torch.Tensor): | ||
if any(isinstance(x, torch.Tensor) for x in input): | ||
# The list can be Optional[Tensor], e.g. [None, Tensor, None] etc. | ||
return [convert_tensor_to_numpy(x) for x in input] | ||
if isinstance(input[0], bool): | ||
return np.array(input, dtype=np.bool_) | ||
|
@@ -276,10 +277,7 @@ def convert_tensor_to_numpy(input: Any) -> Any: | |
|
||
|
||
def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]: | ||
"""Converts kwargs to be compatible with ONNX Runtime. | ||
|
||
ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8. | ||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for my knowledge, does ort support torch.bool now? or was this docstring outdated already? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually don't know if ORT supports bool (it should?), but I think this message was a mistake by copilot because we don't actually have this conversion logic as code. If we see issues with ORT I will make adjustments. |
||
""" | ||
"""Converts kwargs to be compatible with ONNX Runtime.""" | ||
new_kwargs = {} | ||
for key, value in kwargs.items(): | ||
if key == "device": | ||
|
@@ -473,6 +471,10 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, | |
input.value = subarg | ||
sequence_input.append(input) | ||
ort_inputs[input_name] = subarg | ||
else: | ||
# Include non-numpy inputs as-is | ||
# For example, it could be a None value that we want to keep | ||
sequence_input.append(subarg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's put comment here explaining why this is needed in case we forget? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we consider elif None? It's easier to catch what we are not expecting. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. I wonder if there are things we don't expect that can sneak in? Nested lists? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No idea. Just in case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SG. Created #892 for now. |
||
onnxscript_args.append(sequence_input) | ||
else: | ||
onnxscript_args.append(arg) | ||
|
@@ -515,7 +517,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, | |
# Make sure the model is valid | ||
try: | ||
onnx.checker.check_model(onnx_model, full_check=True) | ||
except onnx.checker.ValidationError as e: | ||
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e: | ||
raise AssertionError( | ||
f"ONNX model is invalid. Model:\n{onnxscript.proto2text(onnx_model)}" | ||
) from e | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would itertools.{product,permutation,combination} help in ensuring all combinations are covered and make the code shorter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Let me try that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I tested with itertools but realized some combinations are invalid so we cannot enumerate them all using itertools. For the sake of clarity I propose that we keep the current explicit tests. I also added more test cases and comments