@@ -829,6 +829,144 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
829
829
return success ();
830
830
}
831
831
832
+ // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
833
+ // operation into the corresponding ROCDL instructions.
834
+ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern <DPPOp> {
835
+ AMDGPUDPPLowering (LLVMTypeConverter &converter, Chipset chipset)
836
+ : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
837
+ Chipset chipset;
838
+
839
+ LogicalResult
840
+ matchAndRewrite (DPPOp DppOp, DPPOp::Adaptor adaptor,
841
+ ConversionPatternRewriter &rewriter) const override {
842
+
843
+ // Convert the source operand to the corresponding LLVM type
844
+ Location loc = DppOp.getLoc ();
845
+ Value src = adaptor.getSrc ();
846
+ Type srcType = src.getType ();
847
+ auto llvmI32Type = typeConverter->convertType (rewriter.getI32Type ());
848
+ auto llvmSrcIntType = typeConverter->convertType (
849
+ rewriter.getIntegerType (srcType.getIntOrFloatBitWidth ()));
850
+
851
+ // If the source type is less or equal to i32 or f32, use bitcast to convert
852
+ // it to i32.
853
+ if (llvm::isa<FloatType>(srcType)) {
854
+ src = rewriter.create <LLVM::BitcastOp>(loc, llvmSrcIntType, src);
855
+ }
856
+
857
+ if (srcType.getIntOrFloatBitWidth () < 32 ) {
858
+ auto llvmVecType = typeConverter->convertType (mlir::VectorType::get (
859
+ 32 / srcType.getIntOrFloatBitWidth (), llvmSrcIntType));
860
+ Value undefVec = rewriter.create <LLVM::UndefOp>(loc, llvmVecType);
861
+ src = rewriter.create <LLVM::InsertElementOp>(
862
+ loc, undefVec, src, createI32Constant (rewriter, loc, 0 ));
863
+ src = rewriter.create <LLVM::BitcastOp>(loc, llvmI32Type, src);
864
+ }
865
+
866
+ // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
867
+ enum DppCtrl : unsigned {
868
+ ROW_SHL0 = 0x100 ,
869
+ ROW_SHR0 = 0x110 ,
870
+ ROW_ROR0 = 0x120 ,
871
+ WAVE_SHL1 = 0x130 ,
872
+ WAVE_ROL1 = 0x134 ,
873
+ WAVE_SHR1 = 0x138 ,
874
+ WAVE_ROR1 = 0x13C ,
875
+ ROW_MIRROR = 0x140 ,
876
+ ROW_HALF_MIRROR = 0x141 ,
877
+ BCAST15 = 0x142 ,
878
+ BCAST31 = 0x143 ,
879
+ };
880
+
881
+ auto kind = DppOp.getKind ();
882
+ auto permArgument = DppOp.getPermArgument ();
883
+ uint32_t DppCtrl = 0 ;
884
+
885
+ switch (kind) {
886
+
887
+ case DPPPerm::quad_perm:
888
+ if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
889
+ int32_t i = 0 ;
890
+ for (auto elem : quadPermAttr.getAsRange <IntegerAttr>()) {
891
+ uint32_t num = elem.getInt ();
892
+ DppCtrl |= num << (i * 2 );
893
+ i++;
894
+ }
895
+ }
896
+ break ;
897
+ case DPPPerm::row_shl:
898
+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
899
+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_SHL0;
900
+ }
901
+ break ;
902
+ case DPPPerm::row_shr:
903
+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
904
+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_SHR0;
905
+ }
906
+ break ;
907
+ case DPPPerm::row_ror:
908
+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
909
+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_ROR0;
910
+ }
911
+ break ;
912
+ case DPPPerm::wave_shl:
913
+ DppCtrl = DppCtrl::WAVE_SHL1;
914
+ break ;
915
+ case DPPPerm::wave_shr:
916
+ DppCtrl = DppCtrl::WAVE_SHR1;
917
+ break ;
918
+ case DPPPerm::wave_rol:
919
+ DppCtrl = DppCtrl::WAVE_ROL1;
920
+ break ;
921
+ case DPPPerm::wave_ror:
922
+ DppCtrl = DppCtrl::WAVE_ROR1;
923
+ break ;
924
+ case DPPPerm::row_mirror:
925
+ DppCtrl = DppCtrl::ROW_MIRROR;
926
+ break ;
927
+ case DPPPerm::row_half_mirror:
928
+ DppCtrl = DppCtrl::ROW_HALF_MIRROR;
929
+ break ;
930
+ case DPPPerm::row_bcast_15:
931
+ DppCtrl = DppCtrl::BCAST15;
932
+ break ;
933
+ case DPPPerm::row_bcast_31:
934
+ DppCtrl = DppCtrl::BCAST31;
935
+ break ;
936
+ }
937
+
938
+ // Check for row_mask, bank_mask, bound_ctrl if they exist and create
939
+ // constants
940
+ auto rowMask = DppOp->getAttrOfType <IntegerAttr>(" row_mask" )
941
+ .dyn_cast <IntegerAttr>()
942
+ .getInt ();
943
+ auto bankMask = DppOp->getAttrOfType <IntegerAttr>(" bank_mask" )
944
+ .dyn_cast <IntegerAttr>()
945
+ .getInt ();
946
+ bool boundCtrl = DppOp->getAttrOfType <IntegerAttr>(" bound_ctrl" )
947
+ .dyn_cast <BoolAttr>()
948
+ .getValue ();
949
+
950
+ // create a ROCDL_DPPMovOp instruction with the appropriate attributes
951
+ auto dppMovOp = rewriter.create <ROCDL::DPPMovOp>(
952
+ loc, llvmI32Type, src, DppCtrl, rowMask, bankMask, boundCtrl);
953
+
954
+ Value result = dppMovOp.getRes ();
955
+ if (srcType.getIntOrFloatBitWidth () < 32 ) {
956
+ result = rewriter.create <LLVM::TruncOp>(loc, llvmSrcIntType, result);
957
+ }
958
+
959
+ if (!llvm::isa<IntegerType>(srcType)) {
960
+ result = rewriter.create <LLVM::BitcastOp>(loc, srcType, result);
961
+ }
962
+
963
+ // We are replacing the AMDGPU_DPPOp instruction with the new
964
+ // ROCDL_DPPMovOp instruction
965
+ rewriter.replaceOp (DppOp, ValueRange (result));
966
+ return success ();
967
+ }
968
+ };
969
+
832
970
struct ConvertAMDGPUToROCDLPass
833
971
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
834
972
ConvertAMDGPUToROCDLPass () = default ;
@@ -880,8 +1018,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
880
1018
ROCDL::RawPtrBufferAtomicUminOp>,
881
1019
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
882
1020
ROCDL::RawPtrBufferAtomicCmpSwap>,
883
- LDSBarrierOpLowering, MFMAOpLowering, WMMAOpLowering ,
884
- ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1021
+ AMDGPUDPPLowering, LDSBarrierOpLowering, MFMAOpLowering ,
1022
+ WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
885
1023
PackedStochRoundFp8OpLowering>(converter, chipset);
886
1024
}
887
1025
0 commit comments