Skip to content

Commit 16388fd

Browse files
[WIP] 1:N conversion pattern
1 parent 9f24c14 commit 16388fd

File tree

5 files changed

+133
-15
lines changed

5 files changed

+133
-15
lines changed

mlir/artifacts/jq-linux64

3.77 MB
Binary file not shown.

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ template <typename SourceOp>
143143
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144144
public:
145145
using OpAdaptor = typename SourceOp::Adaptor;
146+
using OneToNOpAdaptor =
147+
typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
146148

147149
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
148150
PatternBenefit benefit = 1)
@@ -153,17 +155,29 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
153155
/// Wrappers around the RewritePattern methods that pass the derived op type.
154156
void rewrite(Operation *op, ArrayRef<Value> operands,
155157
ConversionPatternRewriter &rewriter) const final {
156-
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
157-
rewriter);
158+
auto sourceOp = cast<SourceOp>(op);
159+
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160+
}
161+
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
162+
ConversionPatternRewriter &rewriter) const final {
163+
auto sourceOp = cast<SourceOp>(op);
164+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
158165
}
159166
LogicalResult match(Operation *op) const final {
160167
return match(cast<SourceOp>(op));
161168
}
162169
LogicalResult
163170
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
164171
ConversionPatternRewriter &rewriter) const final {
165-
return matchAndRewrite(cast<SourceOp>(op),
166-
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
172+
auto sourceOp = cast<SourceOp>(op);
173+
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
174+
}
175+
LogicalResult
176+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
177+
ConversionPatternRewriter &rewriter) const final {
178+
auto sourceOp = cast<SourceOp>(op);
179+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
180+
rewriter);
167181
}
168182

169183
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
175189
ConversionPatternRewriter &rewriter) const {
176190
llvm_unreachable("must override rewrite or matchAndRewrite");
177191
}
192+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193+
ConversionPatternRewriter &rewriter) const {
194+
SmallVector<Value> oneToOneOperands =
195+
getOneToOneAdaptorOperands(adaptor.getOperands());
196+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197+
}
178198
virtual LogicalResult
179199
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
180200
ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
183203
rewrite(op, adaptor, rewriter);
184204
return success();
185205
}
206+
virtual LogicalResult
207+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
208+
ConversionPatternRewriter &rewriter) const {
209+
SmallVector<Value> oneToOneOperands =
210+
getOneToOneAdaptorOperands(adaptor.getOperands());
211+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
212+
}
186213

187214
private:
188215
using ConvertToLLVMPattern::match;

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,10 @@ class ConversionPattern : public RewritePattern {
467467
ConversionPatternRewriter &rewriter) const {
468468
llvm_unreachable("unimplemented rewrite");
469469
}
470+
virtual void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
471+
ConversionPatternRewriter &rewriter) const {
472+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
473+
}
470474

471475
/// Hook for derived classes to implement combined matching and rewriting.
472476
virtual LogicalResult
@@ -477,6 +481,11 @@ class ConversionPattern : public RewritePattern {
477481
rewrite(op, operands, rewriter);
478482
return success();
479483
}
484+
virtual LogicalResult
485+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
486+
ConversionPatternRewriter &rewriter) const {
487+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
488+
}
480489

481490
/// Attempt to match and rewrite the IR root at the specified operation.
482491
LogicalResult matchAndRewrite(Operation *op,
@@ -504,6 +513,9 @@ class ConversionPattern : public RewritePattern {
504513
: RewritePattern(std::forward<Args>(args)...),
505514
typeConverter(&typeConverter) {}
506515

516+
static SmallVector<Value>
517+
getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands);
518+
507519
protected:
508520
/// An optional type converter for use by this pattern.
509521
const TypeConverter *typeConverter = nullptr;
@@ -519,6 +531,8 @@ template <typename SourceOp>
519531
class OpConversionPattern : public ConversionPattern {
520532
public:
521533
using OpAdaptor = typename SourceOp::Adaptor;
534+
using OneToNOpAdaptor =
535+
typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
522536

523537
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
524538
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -537,12 +551,24 @@ class OpConversionPattern : public ConversionPattern {
537551
auto sourceOp = cast<SourceOp>(op);
538552
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
539553
}
554+
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
555+
ConversionPatternRewriter &rewriter) const final {
556+
auto sourceOp = cast<SourceOp>(op);
557+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
558+
}
540559
LogicalResult
541560
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
542561
ConversionPatternRewriter &rewriter) const final {
543562
auto sourceOp = cast<SourceOp>(op);
544563
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
545564
}
565+
LogicalResult
566+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
567+
ConversionPatternRewriter &rewriter) const final {
568+
auto sourceOp = cast<SourceOp>(op);
569+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
570+
rewriter);
571+
}
546572

