Skip to content

Commit ed4daea

Browse files
authored
[mlir][spirv][gpu] Add conversion for load/store/mad coop matrix ops (#66311)
This is plugged in as an alternative lowering path in the gpu to spirv dialect conversion. Add custom op builders for coop matrix ops to make the create functions nicer to work with and less error-prone. The latter is accomplished by following the op syntax and also requiring stride to be a constant op to avoid confusion around the order of arguments. The remaining lowering patterns will be added in a future patch.
1 parent f66cd9e commit ed4daea

File tree

7 files changed

+278
-11
lines changed

7 files changed

+278
-11
lines changed

mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,21 @@ class MMAMatrixType;
3030
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
3131
RewritePatternSet &patterns);
3232

33+
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
34+
/// using the KHR Cooperative Matrix extension.
35+
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
36+
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
37+
3338
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
3439
/// using the NV Cooperative Matrix extension.
3540
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
3641
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
3742

43+
/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
44+
/// `type`.
45+
spirv::CooperativeMatrixType
46+
convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);
47+
3848
/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
3949
/// `type`.
4050
spirv::CooperativeMatrixNVType

mlir/include/mlir/Conversion/Passes.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,11 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
567567
let options = [
568568
Option<"use64bitIndex", "use-64bit-index",
569569
"bool", /*default=*/"false",
570-
"Use 64-bit integers to convert index types">
570+
"Use 64-bit integers to convert index types">,
571+
Option<"useCoopMatrixNV", "use-coop-matrix-nv",
572+
"bool", /*default=*/"true",
573+
"Use the NV cooperative matrix extension insted of the KHR extension"
574+
" to lower GPU WMMA ops">,
571575
];
572576
}
573577

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,15 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
146146
let results = (outs
147147
SPIRV_AnyCooperativeMatrix:$result
148148
);
149+
150+
let builders = [
151+
OpBuilder<(ins "Type":$result, "Value":$pointer,
152+
"spirv::ConstantOp":$stride,
153+
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
154+
build($_builder, $_state, result, pointer, layout, stride,
155+
spirv::MemoryAccessAttr{});
156+
}]>
157+
];
149158
}
150159

151160
// -----
@@ -226,6 +235,15 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
226235
);
227236

228237
let results = (outs);
238+
239+
let builders = [
240+
OpBuilder<(ins "Value":$pointer, "Value":$object,
241+
"spirv::ConstantOp":$stride,
242+
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
243+
build($_builder, $_state, pointer, object, layout, stride,
244+
spirv::MemoryAccessAttr{});
245+
}]>
246+
];
229247
}
230248

231249
// -----
@@ -332,6 +350,13 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
332350
let results = (outs
333351
SPIRV_AnyCooperativeMatrix:$result
334352
);
353+
354+
let builders = [
355+
OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c), [{
356+
build($_builder, $_state, a, b, c,
357+
spirv::CooperativeMatrixOperandsKHRAttr{});
358+
}]>
359+
];
335360
}
336361

337362
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,25 @@ void GPUToSPIRVPass::runOnOperation() {
8686
SPIRVConversionOptions options;
8787
options.use64bitIndex = this->use64bitIndex;
8888
SPIRVTypeConverter typeConverter(targetAttr, options);
89-
typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
90-
return convertMMAToSPIRVCoopMatrixNVType(type);
89+
90+
typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
91+
gpu::MMAMatrixType type) -> Type {
92+
if (useNV)
93+
return convertMMAToSPIRVCoopMatrixNVType(type);
94+
95+
return convertMMAToSPIRVCoopMatrixType(type);
9196
});
97+
9298
RewritePatternSet patterns(context);
9399
populateGPUToSPIRVPatterns(typeConverter, patterns);
94-
populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
95-
patterns);
100+
if (this->useCoopMatrixNV) {
101+
populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
102+
patterns);
103+
} else {
104+
populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
105+
patterns);
106+
}
107+
96108
// TODO: Change SPIR-V conversion to be progressive and remove the following
97109
// patterns.
98110
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 140 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,28 @@
1818
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1919
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
2020
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21+
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
2122
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
2223
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
24+
#include "mlir/IR/BuiltinAttributes.h"
25+
#include "mlir/IR/BuiltinTypes.h"
2326
#include "mlir/IR/TypeUtilities.h"
27+
#include "llvm/ADT/StringSwitch.h"
2428

