Skip to content

Commit 2685c3c

Browse files
[mlir][AMDGPU] Implement AMDGPU DPP operation in MLIR.
Defined AMDGPU DPP operation in mlir to represent semantics. Introduced a new enumeration attribute for different permutations and allowed for different types of arguments.Implemented constant attribute handling for ROCDL::DPPMovOp operation. The operation now correctly accepts constant attributes for dppCtrl, rowMask, bankMask, boundCtrl, and passes them to the corresponding LLVM intrinsic.
1 parent d17db60 commit 2685c3c

File tree

5 files changed

+370
-2
lines changed

5 files changed

+370
-2
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,60 @@ def AMDGPU_RawBufferAtomicUminOp :
410410
let hasVerifier = 1;
411411
}
412412

413+
def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
414+
"The possible permutations for a DPP operation",
415+
[
416+
I32EnumAttrCase<"quad_perm", 0>,
417+
I32EnumAttrCase<"row_shl", 1>,
418+
I32EnumAttrCase<"row_shr", 2>,
419+
I32EnumAttrCase<"row_ror", 3>,
420+
I32EnumAttrCase<"wave_shl", 4>,
421+
I32EnumAttrCase<"wave_shr", 5>,
422+
I32EnumAttrCase<"wave_ror", 6>,
423+
I32EnumAttrCase<"wave_rol", 7>,
424+
I32EnumAttrCase<"row_mirror", 8>,
425+
I32EnumAttrCase<"row_half_mirror", 9>,
426+
I32EnumAttrCase<"row_bcast_15", 10>,
427+
I32EnumAttrCase<"row_bcast_31", 11>
428+
]> {
429+
let genSpecializedAttr = 0;
430+
let cppNamespace = "::mlir::amdgpu";
431+
}
432+
433+
def AMDGPU_DPPPermAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_DPPPerm,
434+
"dpp_perm">;
435+
436+
def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result", "src"]>]>,
437+
Arguments<(ins AnyType:$src,
438+
AMDGPU_DPPPermAttr:$kind,
439+
OptionalAttr<AnyAttrOf<[I32Attr, ArrayAttr, UnitAttr]>>:$permArgument,
440+
DefaultValuedAttr<I32Attr, "0xf">:$row_mask,
441+
DefaultValuedAttr<I32Attr, "0xf">:$bank_mask,
442+
DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl)> {
443+
let summary = "AMDGPU DPP operation";
444+
let description = [{
445+
This operation represents DPP functionality in a GPU program.
446+
DPP provides the following operations:
447+
- Full crossbar in a group of four (`quad_perm`)
448+
- Wavefront shift left by one lane (`wave_shl`)
449+
- Wavefront shift right by one lane (`wave_shr`)
450+
- Wavefront rotate right by one lane (`wave_ror`)
451+
- Wavefront rotate left by one lane (`wave_rol`)
452+
- Row shift left by 1–15 lanes (`row_shl`)
453+
- Row shift right by 1–15 lanes (`row_shr`)
454+
- Row rotate right by 1–15 lanes (`row_ror`)
455+
- Reverse within a row (`row_mirror`)
456+
- Reverse within a half-row (`row_half_mirror`)
457+
- Broadcast the 15th lane of each row to the next row (`row_bcast`)
458+
- Broadcast lane 31 to rows 2 and 3 (`row_bcast`)
459+
}];
460+
let results = (outs AnyType:$result);
461+
let assemblyFormat = [{
462+
$src $kind (`(` $permArgument^ `)`)? attr-dict `:` type($result)
463+
}];
464+
let hasVerifier = 1;
465+
}
466+
413467
def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
414468
let summary = "Barrier that includes a wait for LDS memory operations.";
415469
let description = [{

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,29 @@ def ROCDL_RawBufferAtomicUMinOp :
591591
let hasCustomAssemblyFormat = 1;
592592
}
593593

594+
// DPP Move intrinsic
595+
def ROCDL_DPPMovOp : ROCDL_IntrOp<"mov.dpp", [], [0],
596+
[AllTypesMatch<["res", "src"]>], 1>,
597+
Arguments<(ins LLVM_Type:$src, I32Attr:$dppCtrl, I32Attr:$rowMask,
598+
I32Attr:$bankMask, I1Attr:$boundCtrl)> {
599+
let results = (outs LLVM_Type:$res);
600+
let assemblyFormat = [{
601+
attr-dict $src `with` $dppCtrl `,` $rowMask `,` $bankMask `,` $boundCtrl `:` type($src)
602+
}];
603+
string llvmBuilder = [{
604+
auto vdataType = moduleTranslation.convertType(op.getSrc().getType());
605+
llvm::Value *args[] = {
606+
moduleTranslation.lookupValue(op.getSrc()),
607+
builder.getInt32(op.getDppCtrl()),
608+
builder.getInt32(op.getRowMask()),
609+
builder.getInt32(op.getBankMask()),
610+
builder.getInt1(op.getBoundCtrl())
611+
};
612+
$res = createIntrinsicCall(builder,
613+
llvm::Intrinsic::amdgcn_mov_dpp, args, {vdataType});
614+
}];
615+
}
616+
594617
//===---------------------------------------------------------------------===//
595618
// 8-bit float intrinsics
596619
//===---------------------------------------------------------------------===//

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,144 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
829829
return success();
830830
}
831831

832+
// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
833+
// operation into the corresponding ROCDL instructions.
834+
struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
835+
AMDGPUDPPLowering(LLVMTypeConverter &converter, Chipset chipset)
836+
: ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
837+
Chipset chipset;
838+
839+
LogicalResult
840+
matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
841+
ConversionPatternRewriter &rewriter) const override {
842+
843+
// Convert the source operand to the corresponding LLVM type
844+
Location loc = DppOp.getLoc();
845+
Value src = adaptor.getSrc();
846+
Type srcType = src.getType();
847+
auto llvmI32Type = typeConverter->convertType(rewriter.getI32Type());
848+
auto llvmSrcIntType = typeConverter->convertType(
849+
rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
850+
851+
// If the source type is less or equal to i32 or f32, use bitcast to convert
852+
// it to i32.
853+
if (llvm::isa<FloatType>(srcType)) {
854+
src = rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, src);
855+
}
856+
857+
if (srcType.getIntOrFloatBitWidth() < 32) {
858+
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
859+
32 / srcType.getIntOrFloatBitWidth(), llvmSrcIntType));
860+
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
861+
src = rewriter.create<LLVM::InsertElementOp>(
862+
loc, undefVec, src, createI32Constant(rewriter, loc, 0));
863+
src = rewriter.create<LLVM::BitcastOp>(loc, llvmI32Type, src);
864+
}
865+
866+
// This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
867+
enum DppCtrl : unsigned {
868+
ROW_SHL0 = 0x100,
869+
ROW_SHR0 = 0x110,
870+
ROW_ROR0 = 0x120,
871+
WAVE_SHL1 = 0x130,
872+
WAVE_ROL1 = 0x134,
873+
WAVE_SHR1 = 0x138,
874+
WAVE_ROR1 = 0x13C,
875+
ROW_MIRROR = 0x140,
876+
ROW_HALF_MIRROR = 0x141,
877+
BCAST15 = 0x142,
878+
BCAST31 = 0x143,
879+
};
880+
881+
auto kind = DppOp.getKind();
882+
auto permArgument = DppOp.getPermArgument();
883+
uint32_t DppCtrl = 0;
884+
885+
switch (kind) {
886+
887+
case DPPPerm::quad_perm:
888+
if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
889+
int32_t i = 0;
890+
for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
891+
uint32_t num = elem.getInt();
892+
DppCtrl |= num << (i * 2);
893+
i++;
894+
}
895+
}
896+
break;
897+
case DPPPerm::row_shl:
898+
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
899+
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
900+
}
901+
break;
902+
case DPPPerm::row_shr:
903+
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
904+
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
905+
}
906+
break;
907+
case DPPPerm::row_ror:
908+
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
909+
DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
910+
}
911+
break;
912+
case DPPPerm::wave_shl:
913+
DppCtrl = DppCtrl::WAVE_SHL1;
914+
break;
915+
case DPPPerm::wave_shr:
916+
DppCtrl = DppCtrl::WAVE_SHR1;
917+
break;
918+
case DPPPerm::wave_rol:
919+
DppCtrl = DppCtrl::WAVE_ROL1;
920+
break;
921+
case DPPPerm::wave_ror:
922+
DppCtrl = DppCtrl::WAVE_ROR1;
923+
break;
924+
case DPPPerm::row_mirror:
925+
DppCtrl = DppCtrl::ROW_MIRROR;
926+
break;
927+
case DPPPerm::row_half_mirror:
928+
DppCtrl = DppCtrl::ROW_HALF_MIRROR;
929+
break;
930+
case DPPPerm::row_bcast_15:
931+
DppCtrl = DppCtrl::BCAST15;
932+
break;
933+
case DPPPerm::row_bcast_31:
934+
DppCtrl = DppCtrl::BCAST31;
935+
break;
936+
}
937+
938+
// Check for row_mask, bank_mask, bound_ctrl if they exist and create
939+
// constants
940+
auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask")
941+
.dyn_cast<IntegerAttr>()
942+
.getInt();
943+
auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask")
944+
.dyn_cast<IntegerAttr>()
945+
.getInt();
946+
bool boundCtrl = DppOp->getAttrOfType<IntegerAttr>("bound_ctrl")
947+
.dyn_cast<BoolAttr>()
948+
.getValue();
949+
950+
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
951+
auto dppMovOp = rewriter.create<ROCDL::DPPMovOp>(
952+
loc, llvmI32Type, src, DppCtrl, rowMask, bankMask, boundCtrl);
953+
954+
Value result = dppMovOp.getRes();
955+
if (srcType.getIntOrFloatBitWidth() < 32) {
956+
result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
957+
}
958+
959+
if (!llvm::isa<IntegerType>(srcType)) {
960+
result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
961+
}
962+
963+
// We are replacing the AMDGPU_DPPOp instruction with the new
964+
// ROCDL_DPPMovOp instruction
965+
rewriter.replaceOp(DppOp, ValueRange(result));
966+
return success();
967+
}
968+
};
969+
832970
struct ConvertAMDGPUToROCDLPass
833971
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
834972
ConvertAMDGPUToROCDLPass() = default;
@@ -880,8 +1018,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
8801018
ROCDL::RawPtrBufferAtomicUminOp>,
8811019
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
8821020
ROCDL::RawPtrBufferAtomicCmpSwap>,
883-
LDSBarrierOpLowering, MFMAOpLowering, WMMAOpLowering,
884-
ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1021+
AMDGPUDPPLowering, LDSBarrierOpLowering, MFMAOpLowering,
1022+
WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
8851023
PackedStochRoundFp8OpLowering>(converter, chipset);
8861024
}
8871025

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,68 @@ LogicalResult MFMAOp::verify() {
326326
return success();
327327
}
328328

329+
//===----------------------------------------------------------------------===//
330+
// DPPOp
331+
//===----------------------------------------------------------------------===//
332+
LogicalResult DPPOp::verify() {
333+
Type srcType = getSrc().getType();
334+
if (srcType.getIntOrFloatBitWidth() > 32) {
335+
return emitOpError("integer and floating point types larger than 32 bits "
336+
"are not supported");
337+
}
338+
339+
DPPPerm kind = getKind();
340+
Attribute permArgument = getPermArgument().value_or(Attribute{});
341+
342+
switch (kind) {
343+
344+
case DPPPerm::quad_perm: {
345+
auto quadPermAttr = permArgument.dyn_cast_or_null<ArrayAttr>();
346+
if (!quadPermAttr || quadPermAttr.size() != 4) {
347+
return emitOpError("quad_perm attribute must have exactly 4 elements");
348+
}
349+
for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
350+
uint32_t num = elem.getInt();
351+
if (num < 0 || num > 3) {
352+
return emitOpError(
353+
"Each element of quad_perm must be in the range [0, 3]");
354+
}
355+
}
356+
} break;
357+
358+
case DPPPerm::row_shl:
359+
case DPPPerm::row_shr:
360+
case DPPPerm::row_ror: {
361+
if (!permArgument) {
362+
return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
363+
"' value not specified");
364+
}
365+
if (auto intAttr = permArgument.dyn_cast<IntegerAttr>()) {
366+
uint32_t attrValue = intAttr.getInt();
367+
if (attrValue < 1 || attrValue > 15) {
368+
return emitOpError("Attribute value must be between 1 and 15");
369+
}
370+
}
371+
} break;
372+
373+
case DPPPerm::wave_shl:
374+
case DPPPerm::wave_shr:
375+
case DPPPerm::wave_rol:
376+
case DPPPerm::wave_ror:
377+
case DPPPerm::row_mirror:
378+
case DPPPerm::row_half_mirror:
379+
case DPPPerm::row_bcast_15:
380+
case DPPPerm::row_bcast_31: {
381+
if (permArgument && !permArgument.isa<UnitAttr>()) {
382+
return emitOpError("Expected unit attribute for permArgument, but found "
383+
"non-trivial argument");
384+
}
385+
break;
386+
}
387+
}
388+
return success();
389+
}
390+
329391
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
330392

331393
#define GET_ATTRDEF_CLASSES

0 commit comments

Comments
 (0)