Skip to content

[mlir][AMDGPU] Add int4 intrinsics, mixed-type fp8 to handle gfx12 #128963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 27, 2025

Conversation

krzysz00
Copy link
Contributor

  1. Extend the gfx12 FP8 support to allow mixed-type intrinsics (since they've been added), creating limited mixed-type support that mirrors MFMA
  2. Extend the amdgpu.wmma intrinsic lowering to correctly handle shorter vectors because gfx12 now has instructions that logically take a 4xi8, or, as far as LLVM's concerned, an i32. Similarly, there are 4xi4 inputs, which are an i16 (that must be zero-extended to i32).
  3. Correctly handle the ambiguities in the int4 intrinsics on gfx12, which can either be 16x16x16 or 16x16x32
  4. Add tests showing all WMMAs being lowered the way gfx12 expects (mirroring LLVM's tests)
  5. Add a verifier to prevent emiting ilegal instructions on gfx12.

1. Extend the gfx12 FP8 support to allow mixed-type intrinsics (since
they've been added), creating limited mixed-type support that mirrors
MFMA
2. Extend the `amdgpu.wmma` intrinsic lowering to correctly handle
shorter vectors because gfx12 now has instructions that logically take
a 4xi8, or, as far as LLVM's concerned, an i32. Similarly, there are
4xi4 inputs, which are an i16 (that must be zero-extended to i32).
3. Correctly handle the ambiguities in the int4 intrinsics on gfx12,
which can either be 16x16x16 or 16x16x32
4. Add tests showing all WMMAs being lowered the way gfx12
expects (mirroring LLVM's tests)
5. Add a verifier to prevent emiting ilegal instructions on gfx12.
@llvmbot
Copy link
Member

llvmbot commented Feb 26, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-gpu

Author: Krzysztof Drewniak (krzysz00)

Changes
  1. Extend the gfx12 FP8 support to allow mixed-type intrinsics (since they've been added), creating limited mixed-type support that mirrors MFMA
  2. Extend the amdgpu.wmma intrinsic lowering to correctly handle shorter vectors because gfx12 now has instructions that logically take a 4xi8, or, as far as LLVM's concerned, an i32. Similarly, there are 4xi4 inputs, which are an i16 (that must be zero-extended to i32).
  3. Correctly handle the ambiguities in the int4 intrinsics on gfx12, which can either be 16x16x16 or 16x16x32
  4. Add tests showing all WMMAs being lowered the way gfx12 expects (mirroring LLVM's tests)
  5. Add a verifier to prevent emiting ilegal instructions on gfx12.

Patch is 25.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128963.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+16-8)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+6-3)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+60-16)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+16)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+64-5)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir (+12-6)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+11-1)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f795dd89b79a1..2cb60b3836416 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -552,9 +552,14 @@ def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
+                             [4, 8, 16],
+                             [F16, BF16,
+                              I8, SI8, UI8,
+                              I<4>, SI<4>, UI<4>,
+                              F8E4M3FN, F8E5M2]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
-                              VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
+                              VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
 def AMDGPU_MFMAOp :
     AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
@@ -615,8 +620,7 @@ def AMDGPU_MFMAOp :
 
 def AMDGPU_WMMAOp :
     AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
