Skip to content

Commit 5706948

Browse files
committed
1 parent ebbb1cf commit 5706948

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,8 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
452452
)
453453
s = 5
454454
index_1d = common_methods_invocations.index_variable(2, s, device=device)
455-
index_2d = common_methods_invocations.index_variable((2, s+1), s, device=device)
456-
index_3d = common_methods_invocations.index_variable((2, s+1, s+2), s, device=device)
455+
index_2d = common_methods_invocations.index_variable((s + 1, 2), s, device=device)
456+
index_3d = common_methods_invocations.index_variable((s + 2, s + 1, 2), s, device=device)
457457
test_args = [
458458
([index_1d],),
459459
([None, index_1d],),
@@ -463,6 +463,24 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
463463
([None, index_1d, None, index_1d],),
464464
([index_1d, None, index_1d, None],),
465465
([None, index_1d, index_1d, None],),
466+
([index_2d],),
467+
([None, index_2d],),
468+
([None, None, None, index_2d],),
469+
([index_2d, None],),
470+
([index_2d, None, None],),
471+
([None, index_2d, None, index_2d],),
472+
([index_2d, None, index_2d, None],),
473+
([None, index_2d, index_2d, None],),
474+
([index_3d],),
475+
([None, index_3d],),
476+
([None, None, None, index_3d],),
477+
([index_3d, None],),
478+
([index_3d, None, None],),
479+
([None, index_3d, None, index_3d],),
480+
([index_3d, None, index_3d, None],),
481+
([None, index_3d, index_3d, None],),
482+
# Mixed indices
483+
([None, index_3d, index_1d, index_2d],),
466484
]
467485

468486
for args in test_args:

0 commit comments

Comments
 (0)