Skip to content

Commit 9167cb0

Browse files
authored
Add new pass for math to rocdl. (llvm#93)
Add new pass for math to rocdl.
1 parent c01ecf8 commit 9167cb0

File tree

8 files changed

+221
-1
lines changed

8 files changed

+221
-1
lines changed

flang/lib/Optimizer/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen
2626
MLIRMathToFuncs
2727
MLIRMathToLLVM
2828
MLIRMathToLibm
29+
MLIRMathToROCDL
2930
MLIROpenMPToLLVM
3031
MLIRBuiltinToLLVMIRTranslation
3132
MLIRLLVMToLLVMIRTranslation

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
3535
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
3636
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
37+
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
3738
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
3839
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
3940
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
@@ -3609,6 +3610,14 @@ class FIRToLLVMLowering
36093610
// as passes here.
36103611
mlir::OpPassManager mathConvertionPM("builtin.module");
36113612

3613+
bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
3614+
// If compiling for AMD target some math operations must be lowered to AMD
3615+
// GPU library calls, the rest can be converted to LLVM intrinsics, which
3616+
// is handled in the mathToLLVM conversion. The lowering to libm calls is
3617+
// not needed since all math operations are handled this way.
3618+
if (isAMDGCN)
3619+
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
3620+
36123621
// Convert math::FPowI operations to inline implementation
36133622
// only if the exponent's width is greater than 32, otherwise,
36143623
// it will be lowered to LLVM intrinsic operation by a later conversion.
@@ -3648,7 +3657,8 @@ class FIRToLLVMLowering
36483657
pattern);
36493658
// Math operations that have not been converted yet must be converted
36503659
// to Libm.
3651-
mlir::populateMathToLibmConversionPatterns(pattern);
3660+
if (!isAMDGCN)
3661+
mlir::populateMathToLibmConversionPatterns(pattern);
36523662
mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
36533663
mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern);
36543664

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- MathToROCDL.h - Utils to convert from the complex dialect --------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
9+
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
10+
11+
#include <memory>
12+
13+
namespace mlir {
14+
class Pass;
15+
16+
#define GEN_PASS_DECL_CONVERTMATHTOROCDL
17+
#include "mlir/Conversion/Passes.h.inc"
18+
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
4646
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
4747
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
48+
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
4849
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
4950
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
5051
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,22 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
721721
];
722722
}
723723

