@@ -68,6 +68,9 @@ static Value emitInclusiveScanByPowersOfTwo(Value running,
6868
6969 SmallVector<int64_t , 3 > sliceStart (3 , 0 );
7070 SmallVector<int64_t , 3 > sliceSize = {outer, dimSize, inner};
71+ Value sliceStartConstShape =
72+ tosa::getTosaConstShape (rewriter, loc, sliceStart);
73+ Value sliceSizeConstShape = tosa::getTosaConstShape (rewriter, loc, sliceSize);
7174
7275 for (int64_t offset = 1 ; offset < dimSize; offset <<= 1 ) {
7376 SmallVector<int64_t , 6 > padSpec = {0 , 0 , offset, 0 , 0 , 0 };
@@ -79,11 +82,10 @@ static Value emitInclusiveScanByPowersOfTwo(Value running,
7982 padShape, zeroConst)
8083 .getResult ();
8184
82- Value shifted = tosa::SliceOp::create (
83- rewriter, loc, nkcTy, padded,
84- tosa::getTosaConstShape (rewriter, loc, sliceStart),
85- tosa::getTosaConstShape (rewriter, loc, sliceSize))
86- .getResult ();
85+ Value shifted =
86+ tosa::SliceOp::create (rewriter, loc, nkcTy, padded,
87+ sliceStartConstShape, sliceSizeConstShape)
88+ .getResult ();
8789
8890 running =
8991 tosa::AddOp::create (rewriter, loc, nkcTy, running, shifted).getResult ();
0 commit comments