Skip to content

[MLIR][NVGPU] Change name wgmma.descriptor to warpgroup.descriptor (NFC) #67526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 5, 2023

Conversation

grypp
Copy link
Member

@grypp grypp commented Sep 27, 2023

NVGPU dialect is gaining large support for warpgroup level operations, and their names always starts with warpgroup.....

This PR changes name of Op and type from wgmma.descriptor to warpgroup.descriptor for sake of consistency.

…` (NFC)

NVGPU dialect is gaining large support for warpgroup level operations, and their names always starts with `warpgroup....`.

This PR changes name of Op and type from `wgmma.descriptor` to `warpgroup.descriptor` for sake of consistency.
@llvmbot
Copy link
Member

llvmbot commented Sep 27, 2023

@llvm/pr-subscribers-mlir-nvgpu
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Changes

NVGPU dialect is gaining large support for warpgroup level operations, and their names always starts with warpgroup.....

This PR changes name of Op and type from wgmma.descriptor to warpgroup.descriptor for sake of consistency.


Full diff: https://github.com/llvm/llvm-project/pull/67526.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+9-8)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+7-7)
  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+2-2)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+11-11)
  • (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+8-8)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 31b137160545772..3e657da52be5f72 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -174,7 +174,7 @@ def NVGPU_TensorMapDescriptor : NVGPU_Type<"TensorMapDescriptor", "tensormap.des
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
-def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "wgmma.descriptor", []> {
+def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "warpgroup.descriptor", []> {
   let summary = "Warpgroup matrix descriptor type";
   let description = [{
   The descriptor specifies the properties of the matrix in shared memory that 
@@ -667,11 +667,12 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
   let hasVerifier = 1;
 }
 
-def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
-  let summary = "Generate a wgmma matrix descriptor";
+def NVGPU_GenerateWarpgroupDescriptorOp : NVGPU_Op<"warpgroup.generate.descriptor", []> {
+  let summary = "Generate a warpgroup matrix descriptor";
   let description = [{
-  This Op builds a `nvgpu.wgmma.descriptor` that is used by warpgroup-level 
-  matrix multiply and accumulate.
+  This Op builds a `nvgpu.warpgroup.descriptor` that is used by 
+  `nvgpu.warpgroup.mma` to perform warpgroup-level matrix multiply and 
+  accumulate.
 
   The descriptor specifies the properties of the matrix in shared memory that 
   is a multiplicand in the matrix multiply and accumulate operation. 
@@ -702,9 +703,9 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
 
     Example:
     ```mlir
-      %r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2: 
-                 !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, 
-                 !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, 
+      %r1,%r2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2: 
+                 !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
+                 !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
                  !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
                  !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
                  -> 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..d4bca1d8c846576 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -967,13 +967,13 @@ struct NVGPUTmaAsyncLoadOpLowering
     return success();
   }
 };
