Skip to content

Commit 6cd7979

Browse files
[WIP] 1:N conversion pattern
1 parent 5621929 commit 6cd7979

File tree

4 files changed

+122
-15
lines changed

4 files changed

+122
-15
lines changed

mlir/artifacts/jq-linux64

3.77 MB
Binary file not shown.

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ template <typename SourceOp>
143143
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144144
public:
145145
using OpAdaptor = typename SourceOp::Adaptor;
146+
using OneToNOpAdaptor =
147+
typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
146148

147149
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
148150
PatternBenefit benefit = 1)
@@ -153,17 +155,29 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
153155
/// Wrappers around the RewritePattern methods that pass the derived op type.
154156
void rewrite(Operation *op, ArrayRef<Value> operands,
155157
ConversionPatternRewriter &rewriter) const final {
156-
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
157-
rewriter);
158+
auto sourceOp = cast<SourceOp>(op);
159+
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160+
}
161+
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
162+
ConversionPatternRewriter &rewriter) const final {
163+
auto sourceOp = cast<SourceOp>(op);
164+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
158165
}
159166
LogicalResult match(Operation *op) const final {
160167
return match(cast<SourceOp>(op));
161168
}
162169
LogicalResult
163170
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
164171
ConversionPatternRewriter &rewriter) const final {
165-
return matchAndRewrite(cast<SourceOp>(op),
166-
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
172+
auto sourceOp = cast<SourceOp>(op);
173+
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
174+
}
175+
LogicalResult
176+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
177+
ConversionPatternRewriter &rewriter) const final {
178+
auto sourceOp = cast<SourceOp>(op);
179+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
180+
rewriter);
167181
}
168182

169183
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
175189
ConversionPatternRewriter &rewriter) const {
176190
llvm_unreachable("must override rewrite or matchAndRewrite");
177191
}
192+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193+
ConversionPatternRewriter &rewriter) const {
194+
SmallVector<Value> oneToOneOperands =
195+
getOneToOneAdaptorOperands(adaptor.getOperands());
196+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197+
}
178198
virtual LogicalResult
179199
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
180200
ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
183203
rewrite(op, adaptor, rewriter);
184204
return success();
185205
}
206+
virtual LogicalResult
207+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
208+
ConversionPatternRewriter &rewriter) const {
209+
SmallVector<Value> oneToOneOperands =
210+
getOneToOneAdaptorOperands(adaptor.getOperands());
211+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
212+
}
186213

187214
private:
188215
using ConvertToLLVMPattern::match;

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
537537
ConversionPatternRewriter &rewriter) const {
538538
llvm_unreachable("unimplemented rewrite");
539539
}
540+
virtual void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
541+
ConversionPatternRewriter &rewriter) const {
542+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
543+
}
540544

541545
/// Hook for derived classes to implement combined matching and rewriting.
542546
virtual LogicalResult
@@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
547551
rewrite(op, operands, rewriter);
548552
return success();
549553
}
554+
virtual LogicalResult
555+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
556+
ConversionPatternRewriter &rewriter) const {
557+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
558+
}
550559

551560
/// Attempt to match and rewrite the IR root at the specified operation.
552561
LogicalResult matchAndRewrite(Operation *op,
@@ -574,6 +583,9 @@ class ConversionPattern : public RewritePattern {
574583
: RewritePattern(std::forward<Args>(args)...),
575584
typeConverter(&typeConverter) {}
576585

586+
static SmallVector<Value>
587+
getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands);
588+
577589
protected:
578590
/// An optional type converter for use by this pattern.
579591
const TypeConverter *typeConverter = nullptr;
@@ -589,6 +601,8 @@ template <typename SourceOp>
589601
class OpConversionPattern : public ConversionPattern {
590602
public:
591603
using OpAdaptor = typename SourceOp::Adaptor;
604+
using OneToNOpAdaptor =
605+
typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
592606

593607
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
594608
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +621,24 @@ class OpConversionPattern : public ConversionPattern {
607621
auto sourceOp = cast<SourceOp>(op);
608622
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
609623
}
624+
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
625+
ConversionPatternRewriter &rewriter) const final {
626+
auto sourceOp = cast<SourceOp>(op);
627+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
628+
}
610629
LogicalResult
611630
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
612631
ConversionPatternRewriter &rewriter) const final {
613632
auto sourceOp = cast<SourceOp>(op);
614633
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
615634
}
635+
LogicalResult
636+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
637+
ConversionPatternRewriter &rewriter) const final {
638+
auto sourceOp = cast<SourceOp>(op);
639+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
640+
rewriter);
641+
}
616642

