Skip to content

[mlir][Transforms] Dialect conversion: Unify materialization of value replacements #108381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 41 additions & 92 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2338,17 +2338,6 @@ struct OperationConverter {
/// remaining artifacts and complete the conversion.
LogicalResult finalize(ConversionPatternRewriter &rewriter);

/// Legalize the types of converted block arguments.
LogicalResult
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl);

/// Legalize the types of converted op results.
LogicalResult legalizeConvertedOpResultTypes(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// Dialect conversion configuration.
ConversionConfig config;

Expand Down Expand Up @@ -2512,19 +2501,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return success();
}

LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
DenseMap<Value, SmallVector<Value>> inverseMapping =
rewriterImpl.mapping.getInverse();
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
inverseMapping)))
return failure();
return success();
}

/// Finds a user of the given value, or of any other value that the given value
/// replaced, that was not replaced in the conversion process.
static Operation *findLiveUserOfReplaced(
Expand All @@ -2548,87 +2524,60 @@ static Operation *findLiveUserOfReplaced(
return nullptr;
}

LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
// Process requested operation replacements.
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
auto *opReplacement =
dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
if (!opReplacement)
continue;
Operation *op = opReplacement->getOperation();
for (OpResult result : op->getResults()) {
// If the type of this op result changed and the result is still live,
// we need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
/// Helper function that returns the replaced values and the type converter if
/// the given rewrite object is an "operation replacement" or a "block type
/// conversion" (which corresponds to a "block replacement"). Otherwise, return
/// an empty ValueRange and a null type converter pointer.
static std::pair<ValueRange, const TypeConverter *>
getReplacedValues(IRRewrite *rewrite) {
if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
return {blockRewrite->getOrigBlock()->getArguments(),
blockRewrite->getConverter()};
return {};
}

LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
DenseMap<Value, SmallVector<Value>> inverseMapping =
rewriterImpl.mapping.getInverse();

// Process requested value replacements.
for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
ValueRange replacedValues;
const TypeConverter *converter;
std::tie(replacedValues, converter) =
getReplacedValues(rewriterImpl.rewrites[i].get());
for (Value originalValue : replacedValues) {
// If the type of this value changed and the value is still live, we need
// to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(originalValue,
originalValue.getType()))
continue;
Operation *liveUser =
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
if (!liveUser)
continue;

// Legalize this result.
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
// Legalize this value replacement.
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
assert(newValue && "replacement value not found");
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
/*inputs=*/newValue, /*outputType=*/result.getType(),
opReplacement->getConverter());
rewriterImpl.mapping.map(result, castValue);
inverseMapping[castValue].push_back(result);
llvm::erase(inverseMapping[newValue], result);
MaterializationKind::Source, computeInsertPoint(newValue),
originalValue.getLoc(),
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
converter);
rewriterImpl.mapping.map(originalValue, castValue);
inverseMapping[castValue].push_back(originalValue);
llvm::erase(inverseMapping[newValue], originalValue);
}
}

return success();
}

LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl) {
// Functor used to check if all users of a value will be dead after
// conversion.
// TODO: This should probably query the inverse mapping, same as in
// `legalizeConvertedOpResultTypes`.
auto findLiveUser = [&](Value val) {
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
});
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
};
// Note: `rewrites` may be reallocated as the loop is running.
for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
++i) {
auto &rewrite = rewriterImpl.rewrites[i];
if (auto *blockTypeConversionRewrite =
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
// Process the remapping for each of the original arguments.
for (Value origArg :
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
// If the type of this argument changed and the argument is still live,
// we need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
continue;
Operation *liveUser = findLiveUser(origArg);
if (!liveUser)
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
assert(replacementValue && "replacement value not found");
Value repl = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(replacementValue),
origArg.getLoc(), /*inputs=*/replacementValue,
/*outputType=*/origArg.getType(),
blockTypeConversionRewrite->getConverter());
rewriterImpl.mapping.map(origArg, repl);
}
}
}
return success();
}

//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
Expand Down
Loading