@@ -514,8 +514,11 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
514
514
// / Push an input operand. If it is a float type, nothing to do. If it is
515
515
// / an integer type, then we need to also push its signdness (1 for signed, 0
516
516
// / for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
517
- // / vector. We also need to convert bfloat inputs to i16 to account for the lack
518
- // / of bfloat support in the WMMA intrinsics themselves.
517
+ // / vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
518
+ // / We also need to convert bfloat inputs to i16 to account for the bfloat
519
+ // / intrinsics having been defined before the AMD backend supported bfloat. We
520
+ // / similarly need to pack 8-bit float types into integers as if they were i8
521
+ // / (which they are for the backend's purposes).
519
522
static void wmmaPushInputOperand (ConversionPatternRewriter &rewriter,
520
523
Location loc,
521
524
const TypeConverter *typeConverter,
@@ -524,12 +527,16 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
524
527
SmallVector<Value, 4 > &operands) {
525
528
Type inputType = llvmInput.getType ();
526
529
auto vectorType = dyn_cast<VectorType>(inputType);
530
+ if (!vectorType) {
531
+ operands.push_back (llvmInput);
532
+ return ;
533
+ }
527
534
Type elemType = vectorType.getElementType ();
528
535
529
536
if (elemType.isBF16 ())
530
537
llvmInput = rewriter.create <LLVM::BitcastOp>(
531
538
loc, vectorType.clone (rewriter.getI16Type ()), llvmInput);
532
- if (! elemType.isInteger ( 8 ) ) {
539
+ if (elemType.getIntOrFloatBitWidth () > 8 ) {
533
540
operands.push_back (llvmInput);
534
541
return ;
535
542
}
@@ -538,34 +545,43 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
538
545
// for int8. This is because, in LLVM, fp8 type is converted to int8, so the
539
546
// fp8/int8 information is lost during the conversion process.
540
547
auto mlirInputType = cast<VectorType>(mlirInput.getType ());
541
- bool isInputInt8 = mlirInputType.getElementType ().isInteger (8 );
542
- if (isInputInt8 ) {
548
+ bool isInputInteger = mlirInputType.getElementType ().isInteger ();
549
+ if (isInputInteger ) {
543
550
// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
544
551
bool localIsUnsigned = isUnsigned;
545
- if (elemType.isUnsignedInteger (8 )) {
552
+ if (elemType.isUnsignedInteger ()) {
546
553
localIsUnsigned = true ;
547
- } else if (elemType.isSignedInteger (8 )) {
554
+ } else if (elemType.isSignedInteger ()) {
548
555
localIsUnsigned = false ;
549
556
}
550
557
Value sign = createI1Constant (rewriter, loc, !localIsUnsigned);
551
558
operands.push_back (sign);
552
559
}
553
560
554
- int64_t numBytes = vectorType.getNumElements ();
561
+ int64_t numBits =
562
+ vectorType.getNumElements () * elemType.getIntOrFloatBitWidth ();
555
563
Type i32 = rewriter.getI32Type ();
556
- VectorType vectorType32bits = VectorType::get (numBytes * 8 / 32 , i32 );
557
- auto llvmVectorType32bits = typeConverter->convertType (vectorType32bits);
558
- Value result = rewriter.createOrFold <LLVM::BitcastOp>(
559
- loc, llvmVectorType32bits, llvmInput);
560
- operands.push_back (result);
564
+ Type intrinsicInType = numBits <= 32
565
+ ? (Type)rewriter.getIntegerType (numBits)
566
+ : (Type)VectorType::get (numBits / 32 , i32 );
567
+ auto llvmIntrinsicInType = typeConverter->convertType (intrinsicInType);
568
+ Value castInput = rewriter.createOrFold <LLVM::BitcastOp>(
569
+ loc, llvmIntrinsicInType, llvmInput);
570
+ // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
571
+ // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
572
+ // Add in the zeros here.
573
+ if (numBits < 32 )
574
+ castInput = rewriter.create <LLVM::ZExtOp>(loc, i32 , castInput);
575
+ operands.push_back (castInput);
561
576
}
562
577
563
578
// / Push the output operand. For many cases this is only pushing the output in
564
579
// / the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
565
580
// / since the same numbers of VGPRs is used, we need to decide if to store the
566
581
// / result in the upper 16 bits of the VGPRs or in the lower part. To store the
567
582
// / result in the lower 16 bits, set subwordOffset to 1, otherwise result will
568
- // / be stored it in the upper part
583
+ // / be stored it in the upper part. The subwordOffset must not be set for gfx12,
584
+ // / as the instructions have been changed to return fewer registers instead.
569
585
static void wmmaPushOutputOperand (ConversionPatternRewriter &rewriter,
570
586
Location loc,
571
587
const TypeConverter *typeConverter,
@@ -728,8 +744,10 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
728
744
static std::optional<StringRef> wmmaOpToIntrinsic (WMMAOp wmma,
729
745
Chipset chipset) {
730
746
auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA ().getType ());
747
+ auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB ().getType ());
731
748
auto destVectorType = dyn_cast<VectorType>(wmma.getDestC ().getType ());
732
749
auto elemSourceType = sourceVectorType.getElementType ();
750
+ auto elemBSourceType = sourceBVectorType.getElementType ();
733
751
auto elemDestType = destVectorType.getElementType ();
734
752
735
753
if (elemSourceType.isF16 () && elemDestType.isF32 ())
@@ -742,10 +760,33 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
742
760
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName ();
743
761
if (elemSourceType.isInteger (8 ) && elemDestType.isInteger (32 ))
744
762
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName ();
745
- if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32 ())
746
- return ROCDL::wmma_f32_16x16x16_fp8::getOperationName ();
747
- if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32 ())
748
- return ROCDL::wmma_f32_16x16x16_bf8::getOperationName ();
763
+ if (chipset.majorVersion == 11 ) {
764
+ if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 ))
765
+ return ROCDL::wmma_i32_16x16x16_iu4::getOperationName ();
766
+ }
767
+ if (chipset.majorVersion >= 12 ) {
768
+ if (isa<Float8E4M3FNType>(elemSourceType) &&
769
+ isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32 ())
770
+ return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName ();
771
+ if (isa<Float8E4M3FNType>(elemSourceType) &&
772
+ isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32 ())
773
+ return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName ();
774
+ if (isa<Float8E5M2Type>(elemSourceType) &&
775
+ isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32 ())
776
+ return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName ();
777
+ if (isa<Float8E5M2Type>(elemSourceType) &&
778
+ isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32 ())
779
+ return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName ();
780
+ if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 )) {
781
+ bool isWave64 = destVectorType.getNumElements () == 4 ;
782
+ // This is the ambiguous case. 8 inputs to the wave64 version means that
783
+ // we want the 16x16x32 version, but for wave32 they mean the short form.
784
+ bool has8Inputs = sourceVectorType.getNumElements () == 8 ;
785
+ if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
786
+ return ROCDL::wmma_i32_16x16x32_iu4::getOperationName ();
787
+ return ROCDL::wmma_i32_16x16x16_iu4::getOperationName ();
788
+ }
789
+ }
749
790
return std::nullopt;
750
791
}
751
792
@@ -823,6 +864,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
823
864
if (!maybeIntrinsic.has_value ())
824
865
return op.emitOpError (" no intrinsic matching WMMA on the given chipset" );
825
866
867
+ if (chipset.majorVersion >= 12 && op.getSubwordOffset () != 0 )
868
+ return op.emitOpError (" subwordOffset not supported on gfx12+" );
869
+
826
870
OperationState loweredOp (loc, *maybeIntrinsic);
827
871
loweredOp.addTypes (rawOutType);
828
872
0 commit comments