18
18
#include " mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19
19
#include " mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20
20
#include " mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21
+ #include " mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
21
22
#include " mlir/Dialect/SPIRV/IR/TargetAndABI.h"
22
23
#include " mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
24
+ #include " mlir/IR/BuiltinAttributes.h"
25
+ #include " mlir/IR/BuiltinTypes.h"
23
26
#include " mlir/IR/TypeUtilities.h"
27
+ #include " llvm/ADT/StringSwitch.h"
24
28
25
- namespace mlir ::nv {
26
- namespace {
29
+ #include < cassert>
27
30
31
+ namespace mlir {
28
32
// / Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
29
33
// / when the elementwise op directly supports with cooperative matrix type.
30
34
// / Returns false if cannot.
31
35
// /
32
36
// / See SPV_NV_cooperative_matrix for supported elementwise ops.
33
37
static bool createElementwiseOp (ConversionPatternRewriter &builder,
34
- gpu::SubgroupMmaElementwiseOp op,
35
- spirv::CooperativeMatrixNVType coopType,
38
+ gpu::SubgroupMmaElementwiseOp op, Type coopType,
36
39
ValueRange operands) {
40
+ assert ((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
41
+ coopType)));
42
+
37
43
switch (op.getOpType ()) {
38
44
case gpu::MMAElementwiseOp::ADDF:
39
45
builder.replaceOpWithNewOp <spirv::FAddOp>(op, coopType, operands);
@@ -71,6 +77,110 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
71
77
return false ;
72
78
}
73
79
80
+ // ===----------------------------------------------------------------------===//
81
+ // SPV_KHR_cooperative_matrix
82
+ // ===----------------------------------------------------------------------===//
83
+
84
+ namespace khr {
85
+ namespace {
86
+
87
+ // / Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
88
+ // / dialect.
89
+ struct WmmaLoadOpToSPIRVLowering final
90
+ : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
91
+ using OpConversionPattern::OpConversionPattern;
92
+
93
+ LogicalResult
94
+ matchAndRewrite (gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
95
+ ConversionPatternRewriter &rewriter) const override {
96
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
97
+ Location loc = op->getLoc ();
98
+
99
+ auto retType = cast<gpu::MMAMatrixType>(op.getRes ().getType ());
100
+ MemRefType memrefType = op.getSrcMemref ().getType ();
101
+ Value bufferPtr =
102
+ spirv::getElementPtr (typeConverter, memrefType, adaptor.getSrcMemref (),
103
+ adaptor.getIndices (), loc, rewriter);
104
+
105
+ auto coopType =
106
+ typeConverter.convertType <spirv::CooperativeMatrixType>(retType);
107
+ if (!coopType)
108
+ return rewriter.notifyMatchFailure (op, " type conversion failed" );
109
+
110
+ int64_t stride = op.getLeadDimension ().getSExtValue ();
111
+ IntegerType i32Type = rewriter.getI32Type ();
112
+ auto strideValue = rewriter.create <spirv::ConstantOp>(
113
+ loc, i32Type, IntegerAttr::get (i32Type, stride));
114
+
115
+ bool isColMajor = op.getTranspose ().value_or (false );
116
+ auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
117
+ : spirv::CooperativeMatrixLayoutKHR::RowMajor;
118
+
119
+ rewriter.replaceOpWithNewOp <spirv::KHRCooperativeMatrixLoadOp>(
120
+ op, coopType, bufferPtr, strideValue, layout);
121
+ return success ();
122
+ }
123
+ };
124
+
125
+ // / Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
126
+ // / dialect.
127
+ struct WmmaStoreOpToSPIRVLowering final
128
+ : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
129
+ using OpConversionPattern::OpConversionPattern;
130
+
131
+ LogicalResult
132
+ matchAndRewrite (gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
133
+ ConversionPatternRewriter &rewriter) const override {
134
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
135
+ Location loc = op->getLoc ();
136
+
137
+ auto memrefType = cast<MemRefType>(op.getDstMemref ().getType ());
138
+ Value bufferPtr =
139
+ spirv::getElementPtr (typeConverter, memrefType, adaptor.getDstMemref (),
140
+ adaptor.getIndices (), loc, rewriter);
141
+
142
+ int64_t stride = op.getLeadDimension ().getSExtValue ();
143
+ IntegerType i32Type = rewriter.getI32Type ();
144
+ auto strideValue = rewriter.create <spirv::ConstantOp>(
145
+ loc, i32Type, IntegerAttr::get (i32Type, stride));
146
+
147
+ bool isColMajor = op.getTranspose ().value_or (false );
148
+ auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
149
+ : spirv::CooperativeMatrixLayoutKHR::RowMajor;
150
+
151
+ rewriter.replaceOpWithNewOp <spirv::KHRCooperativeMatrixStoreOp>(
152
+ op, bufferPtr, adaptor.getSrc (), strideValue, layout);
153
+ return success ();
154
+ }
155
+ };
156
+
157
+ // / Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
158
+ // / dialect.
159
+ struct WmmaMmaOpToSPIRVLowering final
160
+ : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
161
+ using OpConversionPattern::OpConversionPattern;
162
+
163
+ LogicalResult
164
+ matchAndRewrite (gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
165
+ OpAdaptor adaptor,
166
+ ConversionPatternRewriter &rewriter) const override {
167
+ rewriter.replaceOpWithNewOp <spirv::KHRCooperativeMatrixMulAddOp>(
168
+ subgroupMmaComputeOp, adaptor.getOpA (), adaptor.getOpB (),
169
+ adaptor.getOpC ());
170
+ return success ();
171
+ }
172
+ };
173
+
174
+ } // namespace
175
+ } // namespace khr
176
+
177
+ // ===----------------------------------------------------------------------===//
178
+ // SPV_NV_cooperative_matrix
179
+ // ===----------------------------------------------------------------------===//
180
+
181
+ namespace nv {
182
+ namespace {
183
+
74
184
// / Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
75
185
// / dialect.
76
186
struct WmmaLoadOpToSPIRVLowering final
@@ -247,7 +357,8 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
247
357
};
248
358
249
359
} // namespace
250
- } // namespace mlir::nv
360
+ } // namespace nv
361
+ } // namespace mlir
251
362
252
363
mlir::spirv::CooperativeMatrixNVType
253
364
mlir::convertMMAToSPIRVCoopMatrixNVType (gpu::MMAMatrixType type) {
@@ -257,6 +368,30 @@ mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
257
368
elementType, spirv::Scope::Subgroup, retTypeShape[0 ], retTypeShape[1 ]);
258
369
}
259
370
371
+ mlir::spirv::CooperativeMatrixType
372
+ mlir::convertMMAToSPIRVCoopMatrixType (gpu::MMAMatrixType type) {
373
+ ArrayRef<int64_t > retTypeShape = type.getShape ();
374
+ Type elementType = type.getElementType ();
375
+
376
+ auto use =
377
+ llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand ())
378
+ .Case (" AOp" , spirv::CooperativeMatrixUseKHR::MatrixA)
379
+ .Case (" BOp" , spirv::CooperativeMatrixUseKHR::MatrixB)
380
+ .Default (spirv::CooperativeMatrixUseKHR::MatrixAcc);
381
+
382
+ return spirv::CooperativeMatrixType::get (elementType, retTypeShape[0 ],
383
+ retTypeShape[1 ],
384
+ spirv::Scope::Subgroup, use);
385
+ }
386
+
387
+ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns (
388
+ SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
389
+ using namespace mlir ;
390
+ MLIRContext *context = patterns.getContext ();
391
+ patterns.add <khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
392
+ khr::WmmaStoreOpToSPIRVLowering>(converter, context);
393
+ }
394
+
260
395
void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns (
261
396
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
262
397
using namespace mlir ;
0 commit comments