@@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
537
537
ConversionPatternRewriter &rewriter) const {
538
538
llvm_unreachable (" unimplemented rewrite" );
539
539
}
540
+ virtual void rewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
541
+ ConversionPatternRewriter &rewriter) const {
542
+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
543
+ }
540
544
541
545
// / Hook for derived classes to implement combined matching and rewriting.
542
546
virtual LogicalResult
@@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
547
551
rewrite (op, operands, rewriter);
548
552
return success ();
549
553
}
554
+ virtual LogicalResult
555
+ matchAndRewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
556
+ ConversionPatternRewriter &rewriter) const {
557
+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
558
+ }
550
559
551
560
// / Attempt to match and rewrite the IR root at the specified operation.
552
561
LogicalResult matchAndRewrite (Operation *op,
@@ -574,6 +583,9 @@ class ConversionPattern : public RewritePattern {
574
583
: RewritePattern(std::forward<Args>(args)...),
575
584
typeConverter (&typeConverter) {}
576
585
586
+ static SmallVector<Value>
587
+ getOneToOneAdaptorOperands (ArrayRef<ArrayRef<Value>> operands);
588
+
577
589
protected:
578
590
// / An optional type converter for use by this pattern.
579
591
const TypeConverter *typeConverter = nullptr ;
@@ -589,6 +601,8 @@ template <typename SourceOp>
589
601
class OpConversionPattern : public ConversionPattern {
590
602
public:
591
603
using OpAdaptor = typename SourceOp::Adaptor;
604
+ using OneToNOpAdaptor =
605
+ typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
592
606
593
607
OpConversionPattern (MLIRContext *context, PatternBenefit benefit = 1 )
594
608
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +621,24 @@ class OpConversionPattern : public ConversionPattern {
607
621
auto sourceOp = cast<SourceOp>(op);
608
622
rewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
609
623
}
624
+ void rewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
625
+ ConversionPatternRewriter &rewriter) const final {
626
+ auto sourceOp = cast<SourceOp>(op);
627
+ rewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp), rewriter);
628
+ }
610
629
LogicalResult
611
630
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
612
631
ConversionPatternRewriter &rewriter) const final {
613
632
auto sourceOp = cast<SourceOp>(op);
614
633
return matchAndRewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
615
634
}
635
+ LogicalResult
636
+ matchAndRewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
637
+ ConversionPatternRewriter &rewriter) const final {
638
+ auto sourceOp = cast<SourceOp>(op);
639
+ return matchAndRewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp),
640
+ rewriter);
641
+ }
616
642
617
643
// / Rewrite and Match methods that operate on the SourceOp type. These must be
618
644
// / overridden by the derived pattern class.
@@ -623,6 +649,12 @@ class OpConversionPattern : public ConversionPattern {
623
649
ConversionPatternRewriter &rewriter) const {
624
650
llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
625
651
}
652
+ virtual void rewrite (SourceOp op, OneToNOpAdaptor adaptor,
653
+ ConversionPatternRewriter &rewriter) const {
654
+ SmallVector<Value> oneToOneOperands =
655
+ getOneToOneAdaptorOperands (adaptor.getOperands ());
656
+ rewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
657
+ }
626
658
virtual LogicalResult
627
659
matchAndRewrite (SourceOp op, OpAdaptor adaptor,
628
660
ConversionPatternRewriter &rewriter) const {
@@ -631,6 +663,13 @@ class OpConversionPattern : public ConversionPattern {
631
663
rewrite (op, adaptor, rewriter);
632
664
return success ();
633
665
}
666
+ virtual LogicalResult
667
+ matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
668
+ ConversionPatternRewriter &rewriter) const {
669
+ SmallVector<Value> oneToOneOperands =
670
+ getOneToOneAdaptorOperands (adaptor.getOperands ());
671
+ return matchAndRewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
672
+ }
634
673
635
674
private:
636
675
using ConversionPattern::matchAndRewrite;
@@ -656,18 +695,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
656
695
ConversionPatternRewriter &rewriter) const final {
657
696
rewrite (cast<SourceOp>(op), operands, rewriter);
658
697
}
698
+ void rewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
699
+ ConversionPatternRewriter &rewriter) const final {
700
+ rewrite (cast<SourceOp>(op), operands, rewriter);
701
+ }
659
702
LogicalResult
660
703
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
661
704
ConversionPatternRewriter &rewriter) const final {
662
705
return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
663
706
}
707
+ LogicalResult
708
+ matchAndRewrite (Operation *op, ArrayRef<ArrayRef<Value>> operands,
709
+ ConversionPatternRewriter &rewriter) const final {
710
+ return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
711
+ }
664
712
665
713
// / Rewrite and Match methods that operate on the SourceOp type. These must be
666
714
// / overridden by the derived pattern class.
667
715
virtual void rewrite (SourceOp op, ArrayRef<Value> operands,
668
716
ConversionPatternRewriter &rewriter) const {
669
717
llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
670
718
}
719
+ virtual void rewrite (SourceOp op, ArrayRef<ArrayRef<Value>> operands,
720
+ ConversionPatternRewriter &rewriter) const {
721
+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
722
+ }
671
723
virtual LogicalResult
672
724
matchAndRewrite (SourceOp op, ArrayRef<Value> operands,
673
725
ConversionPatternRewriter &rewriter) const {
@@ -676,6 +728,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
676
728
rewrite (op, operands, rewriter);
677
729
return success ();
678
730
}
731
+ virtual LogicalResult
732
+ matchAndRewrite (SourceOp op, ArrayRef<ArrayRef<Value>> operands,
733
+ ConversionPatternRewriter &rewriter) const {
734
+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
735
+ }
679
736
680
737
private:
681
738
using ConversionPattern::matchAndRewrite;
0 commit comments