@@ -144,27 +144,17 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
144
144
b, loc, producer, getTiledOperands (producer), ivs, tileSizes, sizeBounds,
145
145
/* *omitPartialTileCheck=*/ false ));
146
146
147
- // Iterate over the results in order.
148
- // Extract the subtensor type from the linearized range.
149
- // Since we do not enforce any canonicalizations on the fly, this is always
150
- // fully dynamic at construction time.
147
+ // Take result types from the tiled init operands.
148
+ MutableOperandRange producerDpsInits = producer.getDpsInitsMutable ();
151
149
SmallVector<Type, 4 > resultTypes;
152
150
resultTypes.reserve (producer->getNumResults ());
153
- for (Value operand : producer.getDpsInits ()) {
154
- auto tensorType = dyn_cast<RankedTensorType>(operand.getType ());
155
- if (!tensorType)
156
- continue ;
157
- unsigned rank = tensorType.getRank ();
158
- SmallVector<int64_t , 4 > staticOffsetsVector (
159
- rank, ShapedType::kDynamic );
160
- SmallVector<int64_t , 4 > staticSizesVector (rank, ShapedType::kDynamic );
161
- SmallVector<int64_t , 4 > staticStridesVector (
162
- rank, ShapedType::kDynamic );
163
- resultTypes.push_back (tensor::ExtractSliceOp::inferResultType (
164
- tensorType, staticOffsetsVector, staticSizesVector,
165
- staticStridesVector));
151
+ int64_t firstInitOperandIdx =
152
+ static_cast <OperandRange>(producerDpsInits).getBeginOperandIndex ();
153
+ for (int64_t i = 0 , e = producer->getNumResults (); i < e; ++i) {
154
+ resultTypes.push_back (clonedShapes[firstInitOperandIdx + i].getType ());
166
155
}
167
156
157
+ // Clone the producer with new operands and result types.
168
158
LinalgOp clonedOp = clone (b, producer, resultTypes, clonedShapes);
169
159
170
160
// Shift all IndexOp results by the tile offset.
0 commit comments