617643
/// Rewrite and Match methods that operate on the SourceOp type. These must be
618644
/// overridden by the derived pattern class.
@@ -623,6 +649,12 @@ class OpConversionPattern : public ConversionPattern {
623649
ConversionPatternRewriter &rewriter) const {
624650
llvm_unreachable("must override matchAndRewrite or a rewrite method");
625651
}
652+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
653+
ConversionPatternRewriter &rewriter) const {
654+
SmallVector<Value> oneToOneOperands =
655+
getOneToOneAdaptorOperands(adaptor.getOperands());
656+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
657+
}
626658
virtual LogicalResult
627659
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
628660
ConversionPatternRewriter &rewriter) const {
@@ -631,6 +663,13 @@ class OpConversionPattern : public ConversionPattern {
631663
rewrite(op, adaptor, rewriter);
632664
return success();
633665
}
666+
virtual LogicalResult
667+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
668+
ConversionPatternRewriter &rewriter) const {
669+
SmallVector<Value> oneToOneOperands =
670+
getOneToOneAdaptorOperands(adaptor.getOperands());
671+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
672+
}
634673

635674
private:
636675
using ConversionPattern::matchAndRewrite;
@@ -656,18 +695,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
656695
ConversionPatternRewriter &rewriter) const final {
657696
rewrite(cast<SourceOp>(op), operands, rewriter);
658697
}
698+
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
699+
ConversionPatternRewriter &rewriter) const final {
700+
rewrite(cast<SourceOp>(op), operands, rewriter);
701+
}
659702
LogicalResult
660703
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
661704
ConversionPatternRewriter &rewriter) const final {
662705
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
663706
}
707+
LogicalResult
708+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
709+
ConversionPatternRewriter &rewriter) const final {
710+
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
711+
}
664712

665713
/// Rewrite and Match methods that operate on the SourceOp type. These must be
666714
/// overridden by the derived pattern class.
667715
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
668716
ConversionPatternRewriter &rewriter) const {
669717
llvm_unreachable("must override matchAndRewrite or a rewrite method");
670718
}
719+
virtual void rewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
720+
ConversionPatternRewriter &rewriter) const {
721+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
722+
}
671723
virtual LogicalResult
672724
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
673725
ConversionPatternRewriter &rewriter) const {
@@ -676,6 +728,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
676728
rewrite(op, operands, rewriter);
677729
return success();
678730
}
731+
virtual LogicalResult
732+
matchAndRewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
733+
ConversionPatternRewriter &rewriter) const {
734+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
735+
}
679736

680737
private:
681738
using ConversionPattern::matchAndRewrite;

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
776776
LogicalResult remapValues(StringRef valueDiagTag,
777777
std::optional<Location> inputLoc,
778778
PatternRewriter &rewriter, ValueRange values,
779-
SmallVectorImpl<Value> &remapped);
779+
SmallVector<SmallVector<Value, 1>> &remapped);
780780

