Skip to content

Commit fbfcf66

Browse files
committed
wip
1 parent 576baac commit fbfcf66

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3026,10 +3026,15 @@ def aten_imag(self: TensorType) -> TensorType:
30263026
raise NotImplementedError()
30273027

30283028

3029+
@torch_op("aten::index.Tensor")
30293030
def aten_index(self: TTensor, indices: Sequence[INT64]) -> TTensor:
30303031
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"""
30313032

3032-
return op.Gather(self, indices)
3033+
result = self
3034+
for i in range(op.SequenceLength(indices)):
3035+
result = op.Gather(result, op.SequenceAt(indices, i))
3036+
3037+
return result
30333038

30343039

30353040
def aten_index_add(

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,22 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs):
446446
yield opinfo_core.SampleInput(tensor, args=(output_size, kernel_size), kwargs=kwargs)
447447

448448

449+
def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
450+
del op_info # Unused
451+
del kwargs # Unused
452+
make_arg = functools.partial(
453+
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
454+
)
455+
s = 5
456+
test_args = [
457+
([common_methods_invocations.index_variable(2, s, device=device)],),
458+
# ([torch.tensor()],)
459+
]
460+
461+
for args in test_args:
462+
yield opinfo_core.SampleInput(make_arg((s, s, s)), args=args)
463+
464+
449465
def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs):
450466
del op_info
451467
del kwargs
@@ -581,9 +597,17 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
581597
skips=(),
582598
supports_out=False,
583599
),
600+
opinfo_core.OpInfo(
601+
"aten.index.Tensor",
602+
dtypes=common_dtype.all_types_and_complex_and(
603+
torch.bool, torch.float16, torch.bfloat16, torch.chalf
604+
),
605+
aten_name="index",
606+
op=torch.ops.aten.index.Tensor,
607+
sample_inputs_func=sample_inputs_index,
608+
),
584609
opinfo_core.OpInfo(
585610
"layer_norm",
586-
aliases=("layer_norm",),
587611
aten_name="layer_norm",
588612
dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16),
589613
sample_inputs_func=sample_inputs_layer_norm,

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def _where_input_wrangler(
601601
TorchLibOpInfo("gt", core_ops.aten_gt),
602602
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
603603
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
604+
TorchLibOpInfo("aten.index.Tensor", core_ops.aten_index),
604605
TorchLibOpInfo(
605606
"index_put_bool",
606607
core_ops.aten_index_put_bool,

0 commit comments

Comments
 (0)