@@ -57,44 +57,43 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
57
57
return true ;
58
58
}
59
59
60
- // / Return true if the given `insertionPoint` dominates all uses of
61
- // / `emptyTensorOp`.
62
- static bool insertionPointDominatesUses (const DominanceInfo &domInfo,
63
- Operation *insertionPoint,
64
- Operation *emptyTensorOp) {
65
- return llvm::all_of (emptyTensorOp->getUsers (), [&](Operation *user) {
66
- return domInfo.dominates (insertionPoint, user);
67
- });
68
- }
69
-
70
- // / Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
71
- // / that the replacement may use any value from `neededValues`.
60
+ // / Find a valid insertion point for a replacement of `useToBeEliminated`,
61
+ // / assuming that the replacement may use any value from `neededValues`.
72
62
static Operation *
73
- findValidInsertionPoint (Operation *emptyTensorOp ,
63
+ findValidInsertionPoint (OpOperand *useToBeEliminated ,
74
64
const SmallVector<Value> &neededValues) {
75
65
DominanceInfo domInfo;
76
66
67
+ Operation *candidateInsertionPoint = useToBeEliminated->getOwner ();
68
+ assert (isa<OpResult>(useToBeEliminated->get ()) && " expected a result value" );
69
+ // Both `tensor.empty` and its user are within different blocks.
70
+ if (useToBeEliminated->getOwner ()->getBlock () !=
71
+ useToBeEliminated->get ().getDefiningOp ()->getBlock ())
72
+ candidateInsertionPoint = useToBeEliminated->get ().getDefiningOp ();
73
+
77
74
// Trying to move the needed values before the `emptyTensorOp`.
78
75
for (Value val : neededValues) {
79
- if (valueDominateInsertionPoint (domInfo, emptyTensorOp , val))
76
+ if (valueDominateInsertionPoint (domInfo, candidateInsertionPoint , val))
80
77
continue ;
81
78
Operation *definingOp = val.getDefiningOp ();
82
79
if (!definingOp)
83
80
continue ;
84
81
85
82
bool isItSafeToMoveOp =
86
83
llvm::all_of (definingOp->getOperands (), [&](Value operand) {
87
- return valueDominateInsertionPoint (domInfo, emptyTensorOp, operand);
84
+ return valueDominateInsertionPoint (domInfo, candidateInsertionPoint,
85
+ operand);
88
86
});
89
87
90
88
if (isItSafeToMoveOp)
91
- definingOp->moveBefore (emptyTensorOp );
89
+ definingOp->moveBefore (candidateInsertionPoint );
92
90
}
93
91
94
- // Gather all possible insertion points: the location of `emptyTensorOp` and
95
- // right after the definition of each value in `neededValues`.
92
+ // Gather all possible insertion points: the location of
93
+ // `candidateInsertionPoint` and right after the definition of each value in
94
+ // `neededValues`.
96
95
SmallVector<Operation *> insertionPointCandidates;
97
- insertionPointCandidates.push_back (emptyTensorOp );
96
+ insertionPointCandidates.push_back (candidateInsertionPoint );
98
97
for (Value val : neededValues) {
99
98
// Note: The anchor op is using all of `neededValues`, so:
100
99
// * in case of a block argument: There must be at least one op in the block
@@ -116,8 +115,8 @@ findValidInsertionPoint(Operation *emptyTensorOp,
116
115
if (!neededValuesDominateInsertionPoint (domInfo, insertionPoint,
117
116
neededValues))
118
117
continue ;
119
- // Check if the insertion point is before all uses .
120
- if (!insertionPointDominatesUses ( domInfo, insertionPoint, emptyTensorOp ))
118
+ // Check if the insertion point is before the use to be replaced .
119
+ if (!domInfo. dominates ( insertionPoint, useToBeEliminated-> getOwner () ))
121
120
continue ;
122
121
return insertionPoint;
123
122
}
@@ -129,8 +128,9 @@ findValidInsertionPoint(Operation *emptyTensorOp,
129
128
LogicalResult mlir::bufferization::eliminateEmptyTensors (
130
129
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
131
130
OpBuilder::InsertionGuard g (rewriter);
132
-
131
+ llvm::DenseSet<OpOperand *> visitedOpOperands;
133
132
op->walk ([&](SubsetInsertionOpInterface op) {
133
+ visitedOpOperands.clear ();
134
134
OpOperand &source = op.getSourceOperand ();
135
135
// Skip operands that do not bufferize inplace. "tensor.empty" could still
136
136
// be replaced, but the transformation may not be beneficial.
@@ -157,16 +157,25 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
157
157
config.followSameTypeOrCastsOnly = true ;
158
158
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain (
159
159
source.get (), /* condition=*/
160
- [&](Value val) { return val.getDefiningOp <tensor::EmptyOp>(); },
161
- config );
160
+ [&](Value val) { return val.getDefiningOp <tensor::EmptyOp>(); }, config,
161
+ &visitedOpOperands );
162
162
163
163
for (Value v : emptyTensors) {
164
164
Operation *emptyTensorOp = v.getDefiningOp ();
165
165
166
+ // Find the use to be replaced from the use-def chain
167
+ auto iter = llvm::find_if (
168
+ visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
169
+ return llvm::count (emptyTensorOp->getUses (), *opOperand);
170
+ });
171
+ if (iter == visitedOpOperands.end ())
172
+ continue ;
173
+ OpOperand *useToBeReplaced = *iter;
174
+
166
175
// Find a suitable insertion point. If no suitable insertion point for
167
176
// the replacement can be found, skip this replacement.
168
177
Operation *insertionPoint =
169
- findValidInsertionPoint (emptyTensorOp , neededValues);
178
+ findValidInsertionPoint (useToBeReplaced , neededValues);
170
179
if (!insertionPoint)
171
180
continue ;
172
181
@@ -185,8 +194,9 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
185
194
replacement = rewriter.create <tensor::CastOp>(v.getLoc (), v.getType (),
186
195
replacement);
187
196
}
188
- // Replace the tensor::EmptyOp.
189
- rewriter.replaceOp (emptyTensorOp, replacement);
197
+ // Replace the specific use of the tensor::EmptyOp.
198
+ useToBeReplaced->getOwner ()->setOperand (
199
+ useToBeReplaced->getOperandNumber (), replacement);
190
200
state.resetCache ();
191
201
}
192
202
0 commit comments