Skip to content

Commit e990f23

Browse files
Add ops(as_strided) | feat(atenlib) (#545)
Co-authored-by: Justin Chu <[email protected]>
1 parent 85fee22 commit e990f23

File tree

2 files changed

+71
-3
lines changed

2 files changed

+71
-3
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,74 @@ def aten_argwhere(self: TensorType) -> TensorType:
503503
raise NotImplementedError()
504504

505505

506+
@torch_op("aten::as_strided", trace_only=True)
506507
def aten_as_strided(
507-
self: TensorType, size: INT64, stride: INT64, storage_offset: Optional[INT64] = None
508-
) -> TensorType:
508+
self: TTensor, size: INT64, stride: INT64, storage_offset: int = 0
509+
) -> TTensor:
509510
"""as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"""
510511

511-
raise NotImplementedError()
512+
rank = len(stride)
513+
return _aten_as_strided_onnx(self, size, stride, storage_offset, rank)
514+
515+
516+
@torch_op("aten::as_strided", private=True)
517+
def _aten_as_strided_onnx(
518+
self: TTensor, size: INT64, stride: INT64, storage_offset: int = 0, rank: int = 0
519+
) -> TTensor:
520+
# e.g. when size=[2,3,4], stride=[2,1,3], indices=[0]
521+
# i = 0
522+
# indices=[0], add_value=[0,3,6,9]
523+
# expand(shape=[4]) to [0,0,0,0]
524+
# then + add_value = [0,3,6,9]
525+
# i = 1
526+
# indices=[0,3,6,9], add_value=[0,1,2]
527+
# expand(shape=[3,4] to [[0,3,6,9],[0,3,6,9],[0,3,6,9]]
528+
# indices + add_value = [[0,3,6,9],[1,3,7,10],[2,5,8,11]]
529+
# i = 2
530+
# indices = [[0,3,6,9],[1,3,7,10],[2,5,8,11]], add_value=[0,2]
531+
# expand(shape=[2,3,4]) to [[[0,3,6,9],[1,3,7,10],[2,5,8,11]]],[[0,3,6,9],[1,3,7,10],[2,5,8,11]]]
532+
# indices + add_value = [[[0,3,6,9],[1,3,7,10],[2,5,8,11]]],[[2,5,8,11],[3,5,9,12],[4,7,10,13]]]
533+
neg_1 = op.Constant(value_ints=[-1])
534+
rank_tensor = op.Reshape(rank, neg_1) # should be 3
535+
# The final indices for op.Gather(data, indices), will be continually changed during the loop
536+
indices = op.Constant(value_int=0)
537+
one_seq = op.SequenceEmpty()
538+
for i in range(rank):
539+
# Get the index from back to front, should be 2,1,0 when to i=0,1,2
540+
j = rank - i - 1
541+
j_tensor = op.Reshape(j, neg_1)
542+
# Get size according to index_j, should be 4,3,2 when i=0,1,2
543+
size_dim_j = op.Gather(size, j_tensor, axis=0)
544+
# Get right size according to index_j, should be [4],[3,4],[2,3,4] when i=0,1,2
545+
size_after_j = op.Slice(size, j_tensor, rank_tensor)
546+
# Get stride according to index_j, should be 3,1,2 when i=0,1,2
547+
stride_dim_j = op.Gather(stride, j_tensor, axis=0)
548+
indices = op.Expand(indices, size_after_j)
549+
# When size[j]=4, stride[j]=3, then add_value = [0,1,2,3] * 3 = [0,3,6,9]
550+
# When size[j]=3, stride[j]=1, then add_value = [0,1,2] * 1 = [0,1,2]
551+
# When size[j]=2, stride[j]=2, then add_value = [0,1] * 2 = [0,2]
552+
add_value = op.Range(0, size_dim_j, 1) * stride_dim_j
553+
# Compute the shape for add_value for correct broadcasting
554+
if i == 0:
555+
# shape = [dim_size]
556+
shape = size_dim_j
557+
else:
558+
# shape = [dim_size, 1, 1, ...], the count of 1 euqal to i
559+
ones = op.ConcatFromSequence(one_seq, axis=0)
560+
shape = op.Concat(op.Cast(size_dim_j, to=FLOAT.dtype), ones, axis=0)
561+
shape = op.Cast(shape, to=INT64.dtype)
562+
563+
add_value = op.Reshape(add_value, shape)
564+
# Broadcasting add value to indices according to size and stride value
565+
indices = indices + add_value
566+
# Dims after dim_size to reshape(add_value), should be [1],[1,1],[1,1,1] when i=0,1,2
567+
one_seq = op.SequenceInsert(one_seq, op.Constant(value_floats=[1.0]))
568+
569+
self_flatten = op.Reshape(self, op.Constant(value_ints=[-1]))
570+
indices = op.Add(indices, storage_offset)
571+
result = op.Gather(self_flatten, indices)
572+
573+
return result
512574

513575

514576
def aten_as_strided_copy(

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,7 @@ def _where_input_wrangler(
675675
"arange": core_ops.aten_arange,
676676
"argmax": core_ops.aten_argmax,
677677
"argmin": core_ops.aten_argmin,
678+
"as_strided": core_ops.aten_as_strided,
678679
"clamp": core_ops.aten_clamp,
679680
"col2im": nn_ops.aten_col2im,
680681
"cumsum": core_ops.aten_cumsum,
@@ -751,6 +752,11 @@ def _where_input_wrangler(
751752
reason="fixme: ORT shape inference error",
752753
test_class_name="TestOutputConsistencyFullGraph",
753754
),
755+
xfail(
756+
"as_strided",
757+
variant_name="partial_views",
758+
reason="ONNX doesn't have partial view for tensor",
759+
),
754760
xfail(
755761
"chunk", reason="fixme: ORT error", test_class_name="TestOutputConsistencyFullGraph"
756762
),

0 commit comments

Comments
 (0)