781781
/// Return "true" if the given operation is ignored, and does not need to be
782782
/// converted.
@@ -1099,7 +1099,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
10991099
LogicalResult ConversionPatternRewriterImpl::remapValues(
11001100
StringRef valueDiagTag, std::optional<Location> inputLoc,
11011101
PatternRewriter &rewriter, ValueRange values,
1102-
SmallVectorImpl<Value> &remapped) {
1102+
SmallVector<SmallVector<Value, 1>> &remapped) {
11031103
remapped.reserve(llvm::size(values));
11041104

11051105
for (const auto &it : llvm::enumerate(values)) {
@@ -1111,7 +1111,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11111111
// The current pattern does not have a type converter. I.e., it does not
11121112
// distinguish between legal and illegal types. For each operand, simply
11131113
// pass through the most recently mapped value.
1114-
remapped.push_back(mapping.lookupOrDefault(operand));
1114+
remapped.push_back({mapping.lookupOrDefault(operand)});
11151115
continue;
11161116
}
11171117

@@ -1133,7 +1133,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11331133
// improvements to the `ConversionValueMapping` (to be able to store 1:N
11341134
// mappings) and to the `ConversionPattern` adaptor handling (to be able
11351135
// to pass multiple remapped values for a single operand to the adaptor).
1136-
remapped.push_back(mapping.lookupOrDefault(operand));
1136+
remapped.push_back({mapping.lookupOrDefault(operand)});
11371137
continue;
11381138
}
11391139

@@ -1153,7 +1153,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11531153
mapping.map(newOperand, castValue);
11541154
newOperand = castValue;
11551155
}
1156-
remapped.push_back(newOperand);
1156+
remapped.push_back({newOperand});
11571157
}
11581158
return success();
11591159
}
@@ -1541,20 +1541,28 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
15411541
}
15421542

15431543
Value ConversionPatternRewriter::getRemappedValue(Value key) {
1544-
SmallVector<Value> remappedValues;
1544+
SmallVector<SmallVector<Value, 1>> remappedValues;
15451545
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
15461546
remappedValues)))
15471547
return nullptr;
1548-
return remappedValues.front();
1548+
assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
1549+
return remappedValues.front().front();
15491550
}
15501551

15511552
LogicalResult
15521553
ConversionPatternRewriter::getRemappedValues(ValueRange keys,
15531554
SmallVectorImpl<Value> &results) {
15541555
if (keys.empty())
15551556
return success();
1556-
return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1557-
results);
1557+
SmallVector<SmallVector<Value, 1>> remapped;
1558+
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1559+
remapped)))
1560+
return failure();
1561+
for (const auto &values : remapped) {
1562+
assert(values.size() == 1 && "1:N conversion not supported");
1563+
results.push_back(values.front());
1564+
}
1565+
return success();
15581566
}
15591567

15601568
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1648,6 +1656,16 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
16481656
// ConversionPattern
16491657
//===----------------------------------------------------------------------===//
16501658

1659+
SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
1660+
ArrayRef<ArrayRef<Value>> operands) {
1661+
SmallVector<Value> oneToOneOperands;
1662+
oneToOneOperands.reserve(operands.size());
1663+
for (ArrayRef<Value> operand : operands) {
1664+
assert(operand.size() == 1 && "pattern does not support 1:N conversion");
1665+
oneToOneOperands.push_back(operand.front());
1666+
}
1667+
}
1668+
16511669
LogicalResult
16521670
ConversionPattern::matchAndRewrite(Operation *op,
16531671
PatternRewriter &rewriter) const {
@@ -1659,11 +1677,16 @@ ConversionPattern::matchAndRewrite(Operation *op,
16591677
getTypeConverter());
16601678

16611679
// Remap the operands of the operation.
1662-
SmallVector<Value, 4> operands;
1680+
SmallVector<SmallVector<Value, 1>> remapped;
16631681
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1664-
op->getOperands(), operands))) {
1682+
op->getOperands(), remapped))) {
16651683
return failure();
16661684
}
1685+
SmallVector<Value, 4> operands;
1686+
for (const auto &values : remapped) {
1687+
assert(values.size() == 1 && "1:N conversion not supported");
1688+
operands.push_back(values.front());
1689+
}
16671690
return matchAndRewrite(op, operands, dialectRewriter);
16681691
}
16691692

0 commit comments

Comments
 (0)