Skip to content

Commit 32c3dec

Browse files
[mlir][vector] Modernize vector.transpose op (#72594)
* Declare arguments/results with `let` statements. * Rename `transp` to `permutation`. * Change type of `transp` from `I64ArrayAttr` to `DenseI64ArrayAttr` (provides direct access to `ArrayRef<int64_t>` instead of `ArrayAttr`).
1 parent abcbca2 commit 32c3dec

File tree

11 files changed

+60
-83
lines changed

11 files changed

+60
-83
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,14 +2436,13 @@ def Vector_TransposeOp :
24362436
Vector_Op<"transpose", [Pure,
24372437
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
24382438
PredOpTrait<"operand and result have same element type",
2439-
TCresVTEtIsSameAsOpBase<0, 0>>]>,
2440-
Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
2441-
Results<(outs AnyVectorOfAnyRank:$result)> {
2439+
TCresVTEtIsSameAsOpBase<0, 0>>]> {
24422440
let summary = "vector transpose operation";
24432441
let description = [{
24442442
Takes a n-D vector and returns the transposed n-D vector defined by
24452443
the permutation of ranks in the n-sized integer array attribute (in case
24462444
of 0-D vectors the array attribute must be empty).
2445+
24472446
In the operation
24482447

24492448
```mlir
@@ -2452,7 +2451,7 @@ def Vector_TransposeOp :
24522451
to vector<d_trans[0] x .. x d_trans[n-1] x f32>
24532452
```
24542453

2455-
the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
2454+
the `permutation` array [i_1, .., i_n] must be a permutation of [0, .., n-1].
24562455

24572456
Example:
24582457

@@ -2464,8 +2463,13 @@ def Vector_TransposeOp :
24642463
[c, f] ]
24652464
```
24662465
}];
2466+
2467+
let arguments = (ins AnyVectorOfAnyRank:$vector,
2468+
DenseI64ArrayAttr:$permutation);
2469+
let results = (outs AnyVectorOfAnyRank:$result);
2470+
24672471
let builders = [
2468-
OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
2472+
OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
24692473
];
24702474
let extraClassDeclaration = [{
24712475
VectorType getSourceVectorType() {
@@ -2474,10 +2478,9 @@ def Vector_TransposeOp :
24742478
VectorType getResultVectorType() {
24752479
return ::llvm::cast<VectorType>(getResult().getType());
24762480
}
2477-
void getTransp(SmallVectorImpl<int64_t> &results);
24782481
}];
24792482
let assemblyFormat = [{
2480-
$vector `,` $transp attr-dict `:` type($vector) `to` type($result)
2483+
$vector `,` $permutation attr-dict `:` type($vector) `to` type($result)
24812484
}];
24822485
let hasCanonicalizer = 1;
24832486
let hasFolder = 1;

mlir/include/mlir/IR/AffineMap.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ class AffineMap {
103103
/// (i.e. `[1,1,2]` is an invalid permutation).
104104
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
105105
MLIRContext *context);
106+
static AffineMap getPermutationMap(ArrayRef<int64_t> permutation,
107+
MLIRContext *context);
106108

107109
/// Returns an affine map with `numDims` input dimensions and results
108110
/// specified by `targets`.

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,9 @@ struct TransposeOpToArmSMELowering
436436
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
437437
return failure();
438438

439-
SmallVector<int64_t> transp;
440-
for (auto attr : transposeOp.getTransp())
441-
transp.push_back(cast<IntegerAttr>(attr).getInt());
442-
443439
// Bail unless this is a true 2-D matrix transpose.
444-
if (transp[0] != 1 || transp[1] != 0)
440+
ArrayRef<int64_t> permutation = transposeOp.getPermutation();
441+
if (permutation[0] != 1 || permutation[1] != 0)
445442
return failure();
446443

447444
OpBuilder::InsertionGuard g(rewriter);

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -473,13 +473,8 @@ struct CombineTransferReadOpTranspose final
473473
if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
474474
return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
475475

476-
SmallVector<int64_t, 2> perm;
477-
op.getTransp(perm);
478-
SmallVector<unsigned, 2> permU;
479-
for (int64_t o : perm)
480-
permU.push_back(unsigned(o));
481476
AffineMap permutationMap =
482-
AffineMap::getPermutationMap(permU, op.getContext());
477+
AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
483478
AffineMap newMap =
484479
permutationMap.compose(transferReadOp.getPermutationMap());
485480

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
712712
VectorType newTy =
713713
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
714714
Value newTranspose = rewriter.create<vector::TransposeOp>(
715-
op.getLoc(), newTy, ext->getIn(), op.getTransp());
715+
op.getLoc(), newTy, ext->getIn(), op.getPermutation());
716716
ext->recreateAndReplace(rewriter, op, newTranspose);
717717
return success();
718718
}

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,9 +1456,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
14561456

14571457
if (!nextTransposeOp)
14581458
return failure();
1459-
auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
1460-
AffineMap m = inversePermutation(
1461-
AffineMap::getPermutationMap(permutation, extractOp.getContext()));
1459+
AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1460+
nextTransposeOp.getPermutation(), extractOp.getContext()));
14621461
extractPosition = applyPermutationMap(m, ArrayRef(extractPosition));
14631462
return success();
14641463
}
@@ -5376,20 +5375,20 @@ LogicalResult TypeCastOp::verify() {
53765375
//===----------------------------------------------------------------------===//
53775376

53785377
void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5379-
Value vector, ArrayRef<int64_t> transp) {
5378+
Value vector, ArrayRef<int64_t> permutation) {
53805379
VectorType vt = llvm::cast<VectorType>(vector.getType());
53815380
SmallVector<int64_t, 4> transposedShape(vt.getRank());
53825381
SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5383-
for (unsigned i = 0; i < transp.size(); ++i) {
5384-
transposedShape[i] = vt.getShape()[transp[i]];
5385-
transposedScalableDims[i] = vt.getScalableDims()[transp[i]];
5382+
for (unsigned i = 0; i < permutation.size(); ++i) {
5383+
transposedShape[i] = vt.getShape()[permutation[i]];
5384+
transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
53865385
}
53875386

53885387
result.addOperands(vector);
53895388
result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
53905389
transposedScalableDims));
5391-
result.addAttribute(TransposeOp::getTranspAttrName(result.name),
5392-
builder.getI64ArrayAttr(transp));
5390+
result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5391+
builder.getDenseI64ArrayAttr(permutation));
53935392
}
53945393

53955394
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
@@ -5401,13 +5400,12 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
54015400

54025401
// Eliminate identity transpose ops. This happens when the dimensions of the
54035402
// input vector remain in their original order after the transpose operation.
5404-
SmallVector<int64_t, 4> transp;
5405-
getTransp(transp);
5403+
ArrayRef<int64_t> perm = getPermutation();
54065404

54075405
// Check if the permutation of the dimensions contains sequential values:
54085406
// {0, 1, 2, ...}.
5409-
for (int64_t i = 0, e = transp.size(); i < e; i++) {
5410-
if (transp[i] != i)
5407+
for (int64_t i = 0, e = perm.size(); i < e; i++) {
5408+
if (perm[i] != i)
54115409
return {};
54125410
}
54135411

@@ -5421,20 +5419,19 @@ LogicalResult vector::TransposeOp::verify() {
54215419
if (vectorType.getRank() != rank)
54225420
return emitOpError("vector result rank mismatch: ") << rank;
54235421
// Verify transposition array.
5424-
auto transpAttr = getTransp().getValue();
5425-
int64_t size = transpAttr.size();
5422+
ArrayRef<int64_t> perm = getPermutation();
5423+
int64_t size = perm.size();
54265424
if (rank != size)
54275425
return emitOpError("transposition length mismatch: ") << size;
54285426
SmallVector<bool, 8> seen(rank, false);
5429-
for (const auto &ta : llvm::enumerate(transpAttr)) {
5430-
int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
5431-
if (i < 0 || i >= rank)
5432-
return emitOpError("transposition index out of range: ") << i;
5433-
if (seen[i])
5434-
return emitOpError("duplicate position index: ") << i;
5435-
seen[i] = true;
5436-
if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
5437-
return emitOpError("dimension size mismatch at: ") << i;
5427+
for (const auto &ta : llvm::enumerate(perm)) {
5428+
if (ta.value() < 0 || ta.value() >= rank)
5429+
return emitOpError("transposition index out of range: ") << ta.value();
5430+
if (seen[ta.value()])
5431+
return emitOpError("duplicate position index: ") << ta.value();
5432+
seen[ta.value()] = true;
5433+
if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5434+
return emitOpError("dimension size mismatch at: ") << ta.value();
54385435
}
54395436
return success();
54405437
}
@@ -5452,13 +5449,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
54525449

54535450
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
54545451
PatternRewriter &rewriter) const override {
5455-
// Wrapper around vector::TransposeOp::getTransp() for cleaner code.
5456-
auto getPermutation = [](vector::TransposeOp transpose) {
5457-
SmallVector<int64_t, 4> permutation;
5458-
transpose.getTransp(permutation);
5459-
return permutation;
5460-
};
5461-
54625452
// Composes two permutations: result[i] = permutation1[permutation2[i]].
54635453
auto composePermutations = [](ArrayRef<int64_t> permutation1,
54645454
ArrayRef<int64_t> permutation2) {
@@ -5475,12 +5465,11 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
54755465
return failure();
54765466

54775467
SmallVector<int64_t, 4> permutation = composePermutations(
5478-
getPermutation(parentTransposeOp), getPermutation(transposeOp));
5468+
parentTransposeOp.getPermutation(), transposeOp.getPermutation());
54795469
// Replace 'transposeOp' with a new transpose operation.
54805470
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
54815471
transposeOp, transposeOp.getResult().getType(),
5482-
parentTransposeOp.getVector(),
5483-
vector::getVectorSubscriptAttr(rewriter, permutation));
5472+
parentTransposeOp.getVector(), permutation);
54845473
return success();
54855474
}
54865475
};
@@ -5539,8 +5528,7 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
55395528

55405529
// Get the transpose permutation and apply it to the vector.create_mask or
55415530
// vector.constant_mask operands.
5542-
SmallVector<int64_t> permutation;
5543-
transpOp.getTransp(permutation);
5531+
ArrayRef<int64_t> permutation = transpOp.getPermutation();
55445532

55455533
if (createMaskOp) {
55465534
auto maskOperands = createMaskOp.getOperands();
@@ -5583,9 +5571,7 @@ class FoldTransposeWithNonScalableUnitDimsToShapeCast final
55835571
PatternRewriter &rewriter) const override {
55845572
Value input = transpOp.getVector();
55855573
VectorType resType = transpOp.getResultVectorType();
5586-
5587-
SmallVector<int64_t> permutation;
5588-
transpOp.getTransp(permutation);
5574+
ArrayRef<int64_t> permutation = transpOp.getPermutation();
55895575

55905576
if (resType.getRank() == 2 &&
55915577
((resType.getShape().front() == 1 &&
@@ -5611,10 +5597,6 @@ void vector::TransposeOp::getCanonicalizationPatterns(
56115597
FoldTransposeWithNonScalableUnitDimsToShapeCast>(context);
56125598
}
56135599

5614-
void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
5615-
populateFromInt64AttrArray(getTransp(), results);
5616-
}
5617-
56185600
//===----------------------------------------------------------------------===//
56195601
// ConstantMaskOp
56205602
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
327327
VectorType resType = op.getResultVectorType();
328328

329329
// Set up convenience transposition table.
330-
SmallVector<int64_t> transp;
331-
for (auto attr : op.getTransp())
332-
transp.push_back(cast<IntegerAttr>(attr).getInt());
330+
ArrayRef<int64_t> transp = op.getPermutation();
333331

334332
if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
335333
succeeded(isTranspose2DSlice(op)))

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,7 @@ struct CombineContractABTranspose final
212212
if (!transposeOp)
213213
continue;
214214
AffineMap permutationMap = AffineMap::getPermutationMap(
215-
extractVector<unsigned>(transposeOp.getTransp()),
216-
contractOp.getContext());
215+
transposeOp.getPermutation(), contractOp.getContext());
217216
map = inversePermutation(permutationMap).compose(map);
218217
*operand = transposeOp.getVector();
219218
changed = true;
@@ -279,13 +278,13 @@ struct CombineContractResultTranspose final
279278

280279
// Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
281280
// To index into A in contract, we need revert(f)(g(C)) -> A.
282-
auto accTMap = AffineMap::getPermutationMap(
283-
extractVector<unsigned>(accTOp.getTransp()), context);
281+
auto accTMap =
282+
AffineMap::getPermutationMap(accTOp.getPermutation(), context);
284283

285284
// Contract performs g(C) -> D. Result transpose performs h(D) -> E.
286285
// To index into E in contract, we need h(g(C)) -> E.
287-
auto resTMap = AffineMap::getPermutationMap(
288-
extractVector<unsigned>(resTOp.getTransp()), context);
286+
auto resTMap =
287+
AffineMap::getPermutationMap(resTOp.getPermutation(), context);
289288
auto combinedResMap = resTMap.compose(contractMap);
290289

291290
// The accumulator and result share the same indexing map. So they should be
@@ -490,15 +489,15 @@ struct ReorderElementwiseOpsOnTranspose final
490489

491490
// Make sure all operands are transpose/constant ops and collect their
492491
// transposition maps.
493-
SmallVector<ArrayAttr> transposeMaps;
492+
SmallVector<ArrayRef<int64_t>> transposeMaps;
494493
transposeMaps.reserve(op->getNumOperands());
495494
// Record the initial type before transposition. We'll use its shape later.
496495
// Any type will do here as we will check all transpose maps are the same.
497496
VectorType srcType;
498497
for (Value operand : op->getOperands()) {
499498
auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
500499
if (transposeOp) {
501-
transposeMaps.push_back(transposeOp.getTransp());
500+
transposeMaps.push_back(transposeOp.getPermutation());
502501
srcType = transposeOp.getSourceVectorType();
503502
} else if (!matchPattern(operand, m_Constant())) {
504503
return failure();
@@ -517,7 +516,7 @@ struct ReorderElementwiseOpsOnTranspose final
517516

518517
// If there are constant operands, we need to insert inverse transposes for
519518
// them. Calculate the inverse order first.
520-
auto order = extractVector<unsigned>(transposeMaps.front());
519+
auto order = transposeMaps.front();
521520
SmallVector<int64_t> invOrder(order.size());
522521
for (int i = 0, e = order.size(); i < e; ++i)
523522
invOrder[order[i]] = i;
@@ -532,8 +531,7 @@ struct ReorderElementwiseOpsOnTranspose final
532531
srcType.getShape(),
533532
cast<VectorType>(operand.getType()).getElementType());
534533
srcValues.push_back(rewriter.create<vector::TransposeOp>(
535-
operand.getLoc(), vectorType, operand,
536-
rewriter.getI64ArrayAttr(invOrder)));
534+
operand.getLoc(), vectorType, operand, invOrder));
537535
}
538536
}
539537

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,7 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
537537
// Prepare the result vector;
538538
Value result = rewriter.create<arith::ConstantOp>(
539539
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
540-
SmallVector<int64_t> permutation;
541-
transposeOp.getTransp(permutation);
540+
ArrayRef<int64_t> permutation = transposeOp.getPermutation();
542541

543542
// Unroll the computation.
544543
for (SmallVector<int64_t> elementOffsets :

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,11 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
8787
if (srcGtOneDims.size() != 2)
8888
return failure();
8989

90-
SmallVector<int64_t> transp;
91-
for (auto attr : op.getTransp())
92-
transp.push_back(cast<IntegerAttr>(attr).getInt());
93-
9490
// Check whether the two source vector dimensions that are greater than one
9591
// must be transposed with each other so that we can apply one of the 2-D
9692
// transpose pattens. Otherwise, these patterns are not applicable.
97-
if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
93+
if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
94+
op.getPermutation()))
9895
return failure();
9996

10097
return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);

mlir/lib/IR/AffineMap.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,12 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
254254
assert(permutationMap.isPermutation() && "Invalid permutation vector");
255255
return permutationMap;
256256
}
257+
AffineMap AffineMap::getPermutationMap(ArrayRef<int64_t> permutation,
258+
MLIRContext *context) {
259+
SmallVector<unsigned> perm = llvm::map_to_vector(
260+
permutation, [](int64_t i) { return static_cast<unsigned>(i); });
261+
return AffineMap::getPermutationMap(perm, context);
262+
}
257263

258264
AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
259265
ArrayRef<unsigned> targets,

0 commit comments

Comments
 (0)