Skip to content

Commit 9519e3e

Browse files
authored
[mlir] support dialect attribute translation to LLVM IR (#75309)
Extend the `amendOperation` mechanism for translating dialect attributes attached to operations from another dialect when translating MLIR to LLVM IR. Previously, this mechanism would have no knowledge of the LLVM IR instructions created for the given operation, making it impossible for it to perform local modifications such as attaching operation-level metadata. Collect instructions inserted by the LLVM IR builder and pass them to `amendOperation`.
1 parent 133de6c commit 9519e3e

File tree

8 files changed

+194
-26
lines changed

8 files changed

+194
-26
lines changed

mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Support/LogicalResult.h"
1919

2020
namespace llvm {
21+
class Instruction;
2122
class IRBuilderBase;
2223
} // namespace llvm
2324

@@ -52,7 +53,8 @@ class LLVMTranslationDialectInterface
5253
/// translation results and amend the corresponding IR constructs. Does
5354
/// nothing and succeeds by default.
5455
virtual LogicalResult
55-
amendOperation(Operation *op, NamedAttribute attribute,
56+
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
57+
NamedAttribute attribute,
5658
LLVM::ModuleTranslation &moduleTranslation) const {
5759
return success();
5860
}
@@ -78,11 +80,13 @@ class LLVMTranslationInterface
7880
/// Acts on the given operation using the interface implemented by the dialect
7981
/// of one of the operation's dialect attributes.
8082
virtual LogicalResult
81-
amendOperation(Operation *op, NamedAttribute attribute,
83+
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
84+
NamedAttribute attribute,
8285
LLVM::ModuleTranslation &moduleTranslation) const {
8386
if (const LLVMTranslationDialectInterface *iface =
8487
getInterfaceFor(attribute.getNameDialect())) {
85-
return iface->amendOperation(op, attribute, moduleTranslation);
88+
return iface->amendOperation(op, instructions, attribute,
89+
moduleTranslation);
8690
}
8791
return success();
8892
}

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ class ModuleTranslation {
209209
/// PHI nodes are constructed for block arguments but are _not_ connected to
210210
/// the predecessors that may not exist yet.
211211
LogicalResult convertBlock(Block &bb, bool ignoreArguments,
212-
llvm::IRBuilderBase &builder);
212+
llvm::IRBuilderBase &builder) {
213+
return convertBlockImpl(bb, ignoreArguments, builder,
214+
/*recordInsertions=*/false);
215+
}
213216

214217
/// Gets the named metadata in the LLVM IR module being constructed, creating
215218
/// it if it does not exist.
@@ -299,12 +302,16 @@ class ModuleTranslation {
299302
~ModuleTranslation();
300303

301304
/// Converts individual components.
302-
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder);
305+
LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder,
306+
bool recordInsertions = false);
303307
LogicalResult convertFunctionSignatures();
304308
LogicalResult convertFunctions();
305309
LogicalResult convertComdats();
306310
LogicalResult convertGlobals();
307311
LogicalResult convertOneFunction(LLVMFuncOp func);
312+
LogicalResult convertBlockImpl(Block &bb, bool ignoreArguments,
313+
llvm::IRBuilderBase &builder,
314+
bool recordInsertions);
308315

309316
/// Returns the LLVM metadata corresponding to the given mlir LLVM dialect
310317
/// TBAATagAttr.
@@ -315,7 +322,9 @@ class ModuleTranslation {
315322
LogicalResult createTBAAMetadata();
316323

317324
/// Translates dialect attributes attached to the given operation.
318-
LogicalResult convertDialectAttributes(Operation *op);
325+
LogicalResult
326+
convertDialectAttributes(Operation *op,
327+
ArrayRef<llvm::Instruction *> instructions);
319328

320329
/// Translates parameter attributes and adds them to the returned AttrBuilder.
321330
llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ class NVVMDialectLLVMIRTranslationInterface
142142

143143
/// Attaches module-level metadata for functions marked as kernels.
144144
LogicalResult
145-
amendOperation(Operation *op, NamedAttribute attribute,
145+
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
146+
NamedAttribute attribute,
146147
LLVM::ModuleTranslation &moduleTranslation) const final {
147148
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
148149
if (!func)

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,14 +2572,16 @@ class OpenMPDialectLLVMIRTranslationInterface
25722572
/// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime
25732573
/// calls, or operation amendments
25742574
LogicalResult
2575-
amendOperation(Operation *op, NamedAttribute attribute,
2575+
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
2576+
NamedAttribute attribute,
25762577
LLVM::ModuleTranslation &moduleTranslation) const final;
25772578
};
25782579

25792580
} // namespace
25802581

