@@ -48,6 +48,50 @@ namespace mlir::torch {
4848
4949namespace {
5050
51+ // Runs an in-place inclusive prefix sum along the middle dimension (K) of
52+ // `running` using a binary lifting scheme. The input must have shape [N, K, C].
53+ // After the loop, `running` holds the cumsum result with respect to axis=1.
54+ static Value emitInclusiveScanByPowersOfTwo (Value running,
55+ ConversionPatternRewriter &rewriter,
56+ Location loc) {
57+ auto nkcTy = cast<RankedTensorType>(running.getType ());
58+ SmallVector<int64_t > nkcShape (makeShapeTorchCompatible (nkcTy.getShape ()));
59+ int64_t outer = nkcShape[0 ];
60+ int64_t dimSize = nkcShape[1 ];
61+ int64_t inner = nkcShape[2 ];
62+
63+ auto zeroConstOr =
64+ tosa::createZeroPointTensor (rewriter, loc, nkcTy.getElementType (), 0 );
65+ if (!zeroConstOr)
66+ return nullptr ;
67+ Value zeroConst = *zeroConstOr;
68+
69+ SmallVector<int64_t , 3 > sliceStart (3 , 0 );
70+ SmallVector<int64_t , 3 > sliceSize = {outer, dimSize, inner};
71+
72+ for (int64_t offset = 1 ; offset < dimSize; offset <<= 1 ) {
73+ SmallVector<int64_t , 6 > padSpec = {0 , 0 , offset, 0 , 0 , 0 };
74+ auto padShape = tosa::getTosaConstShape (rewriter, loc, padSpec);
75+ SmallVector<int64_t > paddedShape = {outer, dimSize + offset, inner};
76+ auto paddedTy = RankedTensorType::get (makeShapeLLVMCompatible (paddedShape),
77+ nkcTy.getElementType ());
78+ Value padded = tosa::PadOp::create (rewriter, loc, paddedTy, running,
79+ padShape, zeroConst)
80+ .getResult ();
81+
82+ Value shifted = tosa::SliceOp::create (
83+ rewriter, loc, nkcTy, padded,
84+ tosa::getTosaConstShape (rewriter, loc, sliceStart),
85+ tosa::getTosaConstShape (rewriter, loc, sliceSize))
86+ .getResult ();
87+
88+ running =
89+ tosa::AddOp::create (rewriter, loc, nkcTy, running, shifted).getResult ();
90+ }
91+
92+ return running;
93+ }
94+
5195static SmallVector<int64_t > permuteShape (ArrayRef<int64_t > originalShape,
5296 ArrayRef<int32_t > permutation) {
5397 SmallVector<int64_t > result;
@@ -4574,7 +4618,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
45744618 return rewriter.notifyMatchFailure (op, " dim out of range" );
45754619
45764620 SmallVector<int64_t > inputShape =
4577- llvm::to_vector ( makeShapeTorchCompatible (selfType.getShape () ));
4621+ makeShapeTorchCompatible (selfType.getShape ());
45784622 const int64_t K = inputShape[dim];
45794623
45804624 int64_t start;
@@ -9617,6 +9661,77 @@ LogicalResult ConvertAtenOp<AtenUnfoldOp>::matchAndRewrite(
96179661 return success ();
96189662}
96199663
9664+ // Legalization for aten.cumsum
9665+ template <>
9666+ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
9667+ AtenCumsumOp op, OpAdaptor adaptor,
9668+ ConversionPatternRewriter &rewriter) const {
9669+ auto self = adaptor.getSelf ();
9670+ auto selfType = dyn_cast<RankedTensorType>(self.getType ());
9671+ if (!selfType || !selfType.hasStaticShape ())
9672+ return rewriter.notifyMatchFailure (op,
9673+ " Only static tensor shapes supported" );
9674+
9675+ auto loc = op->getLoc ();
9676+
9677+ int64_t dim;
9678+ if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
9679+ return rewriter.notifyMatchFailure (op, " dim must be constant" );
9680+ dim = toPositiveDim (dim, selfType.getRank ());
9681+ if (!isValidDim (dim, selfType.getRank ()))
9682+ return rewriter.notifyMatchFailure (op, " dim out of range" );
9683+
9684+ auto outTypeAny = getTypeConverter ()->convertType (op.getType ());
9685+ auto outType = dyn_cast<RankedTensorType>(outTypeAny);
9686+ if (!outType)
9687+ return rewriter.notifyMatchFailure (op, " expected ranked result type" );
9688+
9689+ auto outElemTy = outType.getElementType ();
9690+ auto castTy = RankedTensorType::get (selfType.getShape (), outElemTy);
9691+ Value selfCast = self;
9692+ if (selfType.getElementType () != outElemTy) {
9693+ auto maybeCast = tosa::tosaCastTensorToType (rewriter, self, castTy);
9694+ if (!maybeCast)
9695+ return rewriter.notifyMatchFailure (op, " failed to cast tensor to dtype" );
9696+ selfCast = *maybeCast;
9697+ }
9698+
9699+ SmallVector<int64_t > inputShape =
9700+ makeShapeTorchCompatible (selfType.getShape ());
9701+ int64_t dimSize = inputShape[dim];
9702+
9703+ int64_t outer = 1 ;
9704+ for (int64_t i = 0 ; i < dim; ++i)
9705+ outer *= inputShape[i];
9706+ int64_t inner = 1 ;
9707+ for (int64_t i = dim + 1 , e = inputShape.size (); i < e; ++i)
9708+ inner *= inputShape[i];
9709+
9710+ // Collapse the tensor to [outer, dimSize, inner] so the scanned dimension
9711+ // is isolated. `outer` is the product of all dims before `dim`, and `inner`
9712+ // is the product after `dim`. This lets us run a simple binary lifting
9713+ // prefix-sum in 3D regardless of the original rank.
9714+ SmallVector<int64_t > nkcShape = {outer, dimSize, inner};
9715+ auto nkcTy =
9716+ RankedTensorType::get (makeShapeLLVMCompatible (nkcShape), outElemTy);
9717+
9718+ Value running =
9719+ tosa::ReshapeOp::create (rewriter, loc, nkcTy, selfCast,
9720+ tosa::getTosaConstShape (rewriter, loc, nkcShape))
9721+ .getResult ();
9722+
9723+ // Accumulate in-place: `running` always has shape [outer, dimSize, inner].
9724+ running = emitInclusiveScanByPowersOfTwo (running, rewriter, loc);
9725+
9726+ auto finalShape = outType.getShape ();
9727+ auto result = tosa::ReshapeOp::create (
9728+ rewriter, loc, outType, running,
9729+ tosa::getTosaConstShape (rewriter, loc, finalShape));
9730+
9731+ rewriter.replaceOp (op, result.getResult ());
9732+ return success ();
9733+ }
9734+
96209735template <typename OpTy>
96219736class ConvertCastEquivalentOp : public OpConversionPattern <OpTy> {
96229737 using OpConversionPattern<OpTy>::OpConversionPattern;
@@ -10242,6 +10357,7 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
1024210357 INSERT_ATENOP_PATTERN (AtenExpm1Op);
1024310358 INSERT_ATENOP_PATTERN (AtenTanOp);
1024410359 INSERT_ATENOP_PATTERN (AtenUnfoldOp);
10360+ INSERT_ATENOP_PATTERN (AtenCumsumOp);
1024510361 INSERT_ATENOP_PATTERN (AtenQuantizePerTensorOp);
1024610362#undef INSERT_ATENOP_PATTERN
1024710363
0 commit comments