547573
/// Rewrite and Match methods that operate on the SourceOp type. These must be
548574
/// overridden by the derived pattern class.
@@ -553,6 +579,12 @@ class OpConversionPattern : public ConversionPattern {
553579
ConversionPatternRewriter &rewriter) const {
554580
llvm_unreachable("must override matchAndRewrite or a rewrite method");
555581
}
582+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
583+
ConversionPatternRewriter &rewriter) const {
584+
SmallVector<Value> oneToOneOperands =
585+
getOneToOneAdaptorOperands(adaptor.getOperands());
586+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
587+
}
556588
virtual LogicalResult
557589
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
558590
ConversionPatternRewriter &rewriter) const {
@@ -561,6 +593,13 @@ class OpConversionPattern : public ConversionPattern {
561593
rewrite(op, adaptor, rewriter);
562594
return success();
563595
}
596+
virtual LogicalResult
597+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
598+
ConversionPatternRewriter &rewriter) const {
599+
SmallVector<Value> oneToOneOperands =
600+
getOneToOneAdaptorOperands(adaptor.getOperands());
601+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
602+
}
564603

565604
private:
566605
using ConversionPattern::matchAndRewrite;
@@ -586,18 +625,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
586625
ConversionPatternRewriter &rewriter) const final {
587626
rewrite(cast<SourceOp>(op), operands, rewriter);
588627
}
628+
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
629+
ConversionPatternRewriter &rewriter) const final {
630+
rewrite(cast<SourceOp>(op), operands, rewriter);
631+
}
589632
LogicalResult
590633
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
591634
ConversionPatternRewriter &rewriter) const final {
592635
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
593636
}
637+
LogicalResult
638+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
639+
ConversionPatternRewriter &rewriter) const final {
640+
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
641+
}
594642

595643
/// Rewrite and Match methods that operate on the SourceOp type. These must be
596644
/// overridden by the derived pattern class.
597645
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
598646
ConversionPatternRewriter &rewriter) const {
599647
llvm_unreachable("must override matchAndRewrite or a rewrite method");
600648
}
649+
virtual void rewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
650+
ConversionPatternRewriter &rewriter) const {
651+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
652+
}
601653
virtual LogicalResult
602654
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
603655
ConversionPatternRewriter &rewriter) const {
@@ -606,6 +658,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
606658
rewrite(op, operands, rewriter);
607659
return success();
608660
}
661+
virtual LogicalResult
662+
matchAndRewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
663+
ConversionPatternRewriter &rewriter) const {
664+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
665+
}
609666

610667
private:
611668
using ConversionPattern::matchAndRewrite;

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
769769
LogicalResult remapValues(StringRef valueDiagTag,
770770
std::optional<Location> inputLoc,
771771
PatternRewriter &rewriter, ValueRange values,
772-
SmallVectorImpl<Value> &remapped);
772+
SmallVector<SmallVector<Value, 1>> &remapped);
773773

774774
/// Return "true" if the given operation is ignored, and does not need to be
775775
/// converted.
@@ -1089,7 +1089,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
10891089
LogicalResult ConversionPatternRewriterImpl::remapValues(
10901090
StringRef valueDiagTag, std::optional<Location> inputLoc,
10911091
PatternRewriter &rewriter, ValueRange values,
1092-
SmallVectorImpl<Value> &remapped) {
1092+
SmallVector<SmallVector<Value, 1>> &remapped) {
10931093
remapped.reserve(llvm::size(values));
10941094

10951095
for (const auto &it : llvm::enumerate(values)) {
@@ -1101,7 +1101,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11011101
// The current pattern does not have a type converter. I.e., it does not
11021102
// distinguish between legal and illegal types. For each operand, simply
11031103
// pass through the most recently mapped value.
1104-
remapped.push_back(mapping.lookupOrDefault(operand));
1104+
remapped.push_back({mapping.lookupOrDefault(operand)});
11051105
continue;
11061106
}
11071107

@@ -1123,7 +1123,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11231123
// improvements to the `ConversionValueMapping` (to be able to store 1:N
11241124
// mappings) and to the `ConversionPattern` adaptor handling (to be able
11251125
// to pass multiple remapped values for a single operand to the adaptor).
1126-
remapped.push_back(mapping.lookupOrDefault(operand));
1126+
remapped.push_back({mapping.lookupOrDefault(operand)});
11271127
continue;
11281128
}
11291129

