diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 791924f92e8ad..30b9217423539 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2857,6 +2857,26 @@ LogicalResult InsertStridedSliceOp::verify() { /*halfOpen=*/false, /*min=*/1))) return failure(); + unsigned rankDiff = destShape.size() - sourceShape.size(); + for (unsigned idx = 0; idx < sourceShape.size(); ++idx) { + if (sourceVectorType.getScalableDims()[idx] != + destVectorType.getScalableDims()[idx + rankDiff]) { + return emitOpError("mismatching scalable flags (at source vector idx=") + << idx << ")"; + } + if (sourceVectorType.getScalableDims()[idx]) { + auto sourceSize = sourceShape[idx]; + auto destSize = destShape[idx + rankDiff]; + if (sourceSize != destSize) { + return emitOpError("expected size at idx=") + << idx + << (" to match the corresponding base size from the input " + "vector (") + << sourceSize << (" vs ") << destSize << (")"); + } + } + } + return success(); } @@ -3194,6 +3214,7 @@ void ReshapeOp::getFixedVectorSizes(SmallVectorImpl &results) { // Inference works as follows: // 1. Add 'sizes' from prefix of dims in 'offsets'. // 2. Add sizes from 'vectorType' for remaining dims. +// Scalable flags are inherited from 'vectorType'. static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides) { @@ -3206,7 +3227,8 @@ static Type inferStridedSliceOpResultType(VectorType vectorType, for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) shape.push_back(vectorType.getShape()[idx]); - return VectorType::get(shape, vectorType.getElementType()); + return VectorType::get(shape, vectorType.getElementType(), + vectorType.getScalableDims()); } void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, @@ -3265,6 +3287,19 @@ LogicalResult ExtractStridedSliceOp::verify() { if (getResult().getType() != resultType) return emitOpError("expected result type to be ") << resultType; + for (unsigned idx = 0; idx < sizes.size(); ++idx) { + if (type.getScalableDims()[idx]) { + auto inputDim = type.getShape()[idx]; + auto inputSize = llvm::cast(sizes[idx]).getInt(); + if (inputDim != inputSize) + return emitOpError("expected size at idx=") + << idx + << (" to match the corresponding base size from the input " + "vector (") + << inputSize << (" vs ") << inputDim << (")"); + } + } + return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 09108ab317999..1c13b16dfd9af 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1142,6 +1142,28 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> { // ----- +func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> { + %0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32> + return %0 : vector<1x1x[4]xi32> +} + +// CHECK-LABEL: func.func @extract_strided_slice_scalable( +// CHECK-SAME: %[[ARG_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> { + +// CHECK: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>> +// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32> +// CHECK: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>> +// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<1x[4]xi32> +// CHECK: %[[CAST_3:.*]] = builtin.unrealized_conversion_cast %[[CST_1]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>> + +// CHECK: %[[EXT:.*]] = llvm.extractvalue %[[CAST_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>> +// CHECK: %[[INS_1:.*]] = llvm.insertvalue %[[EXT]], %[[CAST_3]][0] : !llvm.array<1 x vector<[4]xi32>> +// CHECK: %[[INS_2:.*]] = llvm.insertvalue %[[INS_1]], %[[CAST_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>> + +// CHECK: builtin.unrealized_conversion_cast %[[INS_2]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32> + +// ----- + func.func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> { %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32> return %0 : vector<4x4x4xf32> @@ -1207,6 +1229,27 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3 // ----- +func.func @insert_strided_slice_scalable(%arg0 : vector<1x1x[4]xi32>, %arg1: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> { + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[4]xi32> into vector<1x4x[4]xi32> + return %0 : vector<1x4x[4]xi32> +} +// CHECK-LABEL: func.func @insert_strided_slice_scalable( +// CHECK-SAME: %[[ARG_0:.*]]: vector<1x1x[4]xi32>, +// CHECK-SAME: %[[ARG_1:.*]]: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> { + +// CHECK: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>> +// CHECK: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>> + +// CHECK: %[[EXT_1:.*]] = llvm.extractvalue %[[CAST_2]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>> +// CHECK: %[[EXT_2:.*]] = llvm.extractvalue %[[CAST_1]][0, 0] : !llvm.array<1 x array<1 x vector<[4]xi32>>> + +// CHECK: %[[INS_1:.*]] = llvm.insertvalue %[[EXT_2]], %[[EXT_1]][3] : !llvm.array<4 x vector<[4]xi32>> +// CHECK: %[[INS_2:.*]] = llvm.insertvalue %[[INS_1]], %[[CAST_2]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>> + +// CHECK: builtin.unrealized_conversion_cast %[[INS_2]] : !llvm.array<1 x array<4 x vector<[4]xi32>>> to vector<1x4x[4]xi32> + +// ----- + func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector) { // CHECK-LABEL: @vector_fma // CHECK-SAME: %[[A:.*]]: vector<8xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 5fa8ac245ce97..c16f1cb2876db 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -652,6 +652,22 @@ func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) { // ----- +func.func @insert_strided_slice_scalable(%a : vector<1x1x[2]xi32>, %b: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> { + // expected-error@+1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}} + %0 = vector.insert_strided_slice %a, %b {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[2]xi32> into vector<1x4x[4]xi32> + return %0 : vector<1x4x[4]xi32> +} + +// ----- + +func.func @insert_strided_slice_scalable(%a : vector<1x1x4xi32>, %b: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> { + // expected-error@+1 {{op mismatching scalable flags (at source vector idx=2)}} + %0 = vector.insert_strided_slice %a, %b {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x4xi32> into vector<1x4x[4]xi32> + return %0 : vector<1x4x[4]xi32> +} + +// ----- + func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected offsets, sizes and strides attributes of same size}} %1 = vector.extract_strided_slice %arg0 {offsets = [100], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32> @@ -687,6 +703,14 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) { // ----- +func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[2]xi32> { + // expected-error@+1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}} + %1 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[2]xi32> + return %1 : vector<1x1x[2]xi32> + } + +// ----- + func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{op expected strides to be confined to [1, 2)}} %1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 03532c5c1ceb1..2f8530e7c171a 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -319,6 +319,13 @@ func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) { return } +// CHECK-LABEL: @insert_strided_slice_scalable +func.func @insert_strided_slice_scalable(%a: vector<4x[16]xf32>, %b: vector<4x8x[16]xf32>) { + // CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 0], strides = [1, 1]} : vector<4x[16]xf32> into vector<4x8x[16]xf32> + %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 0], strides = [1, 1]} : vector<4x[16]xf32> into vector<4x8x[16]xf32> + return +} + // CHECK-LABEL: @extract_strided_slice func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> { // CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> @@ -326,6 +333,13 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32 return %1: vector<2x2x16xf32> } +// CHECK-LABEL: @extract_strided_slice_scalable +func.func @extract_strided_slice_scalable(%arg0: vector<4x[8]x16xf32>) -> vector<2x[8]x16xf32> { + // CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32> + %1 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32> to vector<2x[8]x16xf32> + return %1: vector<2x[8]x16xf32> +} + #contraction_to_scalar_accesses = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>,