Skip to content

Commit 9ddbcee

Browse files
authored
[mlir][vector] Extend vector.{insert|extract}_strided_slice (#79052)
Extends `vector.insert_strided_slice` and `vector.insert_strided_slice` to allow scalable input and output vectors. For scalable sizes, the corresponding slice size has to match the corresponding dimension in the output/input vector (insert/extract, respectively). This is supported: ```mlir vector.extract_strided_slice %1 { offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[4]xi32> ``` This is not supported: ```mlir vector.extract_strided_slice %1 { offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[2]xi32> ```
1 parent 28a2b85 commit 9ddbcee

File tree

4 files changed

+117
-1
lines changed

4 files changed

+117
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2857,6 +2857,26 @@ LogicalResult InsertStridedSliceOp::verify() {
28572857
/*halfOpen=*/false, /*min=*/1)))
28582858
return failure();
28592859

2860+
unsigned rankDiff = destShape.size() - sourceShape.size();
2861+
for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
2862+
if (sourceVectorType.getScalableDims()[idx] !=
2863+
destVectorType.getScalableDims()[idx + rankDiff]) {
2864+
return emitOpError("mismatching scalable flags (at source vector idx=")
2865+
<< idx << ")";
2866+
}
2867+
if (sourceVectorType.getScalableDims()[idx]) {
2868+
auto sourceSize = sourceShape[idx];
2869+
auto destSize = destShape[idx + rankDiff];
2870+
if (sourceSize != destSize) {
2871+
return emitOpError("expected size at idx=")
2872+
<< idx
2873+
<< (" to match the corresponding base size from the input "
2874+
"vector (")
2875+
<< sourceSize << (" vs ") << destSize << (")");
2876+
}
2877+
}
2878+
}
2879+
28602880
return success();
28612881
}
28622882

@@ -3194,6 +3214,7 @@ void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
31943214
// Inference works as follows:
31953215
// 1. Add 'sizes' from prefix of dims in 'offsets'.
31963216
// 2. Add sizes from 'vectorType' for remaining dims.
3217+
// Scalable flags are inherited from 'vectorType'.
31973218
static Type inferStridedSliceOpResultType(VectorType vectorType,
31983219
ArrayAttr offsets, ArrayAttr sizes,
31993220
ArrayAttr strides) {
@@ -3206,7 +3227,8 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
32063227
for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
32073228
shape.push_back(vectorType.getShape()[idx]);
32083229

3209-
return VectorType::get(shape, vectorType.getElementType());
3230+
return VectorType::get(shape, vectorType.getElementType(),
3231+
vectorType.getScalableDims());
32103232
}
32113233

32123234
void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
@@ -3265,6 +3287,19 @@ LogicalResult ExtractStridedSliceOp::verify() {
32653287
if (getResult().getType() != resultType)
32663288
return emitOpError("expected result type to be ") << resultType;
32673289

3290+
for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3291+
if (type.getScalableDims()[idx]) {
3292+
auto inputDim = type.getShape()[idx];
3293+
auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3294+
if (inputDim != inputSize)
3295+
return emitOpError("expected size at idx=")
3296+
<< idx
3297+
<< (" to match the corresponding base size from the input "
3298+
"vector (")
3299+
<< inputSize << (" vs ") << inputDim << (")");
3300+
}
3301+
}
3302+
32683303
return success();
32693304
}
32703305

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,28 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
11421142

11431143
// -----
11441144

1145+
func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
1146+
%0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
1147+
return %0 : vector<1x1x[4]xi32>
1148+
}
1149+
1150+
// CHECK-LABEL: func.func @extract_strided_slice_scalable(
1151+
// CHECK-SAME: %[[ARG_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
1152+
1153+
// CHECK: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
1154+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32>
1155+
// CHECK: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
1156+
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<1x[4]xi32>
1157+
// CHECK: %[[CAST_3:.*]] = builtin.unrealized_conversion_cast %[[CST_1]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>>
1158+
1159+
// CHECK: %[[EXT:.*]] = llvm.extractvalue %[[CAST_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
1160+
// CHECK: %[[INS_1:.*]] = llvm.insertvalue %[[EXT]], %[[CAST_3]][0] : !llvm.array<1 x vector<[4]xi32>>
1161+
// CHECK: %[[INS_2:.*]] = llvm.insertvalue %[[INS_1]], %[[CAST_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
1162+
1163+
// CHECK: builtin.unrealized_conversion_cast %[[INS_2]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32>
1164+
1165+
// -----
1166+
11451167
func.func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
11461168
%0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
11471169
return %0 : vector<4x4x4xf32>
@@ -1207,6 +1229,27 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3
12071229

