Skip to content

Commit 0844d4d

Browse files
[TOSA] Cumsum fix init order (#4407)
1 parent 0b00ae8 commit 0844d4d

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ static Value emitInclusiveScanByPowersOfTwo(Value running,
6868

6969
SmallVector<int64_t, 3> sliceStart(3, 0);
7070
SmallVector<int64_t, 3> sliceSize = {outer, dimSize, inner};
71+
Value sliceStartConstShape =
72+
tosa::getTosaConstShape(rewriter, loc, sliceStart);
73+
Value sliceSizeConstShape = tosa::getTosaConstShape(rewriter, loc, sliceSize);
7174

7275
for (int64_t offset = 1; offset < dimSize; offset <<= 1) {
7376
SmallVector<int64_t, 6> padSpec = {0, 0, offset, 0, 0, 0};
@@ -79,11 +82,10 @@ static Value emitInclusiveScanByPowersOfTwo(Value running,
7982
padShape, zeroConst)
8083
.getResult();
8184

82-
Value shifted = tosa::SliceOp::create(
83-
rewriter, loc, nkcTy, padded,
84-
tosa::getTosaConstShape(rewriter, loc, sliceStart),
85-
tosa::getTosaConstShape(rewriter, loc, sliceSize))
86-
.getResult();
85+
Value shifted =
86+
tosa::SliceOp::create(rewriter, loc, nkcTy, padded,
87+
sliceStartConstShape, sliceSizeConstShape)
88+
.getResult();
8789

8890
running =
8991
tosa::AddOp::create(rewriter, loc, nkcTy, running, shifted).getResult();

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4445,10 +4445,10 @@ func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch
44454445
// CHECK: %[[RESHAPE_SHAPE:.*]] = tosa.const_shape {values = dense<[2, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
44464446
// CHECK: %[[RESHAPED:.*]] = tosa.reshape %[[IN]], %[[RESHAPE_SHAPE]] : (tensor<2x3xf32>, !tosa.shape<3>) -> tensor<2x3x1xf32>
44474447
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4448+
// CHECK-DAG: %[[SLICE_START:.*]] = tosa.const_shape {values = dense<0> : tensor<3xindex>} : () -> !tosa.shape<3>
4449+
// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[2, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
44484450
// CHECK: %[[PAD_SPEC:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 0, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6>
44494451
// CHECK: %[[PADDED:.*]] = tosa.pad %[[RESHAPED]], %[[PAD_SPEC]], %[[ZERO]] : (tensor<2x3x1xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<2x4x1xf32>
4450-
// CHECK: %[[SLICE_START:.*]] = tosa.const_shape {values = dense<0> : tensor<3xindex>} : () -> !tosa.shape<3>
4451-
// CHECK: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[2, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
44524452
// CHECK: %[[SLICE:.*]] = tosa.slice %[[PADDED]], %[[SLICE_START]], %[[SLICE_SIZE]] : (tensor<2x4x1xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<2x3x1xf32>
44534453
// CHECK: %[[ACC1:.*]] = tosa.add %[[RESHAPED]], %[[SLICE]] : (tensor<2x3x1xf32>, tensor<2x3x1xf32>) -> tensor<2x3x1xf32>
44544454
// CHECK: %[[ACC2:.*]] = tosa.add %[[ACC1]], %{{.*}} : (tensor<2x3x1xf32>, tensor<2x3x1xf32>) -> tensor<2x3x1xf32>

0 commit comments

Comments
 (0)