Skip to content

Commit b31175a

Browse files
authored
[mlir][AMDGPU] Add int4 intrinsics, mixed-type fp8 to handle gfx12 (#128963)
1. Extend the gfx12 FP8 support to allow mixed-type intrinsics (since they've been added), creating limited mixed-type support that mirrors MFMA 2. Extend the `amdgpu.wmma` intrinsic lowering to correctly handle shorter vectors because gfx12 now has instructions that logically take a 4xi8, or, as far as LLVM's concerned, an i32. Similarly, there are 4xi4 inputs, which are an i16 (that must be zero-extended to i32). 3. Correctly handle the ambiguities in the int4 intrinsics on gfx12, which can either be 16x16x16 or 16x16x32 4. Add tests showing all WMMAs being lowered the way gfx12 expects (mirroring LLVM's tests) 5. Add a verifier to prevent emiting ilegal instructions on gfx12.
1 parent 64ae0a1 commit b31175a

File tree

7 files changed

+187
-41
lines changed

7 files changed

+187
-41
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,14 @@ def MFMAOutTypes : AnyTypeOf<[F64,
657657
VectorOfLengthAndType<[4, 16, 32], [I32]>,
658658
VectorOfLengthAndType<[4], [F64]>]>;
659659
// wmma
660-
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
660+
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
661+
[4, 8, 16],
662+
[F16, BF16,
663+
I8, SI8, UI8,
664+
I<4>, SI<4>, UI<4>,
665+
F8E4M3FN, F8E5M2]>]>;
661666
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
662-
VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
667+
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
663668

664669
def AMDGPU_MFMAOp :
665670
AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
@@ -720,8 +725,7 @@ def AMDGPU_MFMAOp :
720725

721726
def AMDGPU_WMMAOp :
722727
AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
723-
AllTypesMatch<["sourceA", "sourceB"]>,
724-
Pure]>,
728+
Pure]>,
725729
Arguments<(ins
726730
WMMAInTypes:$sourceA,
727731
WMMAInTypes:$sourceB,
@@ -734,13 +738,17 @@ def AMDGPU_WMMAOp :
734738
let summary = "MLIR wrapper for RDNA3 wmma instructions";
735739
let description = [{
736740
The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
737-
for various `wmma` instructions in the RDNA3 architecture, which perform
738-
a 16x16 matrix multiplication for different data types.
741+
for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
742+
perform a 16x16 * 16x16 matrix multiplication for different data types.
743+
Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
744+
integer inputs.
739745

740-
When emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 (or 16xbf16) vector
741-
containing only 8 valid values:
746+
On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
747+
(or 16xbf16) vector containing only 8 valid values:
742748
- If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
743749
- If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
750+
On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
751+
all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
744752

745753
`unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
746754

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,11 @@ def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16",
410410
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
411411
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
412412
// Available from gfx12
413-
def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
414-
def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
413+
def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
414+
def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
415+
def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
416+
def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
417+
def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
415418

416419
//===---------------------------------------------------------------------===//
417420
// LDS transpose intrinsics (available in GFX950)
@@ -771,7 +774,7 @@ def ROCDL_CvtScaleF32Bf8Op :
771774
Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
772775
let summary = "Scale and convert bf8 to f32";
773776
let description = [{
774-
Scale `src` by the exponent in `scale` then convert 8-bit bf8 value
777+
Scale `src` by the exponent in `scale` then convert 8-bit bf8 value
775778
from the `byteSel`th bit of `src` to fp32.
776779
}];
777780
let assemblyFormat = [{

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,11 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
514514
/// Push an input operand. If it is a float type, nothing to do. If it is
515515
/// an integer type, then we need to also push its signdness (1 for signed, 0
516516
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
517-
/// vector. We also need to convert bfloat inputs to i16 to account for the lack
518-
/// of bfloat support in the WMMA intrinsics themselves.
517+
/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
518+
/// We also need to convert bfloat inputs to i16 to account for the bfloat
519+
/// intrinsics having been defined before the AMD backend supported bfloat. We
520+
/// similarly need to pack 8-bit float types into integers as if they were i8
521+
/// (which they are for the backend's purposes).
519522
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
520523
Location loc,
521524
const TypeConverter *typeConverter,
@@ -524,12 +527,16 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
524527
SmallVector<Value, 4> &operands) {
525528
Type inputType = llvmInput.getType();
526529
auto vectorType = dyn_cast<VectorType>(inputType);
530+
if (!vectorType) {
531+
operands.push_back(llvmInput);
532+
return;
533+
}
527534
Type elemType = vectorType.getElementType();
528535

529536
if (elemType.isBF16())
530537
llvmInput = rewriter.create<LLVM::BitcastOp>(
531538
loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
532-
if (!elemType.isInteger(8)) {
539+
if (elemType.getIntOrFloatBitWidth() > 8) {
533540
operands.push_back(llvmInput);
534541
return;
535542
}
@@ -538,34 +545,43 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
538545
// for int8. This is because, in LLVM, fp8 type is converted to int8, so the
539546
// fp8/int8 information is lost during the conversion process.
540547
auto mlirInputType = cast<VectorType>(mlirInput.getType());
541-
bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
542-
if (isInputInt8) {
548+
bool isInputInteger = mlirInputType.getElementType().isInteger();
549+
if (isInputInteger) {
543550
// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
544551
bool localIsUnsigned = isUnsigned;
545-
if (elemType.isUnsignedInteger(8)) {
552+
if (elemType.isUnsignedInteger()) {
546553
localIsUnsigned = true;
547-
} else if (elemType.isSignedInteger(8)) {
554+
} else if (elemType.isSignedInteger()) {
548555
localIsUnsigned = false;
549556
}
550557
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
551558
operands.push_back(sign);
552559
}
553560

554-
int64_t numBytes = vectorType.getNumElements();
561+
int64_t numBits =
562+
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
555563
Type i32 = rewriter.getI32Type();
556-
VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
557-
auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
558-
Value result = rewriter.createOrFold<LLVM::BitcastOp>(
559-
loc, llvmVectorType32bits, llvmInput);
560-
operands.push_back(result);
564+
Type intrinsicInType = numBits <= 32
565+
? (Type)rewriter.getIntegerType(numBits)
566+
: (Type)VectorType::get(numBits / 32, i32);
567+
auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
568+
Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
569+
loc, llvmIntrinsicInType, llvmInput);
570+
// The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
571+
// (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
572+
// Add in the zeros here.
573+
if (numBits < 32)
574+
castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
575+
operands.push_back(castInput);
561576
}
562577

563578
/// Push the output operand. For many cases this is only pushing the output in
564579
/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
565580
/// since the same numbers of VGPRs is used, we need to decide if to store the
566581
/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
567582
/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
568-
/// be stored it in the upper part
583+
/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
584+
/// as the instructions have been changed to return fewer registers instead.
569585
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
570586
Location loc,
571587
const TypeConverter *typeConverter,
@@ -728,8 +744,10 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
728744
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
729745
Chipset chipset) {
730746
auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
747+
auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
731748
auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
732749
auto elemSourceType = sourceVectorType.getElementType();
750+
auto elemBSourceType = sourceBVectorType.getElementType();
733751
auto elemDestType = destVectorType.getElementType();
734752

735753
if (elemSourceType.isF16() && elemDestType.isF32())
@@ -742,10 +760,33 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
742760
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
743761
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
744762
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
745-
if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
746-
return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
747-
if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
748-
return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
763+
if (chipset.majorVersion == 11) {
764+
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
765+
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
766+
}
767+
if (chipset.majorVersion >= 12) {
768+
if (isa<Float8E4M3FNType>(elemSourceType) &&
769+
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
770+
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
771+
if (isa<Float8E4M3FNType>(elemSourceType) &&
772+
isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
773+
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
774+
if (isa<Float8E5M2Type>(elemSourceType) &&
775+
isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
776+
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
777+
if (isa<Float8E5M2Type>(elemSourceType) &&
778+
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
779+
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
780+
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
781+
bool isWave64 = destVectorType.getNumElements() == 4;
782+
// This is the ambiguous case. 8 inputs to the wave64 version means that
783+
// we want the 16x16x32 version, but for wave32 they mean the short form.
784+
bool has8Inputs = sourceVectorType.getNumElements() == 8;
785+
if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
786+
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
787+
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
788+
}
789+
}
749790
return std::nullopt;
750791
}
751792

@@ -823,6 +864,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
823864
if (!maybeIntrinsic.has_value())
824865
return op.emitOpError("no intrinsic matching WMMA on the given chipset");
825866

867+
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
868+
return op.emitOpError("subwordOffset not supported on gfx12+");
869+
826870
OperationState loweredOp(loc, *maybeIntrinsic);
827871
loweredOp.addTypes(rawOutType);
828872

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,23 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
279279
//===----------------------------------------------------------------------===//
280280
LogicalResult WMMAOp::verify() {
281281
Type sourceAType = getSourceA().getType();
282+
Type sourceBType = getSourceB().getType();
282283
Type destType = getDestC().getType();
283284

284285
VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
286+
VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
285287
VectorType destVectorType = dyn_cast<VectorType>(destType);
286288

287289
Type sourceAElemType = sourceVectorAType.getElementType();
290+
Type sourceBElemType = sourceVectorBType.getElementType();
288291
Type destElemType = destVectorType.getElementType();
289292

293+
if (sourceVectorAType.getNumElements() !=
294+
sourceVectorBType.getNumElements()) {
295+
return emitOpError("source vectors have different lengths: ")
296+
<< sourceVectorAType << " vs. " << sourceVectorBType;
297+
}
298+
290299
bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
291300
bool isSrcFloat =
292301
isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
@@ -300,6 +309,13 @@ LogicalResult WMMAOp::verify() {
300309
return emitOpError("Expected int sources with int destination");
301310
}
302311

312+
if (sourceAElemType != sourceBElemType &&
313+
!(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
314+
isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
315+
return emitOpError(
316+
"source element types much match (except for fp8) but have ")
317+
<< sourceAType << " and " << sourceBType;
318+
}
303319
return success();
304320
}
305321

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,68 @@
11
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
2-
func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>, %arg2 : vector<8xf32>) {
3-
// CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
4-
amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
2+
// CHECK-LABEL: @wmma_to_rocdl
3+
func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
4+
%arg2 : vector<8xf32>, %arg3 : vector<4xf32>,
5+
%arg4 : vector<8xbf16>, %arg5 : vector<4xbf16>,
6+
%arg6 : vector<8xf8E4M3FN>, %arg7 : vector<4xf8E4M3FN>,
7+
%arg8 : vector<8xf8E5M2>, %arg9 : vector<4xf8E5M2>,
8+
%arg10 : vector<8xi8>, %arg11 : vector<4xi8>,
9+
%arg12 : vector<8xi32>, %arg13 : vector<4xi32>,
10+
%arg14 : vector<16xi4>, %arg15 : vector<8xi4>, %arg16 : vector<4xi4>) {
11+
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf32>) -> vector<8xf32>
12+
amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32>
13+
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
14+
amdgpu.wmma %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32>
15+
16+
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xf32>) -> vector<8xf32>
17+
amdgpu.wmma %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32>
18+
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32>
19+
amdgpu.wmma %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
20+
21+
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16>
22+
amdgpu.wmma %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16>
23+
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16>
24+
amdgpu.wmma %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16>
25+
26+
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16>
27+
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
28+
amdgpu.wmma %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
29+
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16>
30+
amdgpu.wmma %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16>
31+
32+
// CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
33+
amdgpu.wmma %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>
34+
// CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
35+
amdgpu.wmma %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32>
36+
37+
// CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
38+
amdgpu.wmma %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32>
39+
// CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
40+
amdgpu.wmma %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32>
41+
42+
// CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
43+
amdgpu.wmma %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
44+
// CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
45+
amdgpu.wmma %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32>
46+
47+
// CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
48+
amdgpu.wmma %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32>
49+
// CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32>
50+
amdgpu.wmma %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32>
51+
52+
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
53+
amdgpu.wmma %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32>
54+
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
55+
amdgpu.wmma %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32>
56+
57+
// CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
58+
amdgpu.wmma %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32>
59+
// CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
60+
amdgpu.wmma %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
61+
62+
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32>
63+
amdgpu.wmma %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32>
64+
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
65+
amdgpu.wmma %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32>
566

6-
// CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
7-
amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
867
func.return
968
}

0 commit comments

Comments
 (0)