Skip to content

Commit 3a087c1

Browse files
[mlir][linalg] Fix invalid IR in Linalg op fusion (#74425)
Linalg op fusion (`Linalg/Transforms/Fusion.cpp`) used to generate invalid fused producer ops: ``` error: 'linalg.conv_2d_nhwc_hwcf' op expected type of operand #2 ('tensor<1x8x16x4xf32>') to match type of corresponding result ('tensor<?x?x?x?xf32>') note: see current operation: %24 = "linalg.conv_2d_nhwc_hwcf"(%21, %22, %23) <{dilations = dense<1> : tensor<2xi64>, operandSegmentSizes = array<i32: 2, 1>, strides = dense<2> : tensor<2xi64>}> ({ ^bb0(%arg9: f32, %arg10: f32, %arg11: f32): %28 = "arith.mulf"(%arg9, %arg10) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32 %29 = "arith.addf"(%arg11, %28) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32 "linalg.yield"(%29) : (f32) -> () }) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>]} : (tensor<1x?x?x3xf32>, tensor<3x3x3x4xf32>, tensor<1x8x16x4xf32>) -> tensor<?x?x?x?xf32> ``` This is a problem because the input IR to greedy pattern rewriter during `-test-linalg-greedy-fusion` is invalid. This commit fixes tests such as `mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir` when verifying the IR after each pattern application (#74270).
1 parent 6a7bbf7 commit 3a087c1

File tree

1 file changed

+7
-17
lines changed

1 file changed

+7
-17
lines changed

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

+7-17
Original file line numberDiff line numberDiff line change
@@ -144,27 +144,17 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
144144
b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds,
145145
/**omitPartialTileCheck=*/false));
146146

147-
// Iterate over the results in order.
148-
// Extract the subtensor type from the linearized range.
149-
// Since we do not enforce any canonicalizations on the fly, this is always
150-
// fully dynamic at construction time.
147+
// Take result types from the tiled init operands.
148+
MutableOperandRange producerDpsInits = producer.getDpsInitsMutable();
151149
SmallVector<Type, 4> resultTypes;
152150
resultTypes.reserve(producer->getNumResults());
153-
for (Value operand : producer.getDpsInits()) {
154-
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
155-
if (!tensorType)
156-
continue;
157-
unsigned rank = tensorType.getRank();
158-
SmallVector<int64_t, 4> staticOffsetsVector(
159-
rank, ShapedType::kDynamic);
160-
SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamic);
161-
SmallVector<int64_t, 4> staticStridesVector(
162-
rank, ShapedType::kDynamic);
163-
resultTypes.push_back(tensor::ExtractSliceOp::inferResultType(
164-
tensorType, staticOffsetsVector, staticSizesVector,
165-
staticStridesVector));
151+
int64_t firstInitOperandIdx =
152+
static_cast<OperandRange>(producerDpsInits).getBeginOperandIndex();
153+
for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) {
154+
resultTypes.push_back(clonedShapes[firstInitOperandIdx + i].getType());
166155
}
167156

157+
// Clone the producer with new operands and result types.
168158
LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes);
169159

170160
// Shift all IndexOp results by the tile offset.

0 commit comments

Comments
 (0)