25-
namespace mlir::nv {
26-
namespace {
29+
#include <cassert>
2730

31+
namespace mlir {
2832
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
2933
/// when the elementwise op directly supports with cooperative matrix type.
3034
/// Returns false if cannot.
3135
///
3236
/// See SPV_NV_cooperative_matrix for supported elementwise ops.
3337
static bool createElementwiseOp(ConversionPatternRewriter &builder,
34-
gpu::SubgroupMmaElementwiseOp op,
35-
spirv::CooperativeMatrixNVType coopType,
38+
gpu::SubgroupMmaElementwiseOp op, Type coopType,
3639
ValueRange operands) {
40+
assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
41+
coopType)));
42+
3743
switch (op.getOpType()) {
3844
case gpu::MMAElementwiseOp::ADDF:
3945
builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
@@ -71,6 +77,110 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
7177
return false;
7278
}
7379

80+
//===----------------------------------------------------------------------===//
81+
// SPV_KHR_cooperative_matrix
82+
//===----------------------------------------------------------------------===//
83+
84+
namespace khr {
85+
namespace {
86+
87+
/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
88+
/// dialect.
89+
struct WmmaLoadOpToSPIRVLowering final
90+
: OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
91+
using OpConversionPattern::OpConversionPattern;
92+
93+
LogicalResult
94+
matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
95+
ConversionPatternRewriter &rewriter) const override {
96+
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
97+
Location loc = op->getLoc();
98+
99+
auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
100+
MemRefType memrefType = op.getSrcMemref().getType();
101+
Value bufferPtr =
102+
spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
103+
adaptor.getIndices(), loc, rewriter);
104+
105+
auto coopType =
106+
typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
107+
if (!coopType)
108+
return rewriter.notifyMatchFailure(op, "type conversion failed");
109+
110+
int64_t stride = op.getLeadDimension().getSExtValue();
111+
IntegerType i32Type = rewriter.getI32Type();
112+
auto strideValue = rewriter.create<spirv::ConstantOp>(
113+
loc, i32Type, IntegerAttr::get(i32Type, stride));
114+
115+
bool isColMajor = op.getTranspose().value_or(false);
116+
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
117+
: spirv::CooperativeMatrixLayoutKHR::RowMajor;
118+
119+
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
120+
op, coopType, bufferPtr, strideValue, layout);
121+
return success();
122+
}
123+
};
124+
125+
/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
126+
/// dialect.
127+
struct WmmaStoreOpToSPIRVLowering final
128+
: OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
129+
using OpConversionPattern::OpConversionPattern;
130+
131+
LogicalResult
132+
matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
133+
ConversionPatternRewriter &rewriter) const override {
134+
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
135+
Location loc = op->getLoc();
136+
137+
auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
138+
Value bufferPtr =
139+
spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
140+
adaptor.getIndices(), loc, rewriter);
141+
142+
int64_t stride = op.getLeadDimension().getSExtValue();
143+
IntegerType i32Type = rewriter.getI32Type();
144+
auto strideValue = rewriter.create<spirv::ConstantOp>(
145+
loc, i32Type, IntegerAttr::get(i32Type, stride));
146+
147+
bool isColMajor = op.getTranspose().value_or(false);
148+
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
149+
: spirv::CooperativeMatrixLayoutKHR::RowMajor;
150+
151+
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
152+
op, bufferPtr, adaptor.getSrc(), strideValue, layout);
153+
return success();
154+
}
155+
};
156+
157+
/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
158+
/// dialect.
159+
struct WmmaMmaOpToSPIRVLowering final
160+
: OpConversionPattern<gpu::SubgroupMmaComputeOp> {
161+
using OpConversionPattern::OpConversionPattern;
162+
163+
LogicalResult
164+
matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
165+
OpAdaptor adaptor,
166+
ConversionPatternRewriter &rewriter) const override {
167+
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
168+
subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
169+
adaptor.getOpC());
170+
return success();
171+
}
172+
};
173+
174+
} // namespace
175+
} // namespace khr
176+
177+
//===----------------------------------------------------------------------===//
178+
// SPV_NV_cooperative_matrix
179+
//===----------------------------------------------------------------------===//
180+
181+
namespace nv {
182+
namespace {
183+
74184
/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
75185
/// dialect.
76186
struct WmmaLoadOpToSPIRVLowering final
@@ -247,7 +357,8 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
247357
};
248358

249359
} // namespace
250-
} // namespace mlir::nv
360+
} // namespace nv
361+
} // namespace mlir
251362