25812582
LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
2582-
Operation *op, NamedAttribute attribute,
2583+
Operation *op, ArrayRef<llvm::Instruction *> instructions,
2584+
NamedAttribute attribute,
25832585
LLVM::ModuleTranslation &moduleTranslation) const {
25842586
return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
25852587
attribute.getName())

mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ class ROCDLDialectLLVMIRTranslationInterface
8181

8282
/// Attaches module-level metadata for functions marked as kernels.
8383
LogicalResult
84-
amendOperation(Operation *op, NamedAttribute attribute,
84+
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
85+
NamedAttribute attribute,
8586
LLVM::ModuleTranslation &moduleTranslation) const final {
8687
if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
8788
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 127 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,113 @@ using namespace mlir::LLVM::detail;
5959

6060
#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
6161

62+
namespace {
63+
/// A customized inserter for LLVM's IRBuilder that captures all LLVM IR
64+
/// instructions that are created for future reference.
65+
///
66+
/// This is intended to be used with the `CollectionScope` RAII object:
67+
///
68+
/// llvm::IRBuilder<..., InstructionCapturingInserter> builder;
69+
/// {
70+
/// InstructionCapturingInserter::CollectionScope scope(builder);
71+
/// // Call IRBuilder methods as usual.
72+
///
73+
/// // This will return a list of all instructions created by the builder,
74+
/// // in order of creation.
75+
/// builder.getInserter().getCapturedInstructions();
76+
/// }
77+
/// // This will return an empty list.
78+
/// builder.getInserter().getCapturedInstructions();
79+
///
80+
/// The capturing functionality is _disabled_ by default for performance
81+
/// consideration. It needs to be explicitly enabled, which is achieved by
82+
/// creating a `CollectionScope`.
83+
class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter {
84+
public:
85+
/// Constructs the inserter.
86+
InstructionCapturingInserter()
87+
: llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) {
88+
if (LLVM_LIKELY(enabled))
89+
capturedInstructions.push_back(instruction);
90+
}) {}
91+
92+
/// Returns the list of LLVM IR instructions captured since the last cleanup.
93+
ArrayRef<llvm::Instruction *> getCapturedInstructions() const {
94+
return capturedInstructions;
95+
}
96+
97+
/// Clears the list of captured LLVM IR instructions.
98+
void clearCapturedInstructions() { capturedInstructions.clear(); }
99+
100+
/// RAII object enabling the capture of created LLVM IR instructions.
101+
class CollectionScope {
102+
public:
103+
/// Creates the scope for the given inserter.
104+
CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing);
105+
106+
/// Ends the scope.
107+
~CollectionScope();
108+
109+
ArrayRef<llvm::Instruction *> getCapturedInstructions() {
110+
if (!inserter)
111+
return {};
112+
return inserter->getCapturedInstructions();
113+
}
114+
115+
private:
116+
/// Back reference to the inserter.
117+
InstructionCapturingInserter *inserter = nullptr;
118+
119+
/// List of instructions in the inserter prior to this scope.
120+
SmallVector<llvm::Instruction *> previouslyCollectedInstructions;
121+
122+
/// Whether the inserter was enabled prior to this scope.
123+
bool wasEnabled;
124+
};
125+
126+
/// Enable or disable the capturing mechanism.
127+
void setEnabled(bool enabled = true) { this->enabled = enabled; }
128+
129+
private:
130+
/// List of captured instructions.
131+
SmallVector<llvm::Instruction *> capturedInstructions;
132+
133+
/// Whether the collection is enabled.
134+
bool enabled = false;
135+
};
136+
137+
using CapturingIRBuilder =
138+
llvm::IRBuilder<llvm::ConstantFolder, InstructionCapturingInserter>;
139+
} // namespace
140+
141+
InstructionCapturingInserter::CollectionScope::CollectionScope(
142+
llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) {
143+
144+
if (!isBuilderCapturing)
145+
return;
146+
147+
auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder);
148+
inserter = &capturingIRBuilder.getInserter();
149+
wasEnabled = inserter->enabled;
150+
if (wasEnabled)
151+
previouslyCollectedInstructions.swap(inserter->capturedInstructions);
152+
inserter->setEnabled(true);
153+
}
154+
155+
InstructionCapturingInserter::CollectionScope::~CollectionScope() {
156+
if (!inserter)
157+
return;
158+
159+
previouslyCollectedInstructions.swap(inserter->capturedInstructions);
160+
// If collection was enabled (likely in another, surrounding scope), keep
161+
// the instructions collected in this scope.
162+
if (wasEnabled) {
163+
llvm::append_range(inserter->capturedInstructions,
164+
previouslyCollectedInstructions);
165+
}
166+
inserter->setEnabled(wasEnabled);
167+
}
168+
62169
/// Translates the given data layout spec attribute to the LLVM IR data layout.
63170
/// Only integer, float, pointer and endianness entries are currently supported.
64171
static FailureOr<llvm::DataLayout>
@@ -631,21 +738,23 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
631738

