-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[MLIR][NVGPU] Change name wgmma.descriptor
to warpgroup.descriptor
(NFC)
#67526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…` (NFC) NVGPU dialect is gaining large support for warpgroup level operations, and their names always starts with `warpgroup....`. This PR changes name of Op and type from `wgmma.descriptor` to `warpgroup.descriptor` for sake of consistency.
@llvm/pr-subscribers-mlir-nvgpu @llvm/pr-subscribers-mlir ChangesNVGPU dialect is gaining large support for warpgroup level operations, and their names always starts with This PR changes name of Op and type from Full diff: https://github.com/llvm/llvm-project/pull/67526.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 31b137160545772..3e657da52be5f72 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -174,7 +174,7 @@ def NVGPU_TensorMapDescriptor : NVGPU_Type<"TensorMapDescriptor", "tensormap.des
let assemblyFormat = "`<` struct(params) `>`";
}
-def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "wgmma.descriptor", []> {
+def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "warpgroup.descriptor", []> {
let summary = "Warpgroup matrix descriptor type";
let description = [{
The descriptor specifies the properties of the matrix in shared memory that
@@ -667,11 +667,12 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
let hasVerifier = 1;
}
-def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
- let summary = "Generate a wgmma matrix descriptor";
+def NVGPU_GenerateWarpgroupDescriptorOp : NVGPU_Op<"warpgroup.generate.descriptor", []> {
+ let summary = "Generate a warpgroup matrix descriptor";
let description = [{
- This Op builds a `nvgpu.wgmma.descriptor` that is used by warpgroup-level
- matrix multiply and accumulate.
+ This Op builds a `nvgpu.warpgroup.descriptor` that is used by
+ `nvgpu.warpgroup.mma` to perform warpgroup-level matrix multiply and
+ accumulate.
The descriptor specifies the properties of the matrix in shared memory that
is a multiplicand in the matrix multiply and accumulate operation.
@@ -702,9 +703,9 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
Example:
```mlir
- %r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2:
- !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
- !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
+ %r1,%r2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2:
+ !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
+ !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
->
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..d4bca1d8c846576 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -967,13 +967,13 @@ struct NVGPUTmaAsyncLoadOpLowering
return success();
}
};
-struct NVGPUGenerateGmmaDescriptorLowering
- : public ConvertOpToLLVMPattern<nvgpu::GenerateGmmaDescriptorOp> {
+struct NVGPUGenerateWarpgroupDescriptorLowering
+ : public ConvertOpToLLVMPattern<nvgpu::GenerateWarpgroupDescriptorOp> {
using ConvertOpToLLVMPattern<
- nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern;
+ nvgpu::GenerateWarpgroupDescriptorOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor,
+ matchAndRewrite(nvgpu::GenerateWarpgroupDescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
@@ -1037,7 +1037,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating wgmma.descriptor: "
+ LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
<< "leading_off:" << leadDimVal << "\t"
<< "stride_off :" << strideDimVal << "\t"
<< "base_offset:" << offsetVal << "\t"
@@ -1320,8 +1320,8 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
- NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
- NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
+ NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
+ NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index dfec17986800417..eb8fc4b65bc89ad 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -367,10 +367,10 @@ LogicalResult TmaCreateDescriptorOp::verify() {
}
//===----------------------------------------------------------------------===//
-// NVGPU_GenerateGmmaDescriptorOp
+// NVGPU_GenerateWarpgroupDescriptorOp
//===----------------------------------------------------------------------===//
-LogicalResult GenerateGmmaDescriptorOp::verify() {
+LogicalResult GenerateWarpgroupDescriptorOp::verify() {
MemRefType memrefType = getTensor().getType();
MemRefType tensorMapType = getTensorMap().getType().getTensor();
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 8c2f8dbbd5ad9a3..3710b06288e2a7f 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -674,7 +674,7 @@ module @mymodule {
!tensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16,3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
memref.global "private" @dynamicShmem : memref<0xf16,3>
// CHECK-LABEL: func @create_wgmma_descriptor(
-func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>{
+func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>{
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
// CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
@@ -706,22 +706,22 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc
// CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]] : i64
// CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]] : i64
- // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+ // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
// CHECK: return %[[ret]]
- %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
- func.return %descA : !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
+ %descA = nvgpu.warpgroup.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>
+ func.return %descA : !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>
}
// CHECK-LABEL: @warpgroup_mma_128_128_64(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
func.func @warpgroup_mma_128_128_64(
- %descA: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
- %descB: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
+ %descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
+ %descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
%acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
%acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>)
{
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvvm.wgmma.fence.aligned
@@ -762,8 +762,8 @@ func.func @warpgroup_mma_128_128_64(
// CHECK: nvvm.wgmma.commit.group.sync.aligned
// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
%wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}:
- !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
- !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
+ !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
+ !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
->
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index ff391e469815d74..66652070ec15f34 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -225,8 +225,8 @@ func.func @async_cp_size_invalid_f64(
// -----
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
-!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x121xf16, 3>>
+!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x121xf16, 3>>
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}}
@@ -237,8 +237,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
// -----
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<128xf32>>
-!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>
+!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
// expected-error @+1 {{'nvgpu.warpgroup.mma' op has matrices A, B, C and D, they must be 2 dimensional}}
%0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
@@ -247,8 +247,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
// -----
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
-!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf32, 3>>
+!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf32, 3>>
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}}
%0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
@@ -258,8 +258,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
// -----
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
-!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x512xf16, 3>>
+!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x512xf16, 3>>
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 512 ) != 2nd dim matrix-C ( 128 )}}
%0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
|
NVGPU dialect is gaining large support for warpgroup level operations, and their names always starts with
warpgroup....
.This PR changes name of Op and type from
wgmma.descriptor
towarpgroup.descriptor
for sake of consistency.