252363
mlir::spirv::CooperativeMatrixNVType
253364
mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
@@ -257,6 +368,30 @@ mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
257368
elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
258369
}
259370

371+
mlir::spirv::CooperativeMatrixType
372+
mlir::convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type) {
373+
ArrayRef<int64_t> retTypeShape = type.getShape();
374+
Type elementType = type.getElementType();
375+
376+
auto use =
377+
llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
378+
.Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
379+
.Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
380+
.Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
381+
382+
return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
383+
retTypeShape[1],
384+
spirv::Scope::Subgroup, use);
385+
}
386+
387+
void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
388+
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
389+
using namespace mlir;
390+
MLIRContext *context = patterns.getContext();
391+
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
392+
khr::WmmaStoreOpToSPIRVLowering>(converter, context);
393+
}
394+
260395
void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
261396
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
262397
using namespace mlir;
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \
2+
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
3+
4+
module attributes {
5+
gpu.container_module,
6+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
7+
[Shader, CooperativeMatrixKHR, Float16],
8+
[SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]>,
9+
#spirv.resource_limits<>>} {
10+
11+
gpu.module @kernels {
12+
// CHECK-LABEL: spirv.func @gpu_wmma_load_op
13+
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
14+
gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
15+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
16+
%i = arith.constant 16 : index
17+
%j = arith.constant 16 : index
18+
// CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32
19+
// CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <RowMajor> :
20+
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
21+
%0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} :
22+
memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
23+
24+
// CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <ColumnMajor> :
25+
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
26+
%1 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} :
27+
memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
28+
// CHECK: spirv.Return
29+
gpu.return
30+
}
31+
32+
// CHECK-LABEL: spirv.func @gpu_wmma_store_op
33+
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
34+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
35+
gpu.func @gpu_wmma_store_op(%arg0: memref<32x32xf16, #spirv.storage_class<StorageBuffer>>,
36+
%arg1: !gpu.mma_matrix<16x16xf16, "COp">) kernel
37+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
38+
%i = arith.constant 16 : index
39+
%j = arith.constant 16 : index
40+
// CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32
41+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <RowMajor> :
42+
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
43+
gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index} :
44+
!gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
45+
46+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <ColumnMajor> :
47+
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
48+
gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index, transpose} :
49+
!gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
50+
// CHECK: spirv.Return
51+
gpu.return
52+
}
53+
54+
// CHECK-LABEL: spirv.func @gpu_wmma_mma_op
55+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
56+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
57+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
58+
gpu.func @gpu_wmma_mma_op(%A: !gpu.mma_matrix<16x16xf16, "AOp">,
59+
%B: !gpu.mma_matrix<16x16xf16, "BOp">,
60+
%C: !gpu.mma_matrix<16x16xf16, "COp">,
61+
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
62+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
63+
// CHECK: %[[MAD:.*]] = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
64+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>,
65+
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
66+
// CHECK-SAME: -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
67+
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">,
68+
!gpu.mma_matrix<16x16xf16, "BOp">
69+
-> !gpu.mma_matrix<16x16xf16, "COp">
70+
71+
%i = arith.constant 0 : index
72+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
73+
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
74+
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
75+
// CHECK: spirv.Return
76+
gpu.return
77+
}
78+
79+
}
80+
}

mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s
1+
// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=true" \
2+
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
23

34
module attributes {
45
gpu.container_module,

0 commit comments

Comments
 (0)