@@ -1143,7 +1143,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11431143
mapping.map(newOperand, castValue);
11441144
newOperand = castValue;
11451145
}
1146-
remapped.push_back(newOperand);
1146+
remapped.push_back({newOperand});
11471147
}
11481148
return success();
11491149
}
@@ -1523,20 +1523,28 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
15231523
}
15241524

15251525
Value ConversionPatternRewriter::getRemappedValue(Value key) {
1526-
SmallVector<Value> remappedValues;
1526+
SmallVector<SmallVector<Value, 1>> remappedValues;
15271527
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
15281528
remappedValues)))
15291529
return nullptr;
1530-
return remappedValues.front();
1530+
assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1531+
return remappedValues.front().front();
15311532
}
15321533

15331534
LogicalResult
15341535
ConversionPatternRewriter::getRemappedValues(ValueRange keys,
15351536
SmallVectorImpl<Value> &results) {
15361537
if (keys.empty())
15371538
return success();
1538-
return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1539-
results);
1539+
SmallVector<SmallVector<Value, 1>> remapped;
1540+
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1541+
remapped)))
1542+
return failure();
1543+
for (const auto &values : remapped) {
1544+
assert(values.size() == 1 && "1:N conversion not supported");
1545+
results.push_back(values.front());
1546+
}
1547+
return success();
15401548
}
15411549

15421550
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1630,6 +1638,16 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
16301638
// ConversionPattern
16311639
//===----------------------------------------------------------------------===//
16321640

1641+
SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
1642+
ArrayRef<ArrayRef<Value>> operands) {
1643+
SmallVector<Value> oneToOneOperands;
1644+
oneToOneOperands.reserve(operands.size());
1645+
for (ArrayRef<Value> operand : operands) {
1646+
assert(operand.size() == 1 && "pattern does not support 1:N conversion");
1647+
oneToOneOperands.push_back(operand.front());
1648+
}
1649+
}
1650+
16331651
LogicalResult
16341652
ConversionPattern::matchAndRewrite(Operation *op,
16351653
PatternRewriter &rewriter) const {
@@ -1641,11 +1659,16 @@ ConversionPattern::matchAndRewrite(Operation *op,
16411659
getTypeConverter());
16421660

16431661
// Remap the operands of the operation.
1644-
SmallVector<Value, 4> operands;
1662+
SmallVector<SmallVector<Value, 1>> remapped;
16451663
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1646-
op->getOperands(), operands))) {
1664+
op->getOperands(), remapped))) {
16471665
return failure();
16481666
}
1667+
SmallVector<Value, 4> operands;
1668+
for (const auto &values : remapped) {
1669+
assert(values.size() == 1 && "1:N conversion not supported");
1670+
operands.push_back(values.front());
1671+
}
16491672
return matchAndRewrite(op, operands, dialectRewriter);
16501673
}
16511674

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4282,6 +4282,17 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
42824282
}
42834283
}
42844284

4285+
{
4286+
SmallVector<MethodParameter> paramList;
4287+
paramList.emplace_back("RangeT", "values");
4288+
paramList.emplace_back("const " + op.getGenericAdaptorName() + "Base &",
4289+
"base");
4290+
auto *constructor =
4291+
genericAdaptor.addConstructor<Method::Inline>(paramList);
4292+
constructor->addMemberInitializer("Base", "base");
4293+
constructor->addMemberInitializer("odsOperands", "values");
4294+
}
4295+
42854296
// Create constructors constructing the adaptor from an instance of the op.
42864297
// This takes the attributes, properties and regions from the op instance
42874298
// and the value range from the parameter.

0 commit comments

Comments
 (0)