diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index 31b1371605457..3e657da52be5f 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -174,7 +174,7 @@ def NVGPU_TensorMapDescriptor : NVGPU_Type<"TensorMapDescriptor", "tensormap.des let assemblyFormat = "`<` struct(params) `>`"; } -def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "wgmma.descriptor", []> { +def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "warpgroup.descriptor", []> { let summary = "Warpgroup matrix descriptor type"; let description = [{ The descriptor specifies the properties of the matrix in shared memory that @@ -667,11 +667,12 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> { let hasVerifier = 1; } -def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> { - let summary = "Generate a wgmma matrix descriptor"; +def NVGPU_GenerateWarpgroupDescriptorOp : NVGPU_Op<"warpgroup.generate.descriptor", []> { + let summary = "Generate a warpgroup matrix descriptor"; let description = [{ - This Op builds a `nvgpu.wgmma.descriptor` that is used by warpgroup-level - matrix multiply and accumulate. + This Op builds a `nvgpu.warpgroup.descriptor` that is used by + `nvgpu.warpgroup.mma` to perform warpgroup-level matrix multiply and + accumulate. The descriptor specifies the properties of the matrix in shared memory that is a multiplicand in the matrix multiply and accumulate operation. @@ -702,9 +703,9 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> { Example: ```mlir - %r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2: - !nvgpu.wgmma.descriptor>, - !nvgpu.wgmma.descriptor>, + %r1,%r2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2: + !nvgpu.warpgroup.descriptor>, + !nvgpu.warpgroup.descriptor>, !nvgpu.warpgroup.accumulator>, !nvgpu.warpgroup.accumulator> -> diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 4d1f6641af6dc..d4bca1d8c8465 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -967,13 +967,13 @@ struct NVGPUTmaAsyncLoadOpLowering return success(); } }; -struct NVGPUGenerateGmmaDescriptorLowering - : public ConvertOpToLLVMPattern { +struct NVGPUGenerateWarpgroupDescriptorLowering + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< - nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern; + nvgpu::GenerateWarpgroupDescriptorOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor, + matchAndRewrite(nvgpu::GenerateWarpgroupDescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); @@ -1037,7 +1037,7 @@ struct NVGPUGenerateGmmaDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LLVM_DEBUG(DBGS() << "Generating wgmma.descriptor: " + LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " << "leading_off:" << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t" << "base_offset:" << offsetVal << "\t" @@ -1320,8 +1320,8 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx - NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor - NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma + NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor + NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index dfec179868004..eb8fc4b65bc89 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -367,10 +367,10 @@ LogicalResult TmaCreateDescriptorOp::verify() { } //===----------------------------------------------------------------------===// -// NVGPU_GenerateGmmaDescriptorOp +// NVGPU_GenerateWarpgroupDescriptorOp //===----------------------------------------------------------------------===// -LogicalResult GenerateGmmaDescriptorOp::verify() { +LogicalResult GenerateWarpgroupDescriptorOp::verify() { MemRefType memrefType = getTensor().getType(); MemRefType tensorMapType = getTensorMap().getType().getTensor(); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 8c2f8dbbd5ad9..3710b06288e2a 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -674,7 +674,7 @@ module @mymodule { !tensorMap = !nvgpu.tensormap.descriptor, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none> memref.global "private" @dynamicShmem : memref<0xf16,3> // CHECK-LABEL: func @create_wgmma_descriptor( -func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.descriptor>{ +func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.warpgroup.descriptor>{ %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3> // CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3> @@ -706,22 +706,22 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc // CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]] : i64 // CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]] : i64 - // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.wgmma.descriptor> + // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.warpgroup.descriptor> // CHECK: return %[[ret]] - %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.wgmma.descriptor> - func.return %descA : !nvgpu.wgmma.descriptor> + %descA = nvgpu.warpgroup.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.warpgroup.descriptor> + func.return %descA : !nvgpu.warpgroup.descriptor> } // CHECK-LABEL: @warpgroup_mma_128_128_64( -// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>) +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>) func.func @warpgroup_mma_128_128_64( - %descA: !nvgpu.wgmma.descriptor>, - %descB: !nvgpu.wgmma.descriptor>, + %descA: !nvgpu.warpgroup.descriptor>, + %descB: !nvgpu.warpgroup.descriptor>, %acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, %acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>) { -// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.wgmma.descriptor> to i64 -// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.wgmma.descriptor> to i64 +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor> to i64 +// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor> to i64 // CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvvm.wgmma.fence.aligned @@ -762,8 +762,8 @@ func.func @warpgroup_mma_128_128_64( // CHECK: nvvm.wgmma.commit.group.sync.aligned // CHECK: nvvm.wgmma.wait.group.sync.aligned 1 %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}: - !nvgpu.wgmma.descriptor>, - !nvgpu.wgmma.descriptor>, + !nvgpu.warpgroup.descriptor>, + !nvgpu.warpgroup.descriptor>, !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> -> diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir index ff391e469815d..66652070ec15f 100644 --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -225,8 +225,8 @@ func.func @async_cp_size_invalid_f64( // ----- !tResult = !nvgpu.warpgroup.accumulator> -!tDescA = !nvgpu.wgmma.descriptor> -!tDescB = !nvgpu.wgmma.descriptor> +!tDescA = !nvgpu.warpgroup.descriptor> +!tDescB = !nvgpu.warpgroup.descriptor> func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}} @@ -237,8 +237,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t // ----- !tResult = !nvgpu.warpgroup.accumulator> -!tDescA = !nvgpu.wgmma.descriptor> -!tDescB = !nvgpu.wgmma.descriptor> +!tDescA = !nvgpu.warpgroup.descriptor> +!tDescB = !nvgpu.warpgroup.descriptor> func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { // expected-error @+1 {{'nvgpu.warpgroup.mma' op has matrices A, B, C and D, they must be 2 dimensional}} %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult @@ -247,8 +247,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t // ----- !tResult = !nvgpu.warpgroup.accumulator> -!tDescA = !nvgpu.wgmma.descriptor> -!tDescB = !nvgpu.wgmma.descriptor> +!tDescA = !nvgpu.warpgroup.descriptor> +!tDescB = !nvgpu.warpgroup.descriptor> func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { // expected-error @+1 {{'nvgpu.warpgroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}} %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult @@ -258,8 +258,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t // ----- !tResult = !nvgpu.warpgroup.accumulator> -!tDescA = !nvgpu.wgmma.descriptor> -!tDescB = !nvgpu.wgmma.descriptor> +!tDescA = !nvgpu.warpgroup.descriptor> +!tDescB = !nvgpu.warpgroup.descriptor> func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 512 ) != 2nd dim matrix-C ( 128 )}} %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult