Skip to content

Commit 870e48b

Browse files
[mlir][AMDGPU] Implement AMDGPU DPP operation in MLIR.
Defined AMDGPU DPP operation in mlir to represent semantics. Introduced a new enumeration attribute for different permutations and allowed for different types of arguments.Implemented constant attribute handling for ROCDL::DPPUpdateOp operation. The operation now correctly accepts constant attributes for dppCtrl, rowMask, bankMask, boundCtrl, and passes them to the corresponding LLVM intrinsic.
1 parent 145aff6 commit 870e48b

File tree

6 files changed

+440
-3
lines changed

6 files changed

+440
-3
lines changed

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2343,6 +2343,8 @@ def int_amdgcn_buffer_wbinvl1_vol :
23432343
// VI Intrinsics
23442344
//===----------------------------------------------------------------------===//
23452345

2346+
// The llvm.amdgcn.mov.dpp.i32 intrinsic represents the mov.dpp operation in AMDGPU.
2347+
// This operation is being deprecated and can be replaced with llvm.amdgcn.update.dpp.i32.
23462348
// llvm.amdgcn.mov.dpp.i32 <src> <dpp_ctrl> <row_mask> <bank_mask> <bound_ctrl>
23472349
def int_amdgcn_mov_dpp :
23482350
Intrinsic<[llvm_anyint_ty],
@@ -2352,6 +2354,10 @@ def int_amdgcn_mov_dpp :
23522354
ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>,
23532355
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, IntrNoCallback, IntrNoFree]>;
23542356

2357+
// The llvm.amdgcn.update.dpp.i32 intrinsic represents the update.dpp operation in AMDGPU.
2358+
// It takes an old value, a source operand, a DPP control operand, a row mask, a bank mask, and a bound control.
2359+
// This operation is equivalent to a sequence of v_mov_b32 operations.
2360+
// It is preferred over llvm.amdgcn.mov.dpp.i32 for future use.
23552361
// llvm.amdgcn.update.dpp.i32 <old> <src> <dpp_ctrl> <row_mask> <bank_mask> <bound_ctrl>
23562362
// Should be equivalent to:
23572363
// v_mov_b32 <dest> <old>

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,61 @@ def AMDGPU_RawBufferAtomicUminOp :
410410
let hasVerifier = 1;
411411
}
412412

413+
def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
414+
"The possible permutations for a DPP operation",
415+
[
416+
I32EnumAttrCase<"quad_perm", 0>,
417+
I32EnumAttrCase<"row_shl", 1>,
418+
I32EnumAttrCase<"row_shr", 2>,
419+
I32EnumAttrCase<"row_ror", 3>,
420+
I32EnumAttrCase<"wave_shl", 4>,
421+
I32EnumAttrCase<"wave_shr", 5>,
422+
I32EnumAttrCase<"wave_ror", 6>,
423+
I32EnumAttrCase<"wave_rol", 7>,
424+
I32EnumAttrCase<"row_mirror", 8>,
425+
I32EnumAttrCase<"row_half_mirror", 9>,
426+
I32EnumAttrCase<"row_bcast_15", 10>,
427+
I32EnumAttrCase<"row_bcast_31", 11>
428+
]> {
429+
let genSpecializedAttr = 0;
430+
let cppNamespace = "::mlir::amdgpu";
431+
}
432+
433+
def AMDGPU_DPPPermAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_DPPPerm,
434+
"dpp_perm">;
435+
436+
def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result", "old", "src"]>]>,
437+
Arguments<(ins AnyType:$old,
438+
AnyType:$src,
439+
AMDGPU_DPPPermAttr:$kind,
440+
OptionalAttr<AnyAttrOf<[I32Attr, ArrayAttr, UnitAttr]>>:$permArgument,
441+
DefaultValuedAttr<I32Attr, "0xf">:$row_mask,
442+
DefaultValuedAttr<I32Attr, "0xf">:$bank_mask,
443+
DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl)> {
444+
let summary = "AMDGPU DPP operation";
445+
let description = [{
446+
This operation represents DPP functionality in a GPU program.
447+
DPP provides the following operations:
448+
- Full crossbar in a group of four (`quad_perm`)
449+
- Wavefront shift left by one lane (`wave_shl`)
450+
- Wavefront shift right by one lane (`wave_shr`)
451+
- Wavefront rotate right by one lane (`wave_ror`)
452+
- Wavefront rotate left by one lane (`wave_rol`)
453+
- Row shift left by 1–15 lanes (`row_shl`)
454+
- Row shift right by 1–15 lanes (`row_shr`)
455+
- Row rotate right by 1–15 lanes (`row_ror`)
456+
- Reverse within a row (`row_mirror`)
457+
- Reverse within a half-row (`row_half_mirror`)
458+
- Broadcast the 15th lane of each row to the next row (`row_bcast`)
459+
- Broadcast lane 31 to rows 2 and 3 (`row_bcast`)
460+
}];
461+
let results = (outs AnyType:$result);
462+
let assemblyFormat = [{
463+
$old $src $kind (`(` $permArgument^ `)`)? attr-dict `:` type($result)
464+
}];
465+
let hasVerifier = 1;
466+
}
467+
413468
def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
414469
let summary = "Barrier that includes a wait for LDS memory operations.";
415470
let description = [{

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,30 @@ def ROCDL_RawBufferAtomicUMinOp :
554554
let hasCustomAssemblyFormat = 1;
555555
}
556556

557+
// DPP Update intrinsic
558+
def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
559+
[AllTypesMatch<["res", "src", "old"]>], 1>,
560+
Arguments<(ins LLVM_Type:$old, LLVM_Type:$src, I32Attr:$dppCtrl, I32Attr:$rowMask,
561+
I32Attr:$bankMask, I1Attr:$boundCtrl)> {
562+
let results = (outs LLVM_Type:$res);
563+
let assemblyFormat = [{
564+
attr-dict $old `,` $src `with` $dppCtrl `,` $rowMask `,` $bankMask `,` $boundCtrl `:` type($src)
565+
}];
566+
string llvmBuilder = [{
567+
auto vdataType = moduleTranslation.convertType(op.getSrc().getType());
568+
llvm::Value *args[] = {
569+
moduleTranslation.lookupValue(op.getOld()),
570+
moduleTranslation.lookupValue(op.getSrc()),
571+
builder.getInt32(op.getDppCtrl()),
572+
builder.getInt32(op.getRowMask()),
573+
builder.getInt32(op.getBankMask()),
574+
builder.getInt1(op.getBoundCtrl())
575+
};
576+
$res = createIntrinsicCall(builder,
577+
llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
578+
}];
579+
}
580+
557581
//===---------------------------------------------------------------------===//
558582
// 8-bit float intrinsics
559583
//===---------------------------------------------------------------------===//

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,147 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
844844
return success();
845845
}
846846

847+
// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
848+
// operation into the corresponding ROCDL instructions.
849+
struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
850+
AMDGPUDPPLowering(LLVMTypeConverter &converter, Chipset chipset)
851+
: ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
852+
Chipset chipset;
853+
854+
LogicalResult
855+
matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
856+
ConversionPatternRewriter &rewriter) const override {
857+
858+
// Convert the source operand to the corresponding LLVM type
859+
Location loc = DppOp.getLoc();
860+
Value src = adaptor.getSrc();
861+
Value old = adaptor.getOld();
862+
Type srcType = src.getType();
863+
Type oldType = old.getType();
864+
auto llvmI32Type = typeConverter->convertType(rewriter.getI32Type());
865+
auto llvmSrcIntType = typeConverter->convertType(
866+
rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
867+
868+
// If the source type is less or equal to i32 or f32, use bitcast to convert
869+
// it to i32.
870+
auto convertOperand = [&](Value operand, Type operandType) {
871+
if (llvm::isa<FloatType>(operandType)) {
872+
operand =
873+
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
874+
}
875+
876+
if (operandType.getIntOrFloatBitWidth() < 32) {
877+
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
878+
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
879+
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
880+
operand = rewriter.create<LLVM::InsertElementOp>(
881+
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
882+
operand = rewriter.create<LLVM::BitcastOp>(loc, llvmI32Type, operand);
883+
}
884+
return operand;
885+
};
886+
887+
src = convertOperand(src, srcType);
888+
old = convertOperand(old, oldType);
889+
890+
// This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
891+
enum DppCtrl : unsigned {
892+
ROW_SHL0 = 0x100,
893+
ROW_SHR0 = 0x110,
894+
ROW_ROR0 = 0x120,
895+
WAVE_SHL1 = 0x130,
896+
WAVE_ROL1 = 0x134,
897+
WAVE_SHR1 = 0x138,
898+
WAVE_ROR1 = 0x13C,
899+
ROW_MIRROR = 0x140,
900+
ROW_HALF_MIRROR = 0x141,
901+
BCAST15 = 0x142,
902+
BCAST31 = 0x143,
903+
};
904+
905+
auto kind = DppOp.getKind();
906+
auto permArgument = DppOp.getPermArgument();
907+
uint32_t DppCtrl = 0;
908+
909+
switch (kind) {
910+
911+
case DPPPerm::quad_perm:
912+
if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
913+
int32_t i = 0;
914+
for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
915+
uint32_t num = elem.getInt();
916+
DppCtrl |= num << (i * 2);
917+
i++;
918+
}
919+
}
920+
break;
921+
case DPPPerm::row_shl:
922+
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
923+
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
924+
}
925+
break;
926+
case DPPPerm::row_shr:
927+
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
928+
DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
929+
}
930+
break;
931+
case DPPPerm::row_ror:
932+
if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
933+
DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
934+
}
935+
break;
936+
case DPPPerm::wave_shl:
937+
DppCtrl = DppCtrl::WAVE_SHL1;
938+
break;
939+
case DPPPerm::wave_shr:
940+
DppCtrl = DppCtrl::WAVE_SHR1;
941+
break;
942+
case DPPPerm::wave_rol:
943+
DppCtrl = DppCtrl::WAVE_ROL1;
944+
break;
945+
case DPPPerm::wave_ror:
946+
DppCtrl = DppCtrl::WAVE_ROR1;
947+
break;
948+
case DPPPerm::row_mirror:
949+
DppCtrl = DppCtrl::ROW_MIRROR;
950+
break;
951+
case DPPPerm::row_half_mirror:
952+
DppCtrl = DppCtrl::ROW_HALF_MIRROR;
953+
break;
954+
case DPPPerm::row_bcast_15:
955+
DppCtrl = DppCtrl::BCAST15;
956+
break;
957+
case DPPPerm::row_bcast_31:
958+
DppCtrl = DppCtrl::BCAST31;
959+
break;
960+
}
961+
962+
// Check for row_mask, bank_mask, bound_ctrl if they exist and create
963+
// constants
964+
auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
965+
auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
966+
bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
967+
968+
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
969+
auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
970+
loc, llvmI32Type, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
971+
972+
Value result = dppMovOp.getRes();
973+
if (srcType.getIntOrFloatBitWidth() < 32) {
974+
result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
975+
}
976+
977+
if (!llvm::isa<IntegerType>(srcType)) {
978+
result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
979+
}
980+
981+
// We are replacing the AMDGPU_DPPOp instruction with the new
982+
// ROCDL_DPPMovOp instruction
983+
rewriter.replaceOp(DppOp, ValueRange(result));
984+
return success();
985+
}
986+
};
987+
847988
struct ConvertAMDGPUToROCDLPass
848989
: public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
849990
ConvertAMDGPUToROCDLPass() = default;
@@ -895,9 +1036,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
8951036
ROCDL::RawPtrBufferAtomicUminOp>,
8961037
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
8971038
ROCDL::RawPtrBufferAtomicCmpSwap>,
898-
LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering,
899-
WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
900-
PackedStochRoundFp8OpLowering>(converter, chipset);
1039+
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1040+
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1041+
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1042+
chipset);
9011043
}
9021044

9031045
std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,68 @@ LogicalResult MFMAOp::verify() {
326326
return success();
327327
}
328328

329+
//===----------------------------------------------------------------------===//
330+
// DPPOp
331+
//===----------------------------------------------------------------------===//
332+
LogicalResult DPPOp::verify() {
333+
Type srcType = getSrc().getType();
334+
if (srcType.getIntOrFloatBitWidth() > 32) {
335+
return emitOpError("integer and floating point types larger than 32 bits "
336+
"are not supported");
337+
}
338+
339+
DPPPerm kind = getKind();
340+
Attribute permArgument = getPermArgument().value_or(Attribute{});
341+
342+
switch (kind) {
343+
344+
case DPPPerm::quad_perm: {
345+
auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
346+
if (!quadPermAttr || quadPermAttr.size() != 4) {
347+
return emitOpError("quad_perm attribute must have exactly 4 elements");
348+
}
349+
for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
350+
uint32_t num = elem.getInt();
351+
if (num < 0 || num > 3) {
352+
return emitOpError(
353+
"Each element of quad_perm must be in the range [0, 3]");
354+
}
355+
}
356+
} break;
357+
358+
case DPPPerm::row_shl:
359+
case DPPPerm::row_shr:
360+
case DPPPerm::row_ror: {
361+
if (!permArgument) {
362+
return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
363+
"' value not specified");
364+
}
365+
if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
366+
uint32_t attrValue = intAttr.getInt();
367+
if (attrValue < 1 || attrValue > 15) {
368+
return emitOpError("Attribute value must be between 1 and 15");
369+
}
370+
}
371+
} break;
372+
373+
case DPPPerm::wave_shl:
374+
case DPPPerm::wave_shr:
375+
case DPPPerm::wave_rol:
376+
case DPPPerm::wave_ror:
377+
case DPPPerm::row_mirror:
378+
case DPPPerm::row_half_mirror:
379+
case DPPPerm::row_bcast_15:
380+
case DPPPerm::row_bcast_31: {
381+
if (permArgument && !isa<UnitAttr>(permArgument)) {
382+
return emitOpError("Expected unit attribute for permArgument, but found "
383+
"non-trivial argument");
384+
}
385+
break;
386+
}
387+
}
388+
return success();
389+
}
390+
329391
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
330392

331393
#define GET_ATTRDEF_CLASSES

0 commit comments

Comments
 (0)