-                       AllTypesMatch<["sourceA", "sourceB"]>,
-                        Pure]>,
+                       Pure]>,
     Arguments<(ins
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
@@ -629,13 +633,17 @@ def AMDGPU_WMMAOp :
   let summary = "MLIR wrapper for RDNA3 wmma instructions";
   let description = [{
     The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
-    for various `wmma` instructions in the RDNA3 architecture, which perform
-    a 16x16 matrix multiplication for different data types.
+    for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
+    perform a 16x16 * 16x16 matrix multiplication for different data types.
+    Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
+    integer inputs.
 
-    When emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 (or 16xbf16) vector
-    containing only 8 valid values:
+    On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
+    (or 16xbf16) vector containing only 8 valid values:
       - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
       - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
+    On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
+    all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
 
     `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 673ea480ad3fa..18fec95f700c4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -410,8 +410,11 @@ def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16",
 def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
 def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
 // Available from gfx12
-def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
+def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
+def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
+def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
+def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
+def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
 
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
@@ -771,7 +774,7 @@ def ROCDL_CvtScaleF32Bf8Op :
     Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
   let summary = "Scale and convert bf8 to f32";
   let description = [{
-    Scale `src` by the exponent in `scale` then convert 8-bit bf8 value 
+    Scale `src` by the exponent in `scale` then convert 8-bit bf8 value
     from the `byteSel`th bit of `src` to fp32.
   }];
   let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b8574bbbee345..4718578703a15 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -403,8 +403,11 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
 /// Push an input operand. If it is a float type, nothing to do. If it is
 /// an integer type, then we need to also push its signdness (1 for signed, 0
 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
-/// vector. We also need to convert bfloat inputs to i16 to account for the lack
-/// of bfloat support in the WMMA intrinsics themselves.
+/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
+/// We also need to convert bfloat inputs to i16 to account for the bfloat
+/// intrinsics having been defined before the AMD backend supported bfloat. We
+/// similarly need to pack 8-bit float types into integers as if they were i8
+/// (which they are for the backend's purposes).
 static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
                                  Location loc,
                                  const TypeConverter *typeConverter,
@@ -413,12 +416,16 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
                                  SmallVector<Value, 4> &operands) {
   Type inputType = llvmInput.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
+  if (!vectorType) {
+    operands.push_back(llvmInput);
+    return;
+  }
   Type elemType = vectorType.getElementType();
 
   if (elemType.isBF16())
     llvmInput = rewriter.create<LLVM::BitcastOp>(
         loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
-  if (!elemType.isInteger(8)) {
+  if (elemType.getIntOrFloatBitWidth() > 8) {
     operands.push_back(llvmInput);
     return;
   }
@@ -427,25 +434,33 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
   // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
   // fp8/int8 information is lost during the conversion process.
   auto mlirInputType = cast<VectorType>(mlirInput.getType());
-  bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
-  if (isInputInt8) {
+  bool isInputInteger = mlirInputType.getElementType().isInteger();
+  if (isInputInteger) {
     // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
     bool localIsUnsigned = isUnsigned;
-    if (elemType.isUnsignedInteger(8)) {
+    if (elemType.isUnsignedInteger()) {
       localIsUnsigned = true;
-    } else if (elemType.isSignedInteger(8)) {
+    } else if (elemType.isSignedInteger()) {
       localIsUnsigned = false;
     }
     Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
     operands.push_back(sign);
   }
 
-  int64_t numBytes = vectorType.getNumElements();
+  int64_t numBits =
+      vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
   Type i32 = rewriter.getI32Type();
-  VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
-  auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
+  Type intrinsicInType = numBits <= 32
+                             ? (Type)rewriter.getIntegerType(numBits)
+                             : (Type)VectorType::get(numBits / 32, i32);
+  auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
   Value result = rewriter.createOrFold<LLVM::BitcastOp>(
-      loc, llvmVectorType32bits, llvmInput);
+      loc, llvmIntrinsicInType, llvmInput);
+  // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
+  // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
+  // Add in the zeros here.
+  if (numBits < 32)
+    result = rewriter.create<LLVM::ZExtOp>(loc, i32, result);
   operands.push_back(result);
 }
 
@@ -454,7 +469,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 /// since the same numbers of VGPRs is used, we need to decide if to store the
 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
-/// be stored it in the upper part
+/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
+/// as the instructions have been changed to return fewer registers instead.
 static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
                                   Location loc,
                                   const TypeConverter *typeConverter,
@@ -617,8 +633,10 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
                                                   Chipset chipset) {
   auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
+  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
   auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
   auto elemSourceType = sourceVectorType.getElementType();
+  auto elemBSourceType = sourceBVectorType.getElementType();
   auto elemDestType = destVectorType.getElementType();
 
   if (elemSourceType.isF16() && elemDestType.isF32())
@@ -631,10 +649,33 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
   if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
     return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
-  if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
+  if (chipset.majorVersion == 11) {
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+  }
+  if (chipset.majorVersion >= 12) {
+    if (isa<Float8E4M3FNType>(elemSourceType) &&
+        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
+    if (isa<Float8E4M3FNType>(elemSourceType) &&
+        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
+    if (isa<Float8E5M2Type>(elemSourceType) &&
+        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
+    if (isa<Float8E5M2Type>(elemSourceType) &&
+        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
+      bool isWave64 = destVectorType.getNumElements() == 4;
+      // This is the ambiguous case. 8 inputs to the wave64 version means that
+      // we want the 16x16x32 version, but for wave32 they mean the short form.
+      bool has8Inputs = sourceVectorType.getNumElements() == 8;
+      if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
+        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+    }
+  }
   return std::nullopt;
 }
 
@@ -712,6 +753,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     if (!maybeIntrinsic.has_value())
       return op.emitOpError("no intrinsic matching WMMA on the given chipset");
 
+    if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
+      return op.emitOpError("subwordOffset not supported on gfx12+");
+
     OperationState loweredOp(loc, *maybeIntrinsic);
     loweredOp.addTypes(rawOutType);
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 271ca382e2f0b..4641fbb280bcb 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -226,14 +226,23 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 LogicalResult WMMAOp::verify() {
   Type sourceAType = getSourceA().getType();
+  Type sourceBType = getSourceB().getType();
   Type destType = getDestC().getType();
 
   VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
+  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
   VectorType destVectorType = dyn_cast<VectorType>(destType);
 
   Type sourceAElemType = sourceVectorAType.getElementType();
+  Type sourceBElemType = sourceVectorBType.getElementType();
   Type destElemType = destVectorType.getElementType();
 
+  if (sourceVectorAType.getNumElements() !=
+      sourceVectorBType.getNumElements()) {
+    return emitOpError("source vectors have different lengths: ")
+           << sourceVectorAType << " vs. " << sourceVectorBType;
+  }
+
   bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
   bool isSrcFloat =
       isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
@@ -247,6 +256,13 @@ LogicalResult WMMAOp::verify() {
     return emitOpError("Expected int sources with int destination");
   }
 
+  if (sourceAElemType != sourceBElemType &&
+      !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
+        isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+    return emitOpError(
+               "source element types much match (except for fp8) but have ")
+           << sourceAType << " and " << sourceBType;
+  }
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 7b2b524d4af42..94a1b78d5f040 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,9 +1,68 @@
 // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
-func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>,  %arg2 : vector<8xf32>) {
-  // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+// CHECK-LABEL: @wmma_to_rocdl
+func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
+                         %arg2 : vector<8xf32>, %arg3 : vector<4xf32>,
+                         %arg4 : vector<8xbf16>, %arg5 : vector<4xbf16>,
+                         %arg6 : vector<8xf8E4M3FN>, %arg7 : vector<4xf8E4M3FN>,
+                         %arg8 : vector<8xf8E5M2>, %arg9 : vector<4xf8E5M2>,
+                         %arg10 : vector<8xi8>, %arg11 : vector<4xi8>,
+                         %arg12 : vector<8xi32>, %arg13 : vector<4xi32>,
+                         %arg14 : vector<16xi4>, %arg15 : vector<8xi4>, %arg16 : vector<4xi4>) {
+  // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16>
+  amdgpu.wmma %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
+  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16>
+  amdgpu.wmma %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
+
+  // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16>
+  // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
+  amdgpu.wmma %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
+  // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16>
+  amdgpu.wmma %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  amdgpu.wmma %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+  amdgpu.wmma %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
+
+  // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  amdgpu.wmma %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+  amdgpu.wmma %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
+
+  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32>
+  amdgpu.wmma %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+  amdgpu.wmma %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
 
-  // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
   func.return
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
index 7b1...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 26, 2025

@llvm/pr-subscribers-mlir-amdgpu

Author: Krzysztof Drewniak (krzysz00)

Changes
  1. Extend the gfx12 FP8 support to allow mixed-type intrinsics (since they've been added), creating limited mixed-type support that mirrors MFMA
  2. Extend the amdgpu.wmma intrinsic lowering to correctly handle shorter vectors because gfx12 now has instructions that logically take a 4xi8, or, as far as LLVM's concerned, an i32. Similarly, there are 4xi4 inputs, which are an i16 (that must be zero-extended to i32).
  3. Correctly handle the ambiguities in the int4 intrinsics on gfx12, which can either be 16x16x16 or 16x16x32
  4. Add tests showing all WMMAs being lowered the way gfx12 expects (mirroring LLVM's tests)
  5. Add a verifier to prevent emiting ilegal instructions on gfx12.

Patch is 25.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128963.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+16-8)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+6-3)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+60-16)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+16)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+64-5)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir (+12-6)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+11-1)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f795dd89b79a1..2cb60b3836416 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -552,9 +552,14 @@ def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
+                             [4, 8, 16],
+                             [F16, BF16,
+                              I8, SI8, UI8,
+                              I<4>, SI<4>, UI<4>,
+                              F8E4M3FN, F8E5M2]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
-                              VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
+                              VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
 def AMDGPU_MFMAOp :
     AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
@@ -615,8 +620,7 @@ def AMDGPU_MFMAOp :
 
 def AMDGPU_WMMAOp :
     AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
-                       AllTypesMatch<["sourceA", "sourceB"]>,
-                        Pure]>,
+                       Pure]>,
     Arguments<(ins
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
@@ -629,13 +633,17 @@ def AMDGPU_WMMAOp :
   let summary = "MLIR wrapper for RDNA3 wmma instructions";
   let description = [{
     The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
-    for various `wmma` instructions in the RDNA3 architecture, which perform
-    a 16x16 matrix multiplication for different data types.
+    for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
+    perform a 16x16 * 16x16 matrix multiplication for different data types.
+    Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
+    integer inputs.
 
-    When emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 (or 16xbf16) vector
-    containing only 8 valid values:
+    On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
+    (or 16xbf16) vector containing only 8 valid values:
       - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
       - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
+    On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
+    all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
 
     `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 673ea480ad3fa..18fec95f700c4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -410,8 +410,11 @@ def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16",
 def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
 def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
 // Available from gfx12
-def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
+def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
+def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
+def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
+def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
+def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
 
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
@@ -771,7 +774,7 @@ def ROCDL_CvtScaleF32Bf8Op :
     Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
   let summary = "Scale and convert bf8 to f32";
   let description = [{
-    Scale `src` by the exponent in `scale` then convert 8-bit bf8 value 
+    Scale `src` by the exponent in `scale` then convert 8-bit bf8 value
     from the `byteSel`th bit of `src` to fp32.
   }];
   let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b8574bbbee345..4718578703a15 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -403,8 +403,11 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
 /// Push an input operand. If it is a float type, nothing to do. If it is
 /// an integer type, then we need to also push its signdness (1 for signed, 0
 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
-/// vector. We also need to convert bfloat inputs to i16 to account for the lack
-/// of bfloat support in the WMMA intrinsics themselves.
+/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
+/// We also need to convert bfloat inputs to i16 to account for the bfloat
+/// intrinsics having been defined before the AMD backend supported bfloat. We
+/// similarly need to pack 8-bit float types into integers as if they were i8
+/// (which they are for the backend's purposes).
 static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
                                  Location loc,
                                  const TypeConverter *typeConverter,
@@ -413,12 +416,16 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
                                  SmallVector<Value, 4> &operands) {
   Type inputType = llvmInput.getType();
   auto vectorType = dyn_cast<VectorType>(inputType);
+  if (!vectorType) {
+    operands.push_back(llvmInput);
+    return;
+  }
   Type elemType = vectorType.getElementType();
 
   if (elemType.isBF16())
     llvmInput = rewriter.create<LLVM::BitcastOp>(
         loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
-  if (!elemType.isInteger(8)) {
+  if (elemType.getIntOrFloatBitWidth() > 8) {
     operands.push_back(llvmInput);
     return;
   }
@@ -427,25 +434,33 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
   // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
   // fp8/int8 information is lost during the conversion process.
   auto mlirInputType = cast<VectorType>(mlirInput.getType());
-  bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
-  if (isInputInt8) {
+  bool isInputInteger = mlirInputType.getElementType().isInteger();
+  if (isInputInteger) {
     // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
     bool localIsUnsigned = isUnsigned;
-    if (elemType.isUnsignedInteger(8)) {
+    if (elemType.isUnsignedInteger()) {
       localIsUnsigned = true;
-    } else if (elemType.isSignedInteger(8)) {
+    } else if (elemType.isSignedInteger()) {
       localIsUnsigned = false;
     }
     Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
     operands.push_back(sign);
   }
 
-  int64_t numBytes = vectorType.getNumElements();
+  int64_t numBits =
+      vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
   Type i32 = rewriter.getI32Type();
-  VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
-  auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
+  Type intrinsicInType = numBits <= 32
+                             ? (Type)rewriter.getIntegerType(numBits)
+                             : (Type)VectorType::get(numBits / 32, i32);
+  auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
   Value result = rewriter.createOrFold<LLVM::BitcastOp>(
-      loc, llvmVectorType32bits, llvmInput);
+      loc, llvmIntrinsicInType, llvmInput);
+  // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
+  // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
+  // Add in the zeros here.
+  if (numBits < 32)
+    result = rewriter.create<LLVM::ZExtOp>(loc, i32, result);
   operands.push_back(result);
 }
 
@@ -454,7 +469,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 /// since the same numbers of VGPRs is used, we need to decide if to store the
 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
-/// be stored it in the upper part
+/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
+/// as the instructions have been changed to return fewer registers instead.
 static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
                                   Location loc,
                                   const TypeConverter *typeConverter,
@@ -617,8 +633,10 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
                                                   Chipset chipset) {
   auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
+  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
   auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
   auto elemSourceType = sourceVectorType.getElementType();
+  auto elemBSourceType = sourceBVectorType.getElementType();
   auto elemDestType = destVectorType.getElementType();
 
   if (elemSourceType.isF16() && elemDestType.isF32())
@@ -631,10 +649,33 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
   if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
     return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
-  if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
+  if (chipset.majorVersion == 11) {
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+  }
+  if (chipset.majorVersion >= 12) {
+    if (isa<Float8E4M3FNType>(elemSourceType) &&
+        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
+    if (isa<Float8E4M3FNType>(elemSourceType) &&
+        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
+    if (isa<Float8E5M2Type>(elemSourceType) &&
+        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
+    if (isa<Float8E5M2Type>(elemSourceType) &&
+        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
+      bool isWave64 = destVectorType.getNumElements() == 4;
+      // This is the ambiguous case. 8 inputs to the wave64 version means that
+      // we want the 16x16x32 version, but for wave32 they mean the short form.
+      bool has8Inputs = sourceVectorType.getNumElements() == 8;
+      if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
+        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+    }
+  }
   return std::nullopt;
 }
 
@@ -712,6 +753,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
     if (!maybeIntrinsic.has_value())
       return op.emitOpError("no intrinsic matching WMMA on the given chipset");
 
+    if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
+      return op.emitOpError("subwordOffset not supported on gfx12+");
+
     OperationState loweredOp(loc, *maybeIntrinsic);
     loweredOp.addTypes(rawOutType);
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 271ca382e2f0b..4641fbb280bcb 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -226,14 +226,23 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 LogicalResult WMMAOp::verify() {
   Type sourceAType = getSourceA().getType();
+  Type sourceBType = getSourceB().getType();
   Type destType = getDestC().getType();
 
   VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
+  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
   VectorType destVectorType = dyn_cast<VectorType>(destType);
 
   Type sourceAElemType = sourceVectorAType.getElementType();
+  Type sourceBElemType = sourceVectorBType.getElementType();
   Type destElemType = destVectorType.getElementType();
 
+  if (sourceVectorAType.getNumElements() !=
+      sourceVectorBType.getNumElements()) {
+    return emitOpError("source vectors have different lengths: ")
+           << sourceVectorAType << " vs. " << sourceVectorBType;
+  }
+
   bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
   bool isSrcFloat =
       isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
@@ -247,6 +256,13 @@ LogicalResult WMMAOp::verify() {
     return emitOpError("Expected int sources with int destination");
   }
 
+  if (sourceAElemType != sourceBElemType &&
+      !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
+        isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+    return emitOpError(
+               "source element types much match (except for fp8) but have ")
+           << sourceAType << " and " << sourceBType;
+  }
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 7b2b524d4af42..94a1b78d5f040 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,9 +1,68 @@
 // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
-func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>,  %arg2 : vector<8xf32>) {
-  // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+// CHECK-LABEL: @wmma_to_rocdl
+func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
+                         %arg2 : vector<8xf32>, %arg3 : vector<4xf32>,
+                         %arg4 : vector<8xbf16>, %arg5 : vector<4xbf16>,
+                         %arg6 : vector<8xf8E4M3FN>, %arg7 : vector<4xf8E4M3FN>,
+                         %arg8 : vector<8xf8E5M2>, %arg9 : vector<4xf8E5M2>,
+                         %arg10 : vector<8xi8>, %arg11 : vector<4xi8>,
+                         %arg12 : vector<8xi32>, %arg13 : vector<4xi32>,
+                         %arg14 : vector<16xi4>, %arg15 : vector<8xi4>, %arg16 : vector<4xi4>) {
+  // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16>
+  amdgpu.wmma %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
+  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16>
+  amdgpu.wmma %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
+
+  // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16>
+  // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
+  amdgpu.wmma %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
+  // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16>
+  amdgpu.wmma %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  amdgpu.wmma %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+  amdgpu.wmma %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
+
+  // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  amdgpu.wmma %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+  amdgpu.wmma %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
+
+  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32>
+  amdgpu.wmma %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
+  amdgpu.wmma %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
 
-  // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
   func.return
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
index 7b1...
[truncated]

Copy link
Contributor

@dhernandez0 dhernandez0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

- If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
- If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the output can be f32 or i32 as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in the context of the f16/bf16-outputting instructions

Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
Value result = rewriter.createOrFold<LLVM::BitcastOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is an operand, it seems confusing to call it "result"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result of the function, but yeah, could be argument

// This is the ambiguous case. 8 inputs to the wave64 version means that
// we want the 16x16x32 version, but for wave32 they mean the short form.
bool has8Inputs = sourceVectorType.getNumElements() == 8;
if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if(isWave64 == has8Inputs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured I'd use a somewhat verbose chunk of logic to make it clear what the cases are

@krzysz00 krzysz00 merged commit b31175a into llvm:main Feb 27, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants