Skip to content

Commit 0b00ae8

Browse files
[TOSA] Add cumsum legalization (#4402)
1 parent 791debb commit 0b00ae8

File tree

3 files changed

+162
-10
lines changed

3 files changed

+162
-10
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,50 @@ namespace mlir::torch {
4848

4949
namespace {
5050

51+
// Runs an in-place inclusive prefix sum along the middle dimension (K) of
52+
// `running` using a binary lifting scheme. The input must have shape [N, K, C].
53+
// After the loop, `running` holds the cumsum result with respect to axis=1.
54+
static Value emitInclusiveScanByPowersOfTwo(Value running,
55+
ConversionPatternRewriter &rewriter,
56+
Location loc) {
57+
auto nkcTy = cast<RankedTensorType>(running.getType());
58+
SmallVector<int64_t> nkcShape(makeShapeTorchCompatible(nkcTy.getShape()));
59+
int64_t outer = nkcShape[0];
60+
int64_t dimSize = nkcShape[1];
61+
int64_t inner = nkcShape[2];
62+
63+
auto zeroConstOr =
64+
tosa::createZeroPointTensor(rewriter, loc, nkcTy.getElementType(), 0);
65+
if (!zeroConstOr)
66+
return nullptr;
67+
Value zeroConst = *zeroConstOr;
68+
69+
SmallVector<int64_t, 3> sliceStart(3, 0);
70+
SmallVector<int64_t, 3> sliceSize = {outer, dimSize, inner};
71+
72+
for (int64_t offset = 1; offset < dimSize; offset <<= 1) {
73+
SmallVector<int64_t, 6> padSpec = {0, 0, offset, 0, 0, 0};
74+
auto padShape = tosa::getTosaConstShape(rewriter, loc, padSpec);
75+
SmallVector<int64_t> paddedShape = {outer, dimSize + offset, inner};
76+
auto paddedTy = RankedTensorType::get(makeShapeLLVMCompatible(paddedShape),
77+
nkcTy.getElementType());
78+
Value padded = tosa::PadOp::create(rewriter, loc, paddedTy, running,
79+
padShape, zeroConst)
80+
.getResult();
81+
82+
Value shifted = tosa::SliceOp::create(
83+
rewriter, loc, nkcTy, padded,
84+
tosa::getTosaConstShape(rewriter, loc, sliceStart),
85+
tosa::getTosaConstShape(rewriter, loc, sliceSize))
86+
.getResult();
87+
88+
running =
89+
tosa::AddOp::create(rewriter, loc, nkcTy, running, shifted).getResult();
90+
}
91+
92+
return running;
93+
}
94+
5195
static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> originalShape,
5296
ArrayRef<int32_t> permutation) {
5397
SmallVector<int64_t> result;
@@ -4574,7 +4618,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
45744618
return rewriter.notifyMatchFailure(op, "dim out of range");
45754619

45764620
SmallVector<int64_t> inputShape =
4577-
llvm::to_vector(makeShapeTorchCompatible(selfType.getShape()));
4621+
makeShapeTorchCompatible(selfType.getShape());
45784622
const int64_t K = inputShape[dim];
45794623

45804624
int64_t start;
@@ -9617,6 +9661,77 @@ LogicalResult ConvertAtenOp<AtenUnfoldOp>::matchAndRewrite(
96179661
return success();
96189662
}
96199663

9664+
// Legalization for aten.cumsum
9665+
template <>
9666+
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
9667+
AtenCumsumOp op, OpAdaptor adaptor,
9668+
ConversionPatternRewriter &rewriter) const {
9669+
auto self = adaptor.getSelf();
9670+
auto selfType = dyn_cast<RankedTensorType>(self.getType());
9671+
if (!selfType || !selfType.hasStaticShape())
9672+
return rewriter.notifyMatchFailure(op,
9673+
"Only static tensor shapes supported");
9674+
9675+
auto loc = op->getLoc();
9676+
9677+
int64_t dim;
9678+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
9679+
return rewriter.notifyMatchFailure(op, "dim must be constant");
9680+
dim = toPositiveDim(dim, selfType.getRank());
9681+
if (!isValidDim(dim, selfType.getRank()))
9682+
return rewriter.notifyMatchFailure(op, "dim out of range");
9683+
9684+
auto outTypeAny = getTypeConverter()->convertType(op.getType());
9685+
auto outType = dyn_cast<RankedTensorType>(outTypeAny);
9686+
if (!outType)
9687+
return rewriter.notifyMatchFailure(op, "expected ranked result type");
9688+
9689+
auto outElemTy = outType.getElementType();
9690+
auto castTy = RankedTensorType::get(selfType.getShape(), outElemTy);
9691+
Value selfCast = self;
9692+
if (selfType.getElementType() != outElemTy) {
9693+
auto maybeCast = tosa::tosaCastTensorToType(rewriter, self, castTy);
9694+
if (!maybeCast)
9695+
return rewriter.notifyMatchFailure(op, "failed to cast tensor to dtype");
9696+
selfCast = *maybeCast;
9697+
}
9698+
9699+
SmallVector<int64_t> inputShape =
9700+
makeShapeTorchCompatible(selfType.getShape());
9701+
int64_t dimSize = inputShape[dim];
9702+
9703+
int64_t outer = 1;
9704+
for (int64_t i = 0; i < dim; ++i)
9705+
outer *= inputShape[i];
9706+
int64_t inner = 1;
9707+
for (int64_t i = dim + 1, e = inputShape.size(); i < e; ++i)
9708+
inner *= inputShape[i];
9709+
9710+
// Collapse the tensor to [outer, dimSize, inner] so the scanned dimension
9711+
// is isolated. `outer` is the product of all dims before `dim`, and `inner`
9712+
// is the product after `dim`. This lets us run a simple binary lifting
9713+
// prefix-sum in 3D regardless of the original rank.
9714+
SmallVector<int64_t> nkcShape = {outer, dimSize, inner};
9715+
auto nkcTy =
9716+
RankedTensorType::get(makeShapeLLVMCompatible(nkcShape), outElemTy);
9717+
9718+
Value running =
9719+
tosa::ReshapeOp::create(rewriter, loc, nkcTy, selfCast,
9720+
tosa::getTosaConstShape(rewriter, loc, nkcShape))
9721+
.getResult();
9722+
9723+
// Accumulate in-place: `running` always has shape [outer, dimSize, inner].
9724+
running = emitInclusiveScanByPowersOfTwo(running, rewriter, loc);
9725+
9726+
auto finalShape = outType.getShape();
9727+
auto result = tosa::ReshapeOp::create(
9728+
rewriter, loc, outType, running,
9729+
tosa::getTosaConstShape(rewriter, loc, finalShape));
9730+
9731+
rewriter.replaceOp(op, result.getResult());
9732+
return success();
9733+
}
9734+
96209735
template <typename OpTy>
96219736
class ConvertCastEquivalentOp : public OpConversionPattern<OpTy> {
96229737
using OpConversionPattern<OpTy>::OpConversionPattern;
@@ -10242,6 +10357,7 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
1024210357
INSERT_ATENOP_PATTERN(AtenExpm1Op);
1024310358
INSERT_ATENOP_PATTERN(AtenTanOp);
1024410359
INSERT_ATENOP_PATTERN(AtenUnfoldOp);
10360+
INSERT_ATENOP_PATTERN(AtenCumsumOp);
1024510361
INSERT_ATENOP_PATTERN(AtenQuantizePerTensorOp);
1024610362
#undef INSERT_ATENOP_PATTERN
1024710363

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3618,7 +3618,6 @@
36183618
"Conv_Transpose1dModule_basic",
36193619
"Conv_Transpose1dStaticModule_basic",
36203620
"IndexPutWithNoneAndBroadcastModule_basic",
3621-
"MaskedScatterStaticBasic_basic",
36223621
"MaxUnpool3dModulePad0_basic",
36233622
"MaxUnpool3dModule_basic",
36243623
"MaxUnpool2dModule_basic",
@@ -3729,11 +3728,7 @@
37293728
"ConvolutionModule3DGroups_basic",
37303729
"ConvolutionModule3DGroupsStrided_basic",
37313730
"ConvolutionModule3DGroupsDilated_basic",
3732-
"CumsumInputDtypeInt32Module_basic",
3733-
"CumsumWithDtypeModule_basic",
37343731
"CumsumModule_basic",
3735-
"CumsumStaticModule_basic",
3736-
"CumsumStaticNegativeDimModule_basic",
37373732
"CumprodModule_basic",
37383733
"CumprodInputDtypeInt32Module_basic",
37393734
"CumprodStaticModule_basic",
@@ -3812,10 +3807,6 @@
38123807
"LinalgNormKeepDimComplexModule_basic",
38133808
"LinalgVectorNormComplexModule_basic",
38143809
"LinspaceEmptyModule_basic",
3815-
"LogCumsumExpModule_basic",
3816-
"LogCumsumExpStaticNegativeDimModule_basic",
3817-
"LogCumsumExpStaticFloat64DtypeModule_basic",
3818-
"MaskedScatterStaticBasic_basic",
38193810
"MaxPool1dWithIndicesModule_basic",
38203811
"MaxPool1dWithIndicesCeilModeModule_basic",
38213812
"MaxPool1dCeilModeTrueModule_basic",

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4439,6 +4439,51 @@ func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch
44394439
return %0 : !torch.vtensor<[2,3],f16>
44404440
}
44414441

4442+
// CHECK-LABEL: func.func @torch.aten.cumsum.basic(
4443+
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
4444+
// CHECK: %[[IN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
4445+
// CHECK: %[[RESHAPE_SHAPE:.*]] = tosa.const_shape {values = dense<[2, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
4446+
// CHECK: %[[RESHAPED:.*]] = tosa.reshape %[[IN]], %[[RESHAPE_SHAPE]] : (tensor<2x3xf32>, !tosa.shape<3>) -> tensor<2x3x1xf32>
4447+
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4448+
// CHECK: %[[PAD_SPEC:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 0, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6>
4449+
// CHECK: %[[PADDED:.*]] = tosa.pad %[[RESHAPED]], %[[PAD_SPEC]], %[[ZERO]] : (tensor<2x3x1xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<2x4x1xf32>
4450+
// CHECK: %[[SLICE_START:.*]] = tosa.const_shape {values = dense<0> : tensor<3xindex>} : () -> !tosa.shape<3>
4451+
// CHECK: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[2, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
4452+
// CHECK: %[[SLICE:.*]] = tosa.slice %[[PADDED]], %[[SLICE_START]], %[[SLICE_SIZE]] : (tensor<2x4x1xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<2x3x1xf32>
4453+
// CHECK: %[[ACC1:.*]] = tosa.add %[[RESHAPED]], %[[SLICE]] : (tensor<2x3x1xf32>, tensor<2x3x1xf32>) -> tensor<2x3x1xf32>
4454+
// CHECK: %[[ACC2:.*]] = tosa.add %[[ACC1]], %{{.*}} : (tensor<2x3x1xf32>, tensor<2x3x1xf32>) -> tensor<2x3x1xf32>
4455+
// CHECK: %[[FINAL:.*]] = tosa.const_shape {values = dense<[2, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4456+
// CHECK: %[[OUT:.*]] = tosa.reshape %[[ACC2]], %[[FINAL]] : (tensor<2x3x1xf32>, !tosa.shape<2>) -> tensor<2x3xf32>
4457+
// CHECK: %[[TORCH:.*]] = torch_c.from_builtin_tensor %[[OUT]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
4458+
// CHECK: return %[[TORCH]] : !torch.vtensor<[2,3],f32>
4459+
// CHECK: }
4460+
func.func @torch.aten.cumsum.basic(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
4461+
%dim = torch.constant.int 1
4462+
%none = torch.constant.none
4463+
%0 = torch.aten.cumsum %arg0, %dim, %none : !torch.vtensor<[2,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f32>
4464+
return %0 : !torch.vtensor<[2,3],f32>
4465+
}
4466+
4467+
// CHECK-LABEL: func.func @torch.aten.cumsum.si32(
4468+
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],si32>) -> !torch.vtensor<[3,2],si32> {
4469+
// CHECK: %[[IN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[3,2],si32> -> tensor<3x2xi32>
4470+
// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[IN]], %{{.*}} : (tensor<3x2xi32>, !tosa.shape<3>) -> tensor<{{.*}}xi32>
4471+
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
4472+
// CHECK: %[[PAD:.*]] = tosa.pad %[[RESHAPE]], %{{.*}}, %[[ZERO]] : (tensor<{{.*}}xi32>, !tosa.shape<6>, tensor<1xi32>) -> tensor<{{.*}}xi32>
4473+
// CHECK: %[[SLICE:.*]] = tosa.slice %[[PAD]], %{{.*}}, %{{.*}} : (tensor<{{.*}}xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<{{.*}}xi32>
4474+
// CHECK: %[[ACC1:.*]] = tosa.add %[[RESHAPE]], %[[SLICE]] : (tensor<{{.*}}xi32>, tensor<{{.*}}xi32>) -> tensor<{{.*}}xi32>
4475+
// CHECK: %[[ACC2:.*]] = tosa.add %[[ACC1]], %{{.*}} : (tensor<{{.*}}xi32>, tensor<{{.*}}xi32>) -> tensor<{{.*}}xi32>
4476+
// CHECK: %[[FINAL:.*]] = tosa.reshape %[[ACC2]], %{{.*}} : (tensor<{{.*}}xi32>, !tosa.shape<2>) -> tensor<3x2xi32>
4477+
// CHECK: %[[TORCH:.*]] = torch_c.from_builtin_tensor %[[FINAL]] : tensor<3x2xi32> -> !torch.vtensor<[3,2],si32>
4478+
// CHECK: return %[[TORCH]] : !torch.vtensor<[3,2],si32>
4479+
// CHECK: }
4480+
func.func @torch.aten.cumsum.si32(%arg0: !torch.vtensor<[3,2],si32>) -> !torch.vtensor<[3,2],si32> {
4481+
%dim = torch.constant.int 0
4482+
%none = torch.constant.none
4483+
%0 = torch.aten.cumsum %arg0, %dim, %none : !torch.vtensor<[3,2],si32>, !torch.int, !torch.none -> !torch.vtensor<[3,2],si32>
4484+
return %0 : !torch.vtensor<[3,2],si32>
4485+
}
4486+
44424487
// -----
44434488
func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
44444489
%c1 = torch.constant.int 1

0 commit comments

Comments
 (0)