632739
/// Given a single MLIR operation, create the corresponding LLVM IR operation
633740
/// using the `builder`.
634-
LogicalResult
635-
ModuleTranslation::convertOperation(Operation &op,
636-
llvm::IRBuilderBase &builder) {
741+
LogicalResult ModuleTranslation::convertOperation(Operation &op,
742+
llvm::IRBuilderBase &builder,
743+
bool recordInsertions) {
637744
const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op);
638745
if (!opIface)
639746
return op.emitError("cannot be converted to LLVM IR: missing "
640747
"`LLVMTranslationDialectInterface` registration for "
641748
"dialect for op: ")
642749
<< op.getName();
643750

751+
InstructionCapturingInserter::CollectionScope scope(builder,
752+
recordInsertions);
644753
if (failed(opIface->convertOperation(&op, builder, *this)))
645754
return op.emitError("LLVM Translation failed for operation: ")
646755
<< op.getName();
647756

648-
return convertDialectAttributes(&op);
757+
return convertDialectAttributes(&op, scope.getCapturedInstructions());
649758
}
650759

651760
/// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
@@ -655,8 +764,10 @@ ModuleTranslation::convertOperation(Operation &op,
655764
/// been created for `bb` and included in the block mapping. Inserts new
656765
/// instructions at the end of the block and leaves `builder` in a state
657766
/// suitable for further insertion into the end of the block.
658-
LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
659-
llvm::IRBuilderBase &builder) {
767+
LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
768+
bool ignoreArguments,
769+
llvm::IRBuilderBase &builder,
770+
bool recordInsertions) {
660771
builder.SetInsertPoint(lookupBlock(&bb));
661772
auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
662773

@@ -687,7 +798,7 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
687798
builder.SetCurrentDebugLocation(
688799
debugTranslation->translateLoc(op.getLoc(), subprogram));
689800

690-
if (failed(convertOperation(op, builder)))
801+
if (failed(convertOperation(op, builder, recordInsertions)))
691802
return failure();
692803

693804
// Set the branch weight metadata on the translated instruction.
@@ -844,7 +955,7 @@ LogicalResult ModuleTranslation::convertGlobals() {
844955
}
845956

846957
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
847-
if (failed(convertDialectAttributes(op)))
958+
if (failed(convertDialectAttributes(op, {})))
848959
return failure();
849960

850961
// Finally, update the compile units their respective sets of global variables
@@ -997,8 +1108,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
9971108
// converted before uses.
9981109
auto blocks = getTopologicallySortedBlocks(func.getBody());
9991110
for (Block *bb : blocks) {
1000-
llvm::IRBuilder<> builder(llvmContext);
1001-
if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
1111+
CapturingIRBuilder builder(llvmContext);
1112+
if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder,
1113+
/*recordInsertions=*/true)))
10021114
return failure();
10031115
}
10041116

