@@ -48,27 +48,20 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
48
48
return true ;
49
49
}
50
50
51
- // / Return true if the given `insertionPoint` dominates all uses of
52
- // / `emptyTensorOp`.
53
- static bool insertionPointDominatesUses (const DominanceInfo &domInfo,
54
- Operation *insertionPoint,
55
- Operation *emptyTensorOp) {
56
- return llvm::all_of (emptyTensorOp->getUsers (), [&](Operation *user) {
57
- return domInfo.dominates (insertionPoint, user);
58
- });
59
- }
60
-
61
- // / Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
62
- // / that the replacement may use any value from `neededValues`.
51
+ // / Find a valid insertion point for a replacement of `useToBeEliminated`,
52
+ // / assuming that the replacement may use any value from `neededValues`.
63
53
static Operation *
64
- findValidInsertionPoint (Operation *emptyTensorOp ,
54
+ findValidInsertionPoint (OpOperand *useToBeEliminated ,
65
55
const SmallVector<Value> &neededValues) {
66
56
DominanceInfo domInfo;
57
+ assert (isa<OpResult>(useToBeEliminated->get ()) && " expected a result value" );
58
+ Operation *candidateInsertionPoint = useToBeEliminated->get ().getDefiningOp ();
67
59
68
- // Gather all possible insertion points: the location of `emptyTensorOp` and
69
- // right after the definition of each value in `neededValues`.
60
+ // Gather all possible insertion points: the location of
61
+ // `candidateInsertionPoint` and right after the definition of each value in
62
+ // `neededValues`.
70
63
SmallVector<Operation *> insertionPointCandidates;
71
- insertionPointCandidates.push_back (emptyTensorOp );
64
+ insertionPointCandidates.push_back (candidateInsertionPoint );
72
65
for (Value val : neededValues) {
73
66
// Note: The anchor op is using all of `neededValues`, so:
74
67
// * in case of a block argument: There must be at least one op in the block
@@ -90,8 +83,8 @@ findValidInsertionPoint(Operation *emptyTensorOp,
90
83
if (!neededValuesDominateInsertionPoint (domInfo, insertionPoint,
91
84
neededValues))
92
85
continue ;
93
- // Check if the insertion point is before all uses .
94
- if (!insertionPointDominatesUses ( domInfo, insertionPoint, emptyTensorOp ))
86
+ // Check if the insertion point is before the use to be replaced .
87
+ if (!domInfo. dominates ( insertionPoint, useToBeEliminated-> getOwner () ))
95
88
continue ;
96
89
return insertionPoint;
97
90
}
@@ -103,8 +96,9 @@ findValidInsertionPoint(Operation *emptyTensorOp,
103
96
LogicalResult mlir::bufferization::eliminateEmptyTensors (
104
97
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
105
98
OpBuilder::InsertionGuard g (rewriter);
106
-
99
+ llvm::DenseSet<OpOperand *> visitedOpOperands;
107
100
op->walk ([&](SubsetInsertionOpInterface op) {
101
+ visitedOpOperands.clear ();
108
102
OpOperand &source = op.getSourceOperand ();
109
103
// Skip operands that do not bufferize inplace. "tensor.empty" could still
110
104
// be replaced, but the transformation may not be beneficial.
@@ -131,16 +125,25 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
131
125
config.followSameTypeOrCastsOnly = true ;
132
126
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain (
133
127
source.get (), /* condition=*/
134
- [&](Value val) { return val.getDefiningOp <tensor::EmptyOp>(); },
135
- config );
128
+ [&](Value val) { return val.getDefiningOp <tensor::EmptyOp>(); }, config,
129
+ &visitedOpOperands );
136
130
137
131
for (Value v : emptyTensors) {
138
132
Operation *emptyTensorOp = v.getDefiningOp ();
139
133
134
+ // Find the use to be replaced from the use-def chain
135
+ auto iter = llvm::find_if (
136
+ visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
137
+ return llvm::count (emptyTensorOp->getUses (), *opOperand);
138
+ });
139
+ if (iter == visitedOpOperands.end ())
140
+ continue ;
141
+ OpOperand *useToBeReplaced = *iter;
142
+
140
143
// Find a suitable insertion point. If no suitable insertion point for
141
144
// the replacement can be found, skip this replacement.
142
145
Operation *insertionPoint =
143
- findValidInsertionPoint (emptyTensorOp , neededValues);
146
+ findValidInsertionPoint (useToBeReplaced , neededValues);
144
147
if (!insertionPoint)
145
148
continue ;
146
149
@@ -159,8 +162,9 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
159
162
replacement = rewriter.create <tensor::CastOp>(v.getLoc (), v.getType (),
160
163
replacement);
161
164
}
162
- // Replace the tensor::EmptyOp.
163
- rewriter.replaceOp (emptyTensorOp, replacement);
165
+ // Replace the specific use of the tensor::EmptyOp.
166
+ useToBeReplaced->getOwner ()->setOperand (
167
+ useToBeReplaced->getOperandNumber (), replacement);
164
168
state.resetCache ();
165
169
}
166
170
0 commit comments