12081230
// -----
12091231

1232+
func.func @insert_strided_slice_scalable(%arg0 : vector<1x1x[4]xi32>, %arg1: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
1233+
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[4]xi32> into vector<1x4x[4]xi32>
1234+
return %0 : vector<1x4x[4]xi32>
1235+
}
1236+
// CHECK-LABEL: func.func @insert_strided_slice_scalable(
1237+
// CHECK-SAME: %[[ARG_0:.*]]: vector<1x1x[4]xi32>,
1238+
// CHECK-SAME: %[[ARG_1:.*]]: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
1239+
1240+
// CHECK: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
1241+
// CHECK: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
1242+
1243+
// CHECK: %[[EXT_1:.*]] = llvm.extractvalue %[[CAST_2]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
1244+
// CHECK: %[[EXT_2:.*]] = llvm.extractvalue %[[CAST_1]][0, 0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
1245+
1246+
// CHECK: %[[INS_1:.*]] = llvm.insertvalue %[[EXT_2]], %[[EXT_1]][3] : !llvm.array<4 x vector<[4]xi32>>
1247+
// CHECK: %[[INS_2:.*]] = llvm.insertvalue %[[INS_1]], %[[CAST_2]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
1248+
1249+
// CHECK: builtin.unrealized_conversion_cast %[[INS_2]] : !llvm.array<1 x array<4 x vector<[4]xi32>>> to vector<1x4x[4]xi32>
1250+
1251+
// -----
1252+
12101253
func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector<f32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>) {
12111254
// CHECK-LABEL: @vector_fma
12121255
// CHECK-SAME: %[[A:.*]]: vector<8xf32>

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,22 @@ func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
652652

653653
// -----
654654

655+
func.func @insert_strided_slice_scalable(%a : vector<1x1x[2]xi32>, %b: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
656+
// expected-error@+1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
657+
%0 = vector.insert_strided_slice %a, %b {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[2]xi32> into vector<1x4x[4]xi32>
658+
return %0 : vector<1x4x[4]xi32>
659+
}
660+
661+
// -----
662+
663+
func.func @insert_strided_slice_scalable(%a : vector<1x1x4xi32>, %b: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
664+
// expected-error@+1 {{op mismatching scalable flags (at source vector idx=2)}}
665+
%0 = vector.insert_strided_slice %a, %b {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x4xi32> into vector<1x4x[4]xi32>
666+
return %0 : vector<1x4x[4]xi32>
667+
}
668+
669+
// -----
670+
655671
func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
656672
// expected-error@+1 {{expected offsets, sizes and strides attributes of same size}}
657673
%1 = vector.extract_strided_slice %arg0 {offsets = [100], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
@@ -687,6 +703,14 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
687703

688704
// -----
689705

706+
func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[2]xi32> {
707+
// expected-error@+1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
708+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
709+
return %1 : vector<1x1x[2]xi32>
710+
}
711+
712+
// -----
713+
690714
func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
691715
// expected-error@+1 {{op expected strides to be confined to [1, 2)}}
692716
%1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,13 +319,27 @@ func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
319319
return
320320
}
321321

322+
// CHECK-LABEL: @insert_strided_slice_scalable
323+
func.func @insert_strided_slice_scalable(%a: vector<4x[16]xf32>, %b: vector<4x8x[16]xf32>) {
324+
// CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 0], strides = [1, 1]} : vector<4x[16]xf32> into vector<4x8x[16]xf32>
325+
%1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 0], strides = [1, 1]} : vector<4x[16]xf32> into vector<4x8x[16]xf32>
326+
return
327+
}
328+
322329
// CHECK-LABEL: @extract_strided_slice
323330
func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
324331
// CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
325332
%1 = vector.extract_strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
326333
return %1: vector<2x2x16xf32>
327334
}
328335

336+
// CHECK-LABEL: @extract_strided_slice_scalable
337+
func.func @extract_strided_slice_scalable(%arg0: vector<4x[8]x16xf32>) -> vector<2x[8]x16xf32> {
338+
// CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32>
339+
%1 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32> to vector<2x[8]x16xf32>
340+
return %1: vector<2x[8]x16xf32>
341+
}
342+
329343
#contraction_to_scalar_accesses = [
330344
affine_map<(i) -> (i)>,
331345
affine_map<(i) -> (i)>,

0 commit comments

Comments
 (0)