@@ -1007,12 +1119,13 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
10071119
detail::connectPHINodes(func.getBody(), *this);
10081120

10091121
// Finally, convert dialect attributes attached to the function.
1010-
return convertDialectAttributes(func);
1122+
return convertDialectAttributes(func, {});
10111123
}
10121124

1013-
LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
1125+
LogicalResult ModuleTranslation::convertDialectAttributes(
1126+
Operation *op, ArrayRef<llvm::Instruction *> instructions) {
10141127
for (NamedAttribute attribute : op->getDialectAttrs())
1015-
if (failed(iface.amendOperation(op, attribute, *this)))
1128+
if (failed(iface.amendOperation(op, instructions, attribute, *this)))
10161129
return failure();
10171130
return success();
10181131
}
@@ -1134,7 +1247,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
11341247
// Do not convert external functions, but do process dialect attributes
11351248
// attached to them.
11361249
if (function.isExternal()) {
1137-
if (failed(convertDialectAttributes(function)))
1250+
if (failed(convertDialectAttributes(function, {})))
11381251
return failure();
11391252
continue;
11401253
}

mlir/test/Target/LLVMIR/test.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,27 @@ module {
1616
module attributes {test.discardable_mod_attr = true} {}
1717

1818
// CHECK: @sym_from_attr = external global i32
19+
20+
// -----
21+
22+
// CHECK-LABEL: @dialect_attr_translation
23+
llvm.func @dialect_attr_translation() {
24+
// CHECK: ret void, !annotation ![[MD_ID:.+]]
25+
llvm.return {test.add_annotation}
26+
}
27+
// CHECK: ![[MD_ID]] = !{!"annotation_from_test"}
28+
29+
// -----
30+
31+
// CHECK-LABEL: @dialect_attr_translation_multi
32+
llvm.func @dialect_attr_translation_multi(%a: i64, %b: i64, %c: i64) -> i64 {
33+
// CHECK: add {{.*}}, !annotation ![[MD_ID_ADD:.+]]
34+
// CHECK: mul {{.*}}, !annotation ![[MD_ID_MUL:.+]]
35+
// CHECK: ret {{.*}}, !annotation ![[MD_ID_RET:.+]]
36+
%ab = llvm.add %a, %b {test.add_annotation = "add"} : i64
37+
%r = llvm.mul %ab, %c {test.add_annotation = "mul"} : i64
38+
llvm.return {test.add_annotation = "ret"} %r : i64
39+
}
40+
// CHECK-DAG: ![[MD_ID_ADD]] = !{!"annotation_from_test: add"}
41+
// CHECK-DAG: ![[MD_ID_MUL]] = !{!"annotation_from_test: mul"}
42+
// CHECK-DAG: ![[MD_ID_RET]] = !{!"annotation_from_test: ret"}

mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ class TestDialectLLVMIRTranslationInterface
3232
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
3333

3434
LogicalResult
35-
amendOperation(Operation *op, NamedAttribute attribute,
35+
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
36+
NamedAttribute attribute,
3637
LLVM::ModuleTranslation &moduleTranslation) const final;
3738

3839
LogicalResult
@@ -43,7 +44,8 @@ class TestDialectLLVMIRTranslationInterface
4344
} // namespace
4445

4546
LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
46-
Operation *op, NamedAttribute attribute,
47+
Operation *op, ArrayRef<llvm::Instruction *> instructions,
48+
NamedAttribute attribute,
4749
LLVM::ModuleTranslation &moduleTranslation) const {
4850
return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
4951
attribute.getName())
@@ -70,6 +72,18 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
7072
/*sym_visibility=*/nullptr);
7173
}
7274

75+
return success();
76+
})
77+
.Case("test.add_annotation",
78+
[&](Attribute attr) {
79+
for (llvm::Instruction *instruction : instructions) {
80+
if (auto strAttr = dyn_cast<StringAttr>(attr)) {
81+
instruction->addAnnotationMetadata("annotation_from_test: " +
82+
strAttr.getValue().str());
83+
} else {
84+
instruction->addAnnotationMetadata("annotation_from_test");
85+
}
86+
}
7387
return success();
7488
})
7589
.Default([](Attribute) {

0 commit comments

Comments
 (0)