Skip to content

Commit d20fbc9

Browse files
authored
[MLIR][NVGPU] Introduce nvgpu.wargroup.mma.store Op for Hopper GPUs (llvm#65441)
This PR introduces a new Op called `warpgroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted result(s) `nvgpu.warpgroup.accumulator` produced by `warpgroup.mma` to the given memref. An example of fragmentated matrix is given here : https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d The `warpgroup.mma.store` does followings: 1) Takes one or more `nvgpu.warpgroup.accumulator` type (fragmented results matrix) 2) Calculates indexes per thread in warp-group and stores the data into give memref. Here's an example usage: ``` // A warpgroup performs GEMM, results in fragmented matrix %result1, %result2 = nvgpu.warpgroup.mma ... // Stores the fragmented result to memref nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> to memref<128x128xf32,3> ```
1 parent c7d6d62 commit d20fbc9

File tree

6 files changed

+300
-2
lines changed

6 files changed

+300
-2
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

+20
Original file line numberDiff line numberDiff line change
@@ -728,4 +728,24 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
728728
let hasVerifier = 1;
729729
}
730730

731+
def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
732+
let description = [{
733+
The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
734+
in $matrixD to given memref.
735+
736+
[See the details of register fragment layout for accumulator matrix D]
737+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
738+
739+
Note that, the op must be run with warp group.
740+
}];
741+
742+
let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
743+
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
744+
745+
let assemblyFormat = [{
746+
`[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
747+
}];
748+
let hasVerifier = 1;
749+
}
750+
731751
#endif // NVGPU

mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRNVGPUToNVVM
1717
MLIRLLVMDialect
1818
MLIRNVGPUDialect
1919
MLIRNVVMDialect
20+
MLIRArithDialect
2021
MLIRPass
2122
MLIRSCFTransforms
2223
MLIRTransforms

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

+115-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1212
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1313
#include "mlir/Conversion/LLVMCommon/Pattern.h"
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
1415
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1516
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1617
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -394,8 +395,8 @@ struct ConvertNVGPUToNVVMPass
394395
using Base::Base;
395396

396397
void getDependentDialects(DialectRegistry &registry) const override {
397-
registry
398-
.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
398+
registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
399+
arith::ArithDialect>();
399400
}
400401

401402
void runOnOperation() override {
@@ -436,6 +437,7 @@ struct ConvertNVGPUToNVVMPass
436437
populateNVGPUToNVVMConversionPatterns(converter, patterns);
437438
LLVMConversionTarget target(getContext());
438439
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
440+
target.addLegalDialect<::mlir::arith::ArithDialect>();
439441
target.addLegalDialect<::mlir::memref::MemRefDialect>();
440442
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
441443
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
@@ -1434,6 +1436,116 @@ struct NVGPUWarpgroupMmaOpLowering
14341436
}
14351437
};
14361438

1439+
struct NVGPUWarpgroupMmaStoreOpLowering
1440+
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1441+
using ConvertOpToLLVMPattern<
1442+
nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1443+
1444+
/// This function stores a fragmented register matrix owned by a warp group
1445+
/// (128 threads) into a memref. Each thread has 64 registers, each the size
1446+
/// of a struct.
1447+
/// Here is what each threads (T) holds, each `d` is struct value with a
1448+
/// number.
1449+
///
1450+
/// Threads in warp-group (128 threads) and what they owns in the matrixD:
1451+
/// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1452+
/// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1453+
/// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1454+
/// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1455+
///
1456+
/// Matrix-D:
1457+
/// +______________________________________________________________________+
1458+
/// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1459+
/// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1460+
/// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1461+
/// ..| .........|.........|.........|.........|........|...........|........|
1462+
/// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1463+
/// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1464+
/// ..| .........|.........|.........|.........|........|...........|........|
1465+
/// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1466+
/// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1467+
/// ..| .........|.........|.........|.........|........|...........|........|
1468+
/// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1469+
/// ..| .........|.........|.........|.........|........|...........|........|
1470+
/// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1471+
/// ..| .........|.........|.........|.........|........|...........|........|
1472+
/// +______________________________________________________________________+
1473+
///
1474+
/// \param rewriter: The pattern rewriter.
1475+
/// \param matrixD: Result of the warp-group MMA operation (fragmented
1476+
/// matrix). It is holded by a thread and a struct with 64 elements.
1477+
/// \param dstMemref: The memref where the registers will be stored.
1478+
/// \param offset: the offset within the memref where the registers will be
1479+
/// stored.
1480+
void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1481+
TypedValue<MemRefType> dstMemref,
1482+
int offset) const {
1483+
Type i32 = b.getI32Type();
1484+
1485+
auto makeConst = [&](int32_t index) -> Value {
1486+
return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1487+
};
1488+
Value c1 = makeConst(1);
1489+
Value c2 = makeConst(2);
1490+
Value c4 = makeConst(4);
1491+
Value c8 = makeConst(8);
1492+
Value c16 = makeConst(16);
1493+
Value warpSize = makeConst(kWarpSize);
1494+
1495+
auto makeMul = [&](Value lhs, Value rhs) -> Value {
1496+
return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1497+
};
1498+
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1499+
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1500+
};
1501+
1502+
Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1503+
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1504+
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1505+
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1506+
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1507+
1508+
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1509+
TypedValue<::mlir::MemRefType> memref) {
1510+
Type it = b.getIndexType();
1511+
Value idx = b.create<arith::IndexCastOp>(it, x);
1512+
Value idy0 = b.create<arith::IndexCastOp>(it, y);
1513+
Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1514+
Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1515+
Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1516+
b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1517+
b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1518+
};
1519+
1520+
Value tj = makeMul(lane4modId, c2);
1521+
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1522+
if (offset)
1523+
ti = makeAdd(ti, makeConst(offset));
1524+
for (int i = 0; i < 2; ++i) {
1525+
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1526+
for (int j = 0; j < 16; ++j) {
1527+
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1528+
int sIndex = i * 2 + j * 4;
1529+
makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
1530+
}
1531+
}
1532+
}
1533+
1534+
LogicalResult
1535+
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1536+
ConversionPatternRewriter &rewriter) const override {
1537+
int offset = 0;
1538+
ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
1539+
for (Value matrixD : adaptor.getMatrixD()) {
1540+
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
1541+
storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
1542+
offset += structType.getBody().size();
1543+
}
1544+
rewriter.eraseOp(op);
1545+
return success();
1546+
}
1547+
};
1548+
14371549
} // namespace
14381550

14391551
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1450,6 +1562,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
14501562
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
14511563
NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
14521564
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1565+
NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
14531566
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
14541567
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
14551568
NVGPUMmaSparseSyncLowering>(converter);

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
1414
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1516
#include "mlir/IR/Builders.h"
1617
#include "mlir/IR/BuiltinAttributes.h"
1718
#include "mlir/IR/BuiltinTypes.h"
@@ -529,6 +530,39 @@ LogicalResult WarpgroupMmaOp::verify() {
529530
return success();
530531
}
531532

533+
LogicalResult WarpgroupMmaStoreOp::verify() {
534+
MemRefType dstMemrefType = getDstMemref().getType();
535+
VectorType firstVtype = getMatrixD()
536+
.front()
537+
.getType()
538+
.cast<WarpgroupAccumulatorType>()
539+
.getFragmented();
540+
541+
int64_t totalFirstDimension = 0;
542+
for (Value result : getMatrixD()) {
543+
VectorType vtype =
544+
result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
545+
if (vtype != firstVtype)
546+
return emitOpError() << "all fragmented types must be the same";
547+
// Limitation
548+
if (!vtype.getElementType().isF32()) {
549+
return emitOpError()
550+
<< "hit a limitation: only f32 results for the time being";
551+
}
552+
totalFirstDimension += vtype.getDimSize(0);
553+
}
554+
if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
555+
firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
556+
return emitOpError() << "results [" << totalFirstDimension << "]["
557+
<< firstVtype.getDimSize(1)
558+
<< "] values. However, destination memref["
559+
<< dstMemrefType.getDimSize(0) << "]["
560+
<< dstMemrefType.getDimSize(1)
561+
<< "] does not have same size as results";
562+
}
563+
return success();
564+
}
565+
532566
//===----------------------------------------------------------------------===//
533567
// TableGen'd dialect, type, and op definitions
534568
//===----------------------------------------------------------------------===//

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

+129
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,135 @@ func.func @warpgroup_mma_128_128_64(
772772
return
773773
}
774774

775+
// CHECK-LABEL: @warpgroup_mma_store(
776+
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
777+
func.func @warpgroup_mma_store(
778+
%result1 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
779+
%result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
780+
%matrixD: memref<128x128xf32,3>) {
781+
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
782+
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
783+
// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
784+
// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
785+
// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
786+
// CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
787+
// CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
788+
// CHECK: %[[WarpSize:.+]] = llvm.mlir.constant(32 : i32) : i32
789+
790+
// ### Store {d0, d1} of each thread ###
791+
792+
// CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
793+
// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[WarpSize]] : i32
794+
// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[WarpSize]] : i32
795+
// CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32
796+
// CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32
797+
// CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32
798+
// CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]] : i32
799+
// CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]] : i32
800+
// CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32
801+
// CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]] : i32
802+
// CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]] : i32
803+
// CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32
804+
// CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]] : i32
805+
// CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]] : i32
806+
// CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index
807+
// CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
808+
// CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32
809+
// CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
810+
// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
811+
// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
812+
// CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
813+
// CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>
814+
815+
// ### Store {d2, d3} of each thread ###
816+
817+
// CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32
818+
// CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]] : i32
819+
// CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]] : i32
820+
// CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index
821+
// CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
822+
// CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32
823+
// CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
824+
// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
825+
// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
826+
// CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
827+
// CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>
828+
829+
// ### Store {d4, d5} of each thread ###
830+
831+
// CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32
832+
// CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]] : i32
833+
// CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]] : i32
834+
// CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index
835+
// CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
836+
// CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32
837+
// CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
838+
// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
839+
// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
840+
// CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
841+
// CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>
842+
843+
// ### Store {d6, d7} of each thread ###
844+
845+
// CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32
846+
// CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]] : i32
847+
// CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]] : i32
848+
// CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index
849+
// CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
850+
// CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32
851+
// CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
852+
// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
853+
// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
854+
// CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
855+
// CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>
856+
857+
// Pattern continues similarly 28x times until {... d62, d63}
858+
859+
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
860+
// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
861+
862+
// ### Store {d64, d65} of each thread ###
863+
864+
// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
865+
// CHECK: %[[S312:.+]] = llvm.mlir.constant(2 : i32) : i32
866+
// CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
867+
// CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
868+
// CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
869+
// CHECK: %[[WS2:.+]] = llvm.mlir.constant(32 : i32) : i32
870+
// CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
871+
// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[WS2]] : i32
872+
// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[WS2]] : i32
873+
// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]], %[[S311]] : i32
874+
// CHECK: %[[S321:.+]] = llvm.urem %[[S318]], %[[S311]] : i32
875+
// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S312]] : i32
876+
// CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32
877+
// CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32
878+
// CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32
879+
// CHECK: %[[S326:.+]] = llvm.add %[[S324]], %[[S325]] : i32
880+
// CHECK: %[[S327:.+]] = llvm.mlir.constant(0 : i32) : i32
881+
// CHECK: %[[S328:.+]] = llvm.mul %[[S327]], %[[S313]] : i32
882+
// CHECK: %[[S329:.+]] = llvm.add %[[S326]], %[[S328]] : i32
883+
// CHECK: %[[S330:.+]] = llvm.mlir.constant(0 : i32) : i32
884+
// CHECK: %[[S331:.+]] = llvm.mul %[[S330]], %[[S313]] : i32
885+
// CHECK: %[[S332:.+]] = llvm.add %[[S322]], %[[S331]] : i32
886+
// CHECK: %[[S333:.+]] = arith.index_cast %[[S329]] : i32 to index
887+
// CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index
888+
// CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]] : i32
889+
// CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index
890+
// CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0]
891+
// CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1]
892+
// CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3>
893+
// CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3>
894+
895+
// Pattern continues similarly 31x times until {... d126, d127}
896+
897+
nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD :
898+
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
899+
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
900+
to memref<128x128xf32,3>
901+
return
902+
}
903+
775904
transform.sequence failures(propagate) {
776905
^bb1(%arg1: !transform.any_op):
777906
%0 = transform.structured.match ops{["func.func"]} in %arg1

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -5308,6 +5308,7 @@ cc_library(
53085308
":LLVMCommonConversion",
53095309
":LLVMDialect",
53105310
":MemRefDialect",
5311+
":MLIRArithDialect",
53115312
":NVGPUDialect",
53125313
":NVVMDialect",
53135314
":Pass",

0 commit comments

Comments
 (0)