diff --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h index 142f77c976ffc..21bd191aa9dc8 100644 --- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h +++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h @@ -19,6 +19,7 @@ namespace mlir { class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; +class SymbolTable; /// Collect the default pattern to convert a FuncOp to the LLVM dialect. If /// `emitCWrappers` is set, the pattern will also produce functions @@ -31,8 +32,18 @@ void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter, /// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions /// by reference meaning the references have to remain alive during the entire /// pattern lifetime. -void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); +/// +/// The `symbolTable` parameter can be used to speed up function lookups in the +/// module. It's good to provide it, but only if we know that the patterns will +/// be applied to a single module and the symbols referenced by the symbol table +/// will not be removed and new symbols will not be added during the usage of +/// the patterns. If provided, the lookups will have O(calls) cumulative +/// runtime, otherwise O(calls * functions). The symbol table is currently not +/// needed if `converter.getOptions().useBarePtrCallConv` is `true`, but it's +/// not an error to provide it anyway. +void populateFuncToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, + const SymbolTable *symbolTable = nullptr); void registerConvertFuncToLLVMInterface(DialectRegistry ®istry); diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 7de7f3cb9e36b..d52f01880282e 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -33,6 +33,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" @@ -48,6 +49,7 @@ #include "llvm/Support/FormatVariadic.h" #include #include +#include namespace mlir { #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS @@ -601,19 +603,38 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { } }; -struct CallOpLowering : public CallOpInterfaceLowering { - using Super::Super; +class CallOpLowering : public CallOpInterfaceLowering { +public: + CallOpLowering(const LLVMTypeConverter &typeConverter, + // Can be nullptr. + const SymbolTable *symbolTable, PatternBenefit benefit = 1) + : CallOpInterfaceLowering(typeConverter, benefit), + symbolTable(symbolTable) {} LogicalResult matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { bool useBarePtrCallConv = false; - if (Operation *callee = SymbolTable::lookupNearestSymbolFrom( - callOp, callOp.getCalleeAttr())) { - useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter()); + if (getTypeConverter()->getOptions().useBarePtrCallConv) { + useBarePtrCallConv = true; + } else if (symbolTable != nullptr) { + // Fast lookup. + Operation *callee = + symbolTable->lookup(callOp.getCalleeAttr().getValue()); + useBarePtrCallConv = + callee != nullptr && callee->hasAttr(barePtrAttrName); + } else { + // Warning: This is a linear lookup. + Operation *callee = + SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr()); + useBarePtrCallConv = + callee != nullptr && callee->hasAttr(barePtrAttrName); } return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv); } + +private: + const SymbolTable *symbolTable = nullptr; }; struct CallIndirectOpLowering @@ -728,16 +749,14 @@ void mlir::populateFuncToLLVMFuncOpConversionPattern( patterns.add(converter); } -void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns) { +void mlir::populateFuncToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, + const SymbolTable *symbolTable) { populateFuncToLLVMFuncOpConversionPattern(converter, patterns); - // clang-format off - patterns.add< - CallIndirectOpLowering, - CallOpLowering, - ConstantOpLowering, - ReturnOpLowering>(converter); - // clang-format on + patterns.add(converter); + patterns.add(converter, symbolTable); + patterns.add(converter); + patterns.add(converter); } namespace { @@ -776,8 +795,15 @@ struct ConvertFuncToLLVMPass LLVMTypeConverter typeConverter(&getContext(), options, &dataLayoutAnalysis); + std::optional optSymbolTable = std::nullopt; + const SymbolTable *symbolTable = nullptr; + if (!options.useBarePtrCallConv) { + optSymbolTable.emplace(m); + symbolTable = &optSymbolTable.value(); + } + RewritePatternSet patterns(&getContext()); - populateFuncToLLVMConversionPatterns(typeConverter, patterns); + populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable); // TODO: Remove these in favor of their dedicated conversion passes. arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);