724+
//===----------------------------------------------------------------------===//
725+
// MathToROCDL
726+
//===----------------------------------------------------------------------===//
727+
728+
def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
729+
let summary = "Convert Math dialect to rocdl calls";
730+
let description = [{
731+
This pass converts supported Math ops to rocdl calls.
732+
}];
733+
let dependentDialects = [
734+
"func::FuncDialect",
735+
"math::MathDialect",
736+
"vector::VectorDialect",
737+
];
738+
}
739+
724740
//===----------------------------------------------------------------------===//
725741
// MathToSPIRV
726742
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_subdirectory(LLVMCommon)
3535
add_subdirectory(MathToFuncs)
3636
add_subdirectory(MathToLibm)
3737
add_subdirectory(MathToLLVM)
38+
add_subdirectory(MathToROCDL)
3839
add_subdirectory(MathToSPIRV)
3940
add_subdirectory(MemRefToEmitC)
4041
add_subdirectory(MemRefToLLVM)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
add_mlir_conversion_library(MLIRMathToROCDL
2+
MathToROCDL.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRDialectUtils
15+
MLIRFuncDialect
16+
MLIRGPUToGPURuntimeTransforms
17+
MLIRMathDialect
18+
MLIRLLVMCommonConversion
19+
MLIRPass
20+
MLIRTransformUtils
21+
MLIRVectorDialect
22+
MLIRVectorUtils
23+
)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
10+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
11+
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14+
#include "mlir/Dialect/Math/IR/Math.h"
15+
#include "mlir/Dialect/Utils/IndexingUtils.h"
16+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
17+
#include "mlir/IR/BuiltinDialect.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Transforms/DialectConversion.h"
21+
22+
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
23+
#include "../GPUCommon/GPUOpsLowering.h"
24+
#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
25+
#include "../GPUCommon/OpToFuncCallLowering.h"
26+
27+
namespace mlir {
28+
#define GEN_PASS_DEF_CONVERTMATHTOROCDL
29+
#include "mlir/Conversion/Passes.h.inc"
30+
} // namespace mlir
31+
32+
using namespace mlir;
33+
34+
#define DEBUG_TYPE "math-to-rocdl"
35+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
36+
37+
template <typename OpTy>
38+
static void populateOpPatterns(LLVMTypeConverter &converter,
39+
RewritePatternSet &patterns, StringRef f32Func,
40+
StringRef f64Func) {
41+
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
42+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
43+
}
44+
45+
static void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
46+
RewritePatternSet &patterns) {
47+
// Handled by mathToLLVM: math::AbsIOp
48+
// Handled by mathToLLVM: math::CopySignOp
49+
// Handled by mathToLLVM: math::CountLeadingZerosOp
50+
// Handled by mathToLLVM: math::CountTrailingZerosOp
51+
// Handled by mathToLLVM: math::CgPopOp
52+
// Handled by mathToLLVM: math::FmaOp
53+
// FIXME: math::IPowIOp
54+
// FIXME: math::FPowIOp
55+
// Handled by mathToLLVM: math::RoundEvenOp
56+
// Handled by mathToLLVM: math::RoundOp
57+
// Handled by mathToLLVM: math::TruncOp
58+
59+
populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
60+
"__ocml_fabs_f64");
61+
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
62+
"__ocml_acos_f64");
63+
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
64+
"__ocml_acosh_f64");
65+
populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
66+
"__ocml_asin_f64");
67+
populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
68+
"__ocml_asinh_f64");
69+
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
70+
"__ocml_atan_f64");
71+
populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
72+
"__ocml_atanh_f64");
73+
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
74+
"__ocml_atan2_f64");
75+
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
76+
"__ocml_cbrt_f64");
77+
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
78+
"__ocml_ceil_f64");
79+
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
80+
"__ocml_cos_f64");
81+
populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
82+
"__ocml_cosh_f64");
83+
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
84+
"__ocml_sinh_f64");
85+
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
86+
"__ocml_exp_f64");
87+
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
88+
"__ocml_exp2_f64");
89+
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
90+
"__ocml_expm1_f64");
91+
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
92+
"__ocml_floor_f64");
93+
// FIXME: Different pass or new op in math?
94+
// populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
95+
// "__ocml_fmod_f64");
96+
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
97+
"__ocml_log_f64");
98+
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
99+
"__ocml_log10_f64");
100+
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
101+
"__ocml_log1p_f64");
102+
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
103+
"__ocml_log2_f64");
104+
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
105+
"__ocml_pow_f64");
106+
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
107+
"__ocml_rsqrt_f64");
108+
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
109+
"__ocml_sin_f64");
110+
populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
111+
"__ocml_sqrt_f64");
112+
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
113+
"__ocml_tanh_f64");
114+
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
115+
"__ocml_tan_f64");
116+
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
117+
"__ocml_erf_f64");
118+
}
119+
120+
namespace {
121+
struct ConvertMathToROCDLPass
122+
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
123+
ConvertMathToROCDLPass() = default;
124+
void runOnOperation() override;
125+
};
126+
} // namespace
127+
128+
void ConvertMathToROCDLPass::runOnOperation() {
129+
auto m = getOperation();
130+
MLIRContext *ctx = m.getContext();
131+
132+
133+
RewritePatternSet patterns(&getContext());
134+
LowerToLLVMOptions options(ctx, DataLayout(m));
135+
LLVMTypeConverter converter(ctx, options);
136+
populateMathToROCDLConversionPatterns(converter, patterns);
137+
138+
ConversionTarget target(getContext());
139+
target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
140+
vector::VectorDialect, LLVM::LLVMDialect>();
141+
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
142+
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
143+
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
144+
LLVM::SqrtOp>();
145+
if (failed(applyPartialConversion(m, target, std::move(patterns))))
146+
signalPassFailure();
147+
}

0 commit comments

Comments
 (0)