Skip to content

Commit d14a3b1

Browse files
[mlir][AMDGPU] "Added support for 64-bit operands in
ROCDL::DPPUpdateOp operation."
1 parent 870e48b commit d14a3b1

File tree

4 files changed

+60
-61
lines changed

4 files changed

+60
-61
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
572572
builder.getInt32(op.getRowMask()),
573573
builder.getInt32(op.getBankMask()),
574574
builder.getInt1(op.getBoundCtrl())
575-
};
575+
};
576576
$res = createIntrinsicCall(builder,
577577
llvm::Intrinsic::amdgcn_update_dpp, args, {vdataType});
578578
}];

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -861,25 +861,34 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
861861
Value old = adaptor.getOld();
862862
Type srcType = src.getType();
863863
Type oldType = old.getType();
864-
auto llvmI32Type = typeConverter->convertType(rewriter.getI32Type());
864+
Type llvmType = nullptr;
865+
if (srcType.getIntOrFloatBitWidth() < 32) {
866+
llvmType = rewriter.getI32Type();
867+
} else if (isa<FloatType>(srcType)) {
868+
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
869+
? rewriter.getF32Type()
870+
: rewriter.getF64Type();
871+
} else if (isa<IntegerType>(srcType)) {
872+
llvmType = (srcType.getIntOrFloatBitWidth() == 32)
873+
? rewriter.getI32Type()
874+
: rewriter.getI64Type();
875+
}
865876
auto llvmSrcIntType = typeConverter->convertType(
866877
rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
867878

868-
// If the source type is less or equal to i32 or f32, use bitcast to convert
869-
// it to i32.
879+
// If the source type is less of 32, use bitcast to convert it to i32.
870880
auto convertOperand = [&](Value operand, Type operandType) {
871-
if (llvm::isa<FloatType>(operandType)) {
872-
operand =
873-
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
874-
}
875-
876-
if (operandType.getIntOrFloatBitWidth() < 32) {
881+
if (operandType.getIntOrFloatBitWidth() <= 16) {
882+
if (llvm::isa<FloatType>(operandType)) {
883+
operand =
884+
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
885+
}
877886
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
878887
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
879888
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
880889
operand = rewriter.create<LLVM::InsertElementOp>(
881890
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
882-
operand = rewriter.create<LLVM::BitcastOp>(loc, llvmI32Type, operand);
891+
operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
883892
}
884893
return operand;
885894
};
@@ -967,15 +976,14 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
967976

968977
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
969978
auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
970-
loc, llvmI32Type, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
979+
loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
971980

972981
Value result = dppMovOp.getRes();
973982
if (srcType.getIntOrFloatBitWidth() < 32) {
974983
result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
975-
}
976-
977-
if (!llvm::isa<IntegerType>(srcType)) {
978-
result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
984+
if (!llvm::isa<IntegerType>(srcType)) {
985+
result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
986+
}
979987
}
980988

981989
// We are replacing the AMDGPU_DPPOp instruction with the new

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,8 @@ LogicalResult MFMAOp::verify() {
331331
//===----------------------------------------------------------------------===//
332332
LogicalResult DPPOp::verify() {
333333
Type srcType = getSrc().getType();
334-
if (srcType.getIntOrFloatBitWidth() > 32) {
335-
return emitOpError("integer and floating point types larger than 32 bits "
334+
if (srcType.getIntOrFloatBitWidth() > 64) {
335+
return emitOpError("integer and floating point types larger than 64 bits "
336336
"are not supported");
337337
}
338338

mlir/test/Conversion/AMDGPUToROCDL/dpp.mlir

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@ func.func @quad_dpp(%arg0: i32, %arg1: i32) -> i32 {
1818
return %0 : i32
1919
}
2020

21-
func.func @quad_perm_dpp(%arg0: i32, %arg1: i32) -> i32 {
22-
// CHECK-LABEL: func @quad_perm_dpp
23-
// CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i32
24-
// CHECK: return %0 : i32
25-
%0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i32
26-
return %0 : i32
27-
}
28-
2921
func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
3022
// CHECK-LABEL: func @wave_shr_dpp
3123
// CHECK: rocdl.update.dpp %arg0, %arg1 with 312, 10, 1, true : i32
@@ -34,25 +26,6 @@ func.func @wave_shr_dpp(%arg0: i32, %arg1: i32) -> i32 {
3426
return %0 : i32
3527
}
3628

37-
func.func @row_bcast_dpp(%arg0: i32, %arg1: i32) -> i32 {
38-
// CHECK-LABEL: func @row_bcast_dpp
39-
// CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : i32
40-
// CHECK: return %0 : i32
41-
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : i32
42-
return %0 : i32
43-
}
44-
45-
func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
46-
// CHECK-LABEL: func @row_bcast_dpp_f32
47-
// CHECK: llvm.bitcast %arg1 : f32 to i32
48-
// CHECK: llvm.bitcast %arg0 : f32 to i32
49-
// CHECK: rocdl.update.dpp %1, %0 with 322, 15, 15, true : i32
50-
// CHECK: llvm.bitcast %2 : i32 to f32
51-
// CHECK: return %3 : f32
52-
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
53-
return %0 : f32
54-
}
55-
5629
func.func @row_half_mirror_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
5730
// CHECK-LABEL: func @row_half_mirror_update_dpp
5831
// CHECK: rocdl.update.dpp %arg0, %arg1 with 321, 15, 1, false : i32
@@ -69,17 +42,46 @@ func.func @wave_rol_update_dpp(%arg0: i32, %arg1: i32) -> i32 {
6942
return %0 : i32
7043
}
7144

45+
func.func @row_bcast_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
46+
// CHECK-LABEL: func @row_bcast_dpp_f32
47+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 322, 15, 15, true : f32
48+
// CHECK: return %0 : f32
49+
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_15 { bound_ctrl = true } : f32
50+
return %0 : f32
51+
}
52+
7253
func.func @test_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
7354
// CHECK-LABEL: func @test_dpp_f32
74-
// CHECK: llvm.bitcast %arg1 : f32 to i32
75-
// CHECK: llvm.bitcast %arg0 : f32 to i32
76-
// CHECK: rocdl.update.dpp %1, %0 with 320, 1, 4, true : i32
77-
// CHECK: llvm.bitcast %2 : i32 to f32
78-
// CHECK: return %3 : f32
55+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 320, 1, 4, true : f32
56+
// CHECK: return %0 : f32
7957
%0 = amdgpu.dpp %arg0 %arg1 row_mirror { row_mask = 0x1 : i32, bank_mask = 0x4 : i32, bound_ctrl = true } : f32
8058
return %0 : f32
8159
}
8260

61+
func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
62+
// CHECK-LABEL: func @quad_perm_update_dpp_f32
63+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 1, false : f32
64+
// CHECK: return %0 : f32
65+
%0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
66+
return %0 : f32
67+
}
68+
69+
func.func @quad_perm_dpp(%arg0: i64, %arg1: i64) -> i64 {
70+
// CHECK-LABEL: func @quad_perm_dpp
71+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 88, 15, 15, false : i64
72+
// CHECK: return %0 : i64
73+
%0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) : i64
74+
return %0 : i64
75+
}
76+
77+
func.func @row_bcast_dpp(%arg0: f64, %arg1: f64) -> f64 {
78+
// CHECK-LABEL: func @row_bcast_dpp
79+
// CHECK: rocdl.update.dpp %arg0, %arg1 with 323, 4, 1, false : f64
80+
// CHECK: return %0 : f64
81+
%0 = amdgpu.dpp %arg0 %arg1 row_bcast_31 { row_mask = 0x4 : i32, bank_mask = 0x1 : i32} : f64
82+
return %0 : f64
83+
}
84+
8385
func.func @test_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
8486
// CHECK-LABEL: func @test_dpp_f16
8587
// CHECK: llvm.bitcast %arg1 : f16 to i16
@@ -117,17 +119,6 @@ func.func @row_shl_dpp_i16(%arg0: i16, %arg1: i16) -> i16 {
117119
return %0 : i16
118120
}
119121

120-
func.func @quad_perm_update_dpp_f32(%arg0: f32, %arg1: f32) -> f32 {
121-
// CHECK-LABEL: func @quad_perm_update_dpp_f32
122-
// CHECK: llvm.bitcast %arg1 : f32 to i32
123-
// CHECK: llvm.bitcast %arg0 : f32 to i32
124-
// CHECK: rocdl.update.dpp %1, %0 with 88, 15, 1, false : i32
125-
// CHECK: llvm.bitcast %2 : i32 to f32
126-
// CHECK: return %3 : f32
127-
%0 = amdgpu.dpp %arg0 %arg1 quad_perm ( [0,2,1,1] ) { bank_mask = 0x1 : i32 } : f32
128-
return %0 : f32
129-
}
130-
131122
func.func @row_bcast_update_dpp_f16(%arg0: f16, %arg1: f16) -> f16 {
132123
// CHECK-LABEL: func @row_bcast_update_dpp_f16
133124
// CHECK: llvm.bitcast %arg1 : f16 to i16

0 commit comments

Comments
 (0)