-struct NVGPUGenerateGmmaDescriptorLowering
-    : public ConvertOpToLLVMPattern<nvgpu::GenerateGmmaDescriptorOp> {
+struct NVGPUGenerateWarpgroupDescriptorLowering
+    : public ConvertOpToLLVMPattern<nvgpu::GenerateWarpgroupDescriptorOp> {
   using ConvertOpToLLVMPattern<
-      nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern;
+      nvgpu::GenerateWarpgroupDescriptorOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor,
+  matchAndRewrite(nvgpu::GenerateWarpgroupDescriptorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
     Location loc = op->getLoc();
@@ -1037,7 +1037,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
     // // [0,14)   start_address
     dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
 
-    LLVM_DEBUG(DBGS() << "Generating wgmma.descriptor: "
+    LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
                       << "leading_off:" << leadDimVal << "\t"
                       << "stride_off :" << strideDimVal << "\t"
                       << "base_offset:" << offsetVal << "\t"
@@ -1320,8 +1320,8 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUTmaAsyncLoadOpLowering,           // nvgpu.tma.async.load
       NVGPUTmaCreateDescriptorOpLowering,    // nvgpu.tma.create.descriptor
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
-      NVGPUGenerateGmmaDescriptorLowering,   // nvgpu.wgmma.generate.descriptor
-      NVGPUWarpgroupMmaOpLowering,           // nvgpu.warpgroup.mma
+      NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
+      NVGPUWarpgroupMmaOpLowering,              // nvgpu.warpgroup.mma
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index dfec17986800417..eb8fc4b65bc89ad 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -367,10 +367,10 @@ LogicalResult TmaCreateDescriptorOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// NVGPU_GenerateGmmaDescriptorOp
+// NVGPU_GenerateWarpgroupDescriptorOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult GenerateGmmaDescriptorOp::verify() {
+LogicalResult GenerateWarpgroupDescriptorOp::verify() {
   MemRefType memrefType = getTensor().getType();
   MemRefType tensorMapType = getTensorMap().getType().getTensor();
 
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 8c2f8dbbd5ad9a3..3710b06288e2a7f 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -674,7 +674,7 @@ module @mymodule {
 !tensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16,3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
 memref.global "private" @dynamicShmem : memref<0xf16,3>
 // CHECK-LABEL: func @create_wgmma_descriptor(
-func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>{
+func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>{
   %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
   %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
     // CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
@@ -706,22 +706,22 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc
     // CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64
     // CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]]  : i64
     // CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]]  : i64
-    // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>> 
+    // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> 
     // CHECK: return %[[ret]]
-  %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
-  func.return %descA : !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
+  %descA = nvgpu.warpgroup.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>
+  func.return %descA : !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>
 }
 
 // CHECK-LABEL: @warpgroup_mma_128_128_64(  
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
 func.func @warpgroup_mma_128_128_64(
-      %descA: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, 
-      %descB: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, 
+      %descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
+      %descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
       %acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
       %acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>) 
 {
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
 // CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: nvvm.wgmma.fence.aligned
@@ -762,8 +762,8 @@ func.func @warpgroup_mma_128_128_64(
 // CHECK: nvvm.wgmma.commit.group.sync.aligned
 // CHECK: nvvm.wgmma.wait.group.sync.aligned 1  
   %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}: 
-      !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, 
-      !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, 
+      !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
+      !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
       !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
       !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> 
       -> 
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index ff391e469815d74..66652070ec15f34 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -225,8 +225,8 @@ func.func @async_cp_size_invalid_f64(
 // -----
 
 !tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
-!tDescA  = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB  = !nvgpu.wgmma.descriptor<tensor = memref<64x121xf16, 3>>
+!tDescA  = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB  = !nvgpu.warpgroup.descriptor<tensor = memref<64x121xf16, 3>>
 
 func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
   // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}}  
@@ -237,8 +237,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
 // -----
 
 !tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<128xf32>>
-!tDescA  = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB  = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>
+!tDescA  = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB  = !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>
 func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
   // expected-error @+1 {{'nvgpu.warpgroup.mma' op has matrices A, B, C and D, they must be 2 dimensional}}  
   %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
@@ -247,8 +247,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
 
 // -----
 !tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
-!tDescA  = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB  = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf32, 3>>
+!tDescA  = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB  = !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf32, 3>>
 func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
   // expected-error @+1 {{'nvgpu.warpgroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}}  
   %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
@@ -258,8 +258,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
 // -----
 
 !tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
-!tDescA  = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
-!tDescB  = !nvgpu.wgmma.descriptor<tensor = memref<64x512xf16, 3>>
+!tDescA  = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB  = !nvgpu.warpgroup.descriptor<tensor = memref<64x512xf16, 3>>
 func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
   // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 512 ) != 2nd dim matrix-C ( 128 )}}
   %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult

@grypp grypp merged commit 6dc7717 into llvm:main Oct 5, 2023
@grypp grypp deleted the changename branch October 5, 2023 07:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants