-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][ArmSME] Add arm_sme.streaming_vl operation #77321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This operation provides a convenient way to query the streaming vector length regardless of the streaming mode. This most useful for functions that call/pass data to streaming functions, but are not streaming themselves. Example: ```mlir %svl_w = arm_sme.streaming_vl <words> ```
d7ebafe
to
58a5ca6
Compare
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis operation provides a convenient way to query the streaming vector Example: %svl_w = arm_sme.streaming_vl <words> Created based on discussion here: #76086 (comment) Full diff: https://github.com/llvm/llvm-project/pull/77321.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index f7cc1d3fe7517f..4060407d81c0fa 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -223,6 +223,21 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
let defaultValue = "CombiningKind::Add";
}
+def TypeSize : I32EnumAttr<"TypeSize", "Size of vector type", [
+ I32EnumAttrCase<"Bytes" , 0, "bytes">,
+ I32EnumAttrCase<"HalfWords" , 1, "half_words">,
+ I32EnumAttrCase<"Words" , 2, "words">,
+ I32EnumAttrCase<"DoubleWords", 3, "double_words">,
+]> {
+ let cppNamespace = "::mlir::arm_sme";
+ let genSpecializedAttr = 0;
+}
+
+def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
+ "type_size"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
@@ -768,4 +783,32 @@ let arguments = (ins
}];
}
+def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
+{
+ let summary = "Query the streaming vector length";
+
+ let description = [{
+ This operation returns the streaming vector length for a type size. Unlike
+ `vector.vscale` the value returned is invariant to the streaming mode.
+
+ Example:
+ ```mlir
+ // Streaming vector length in:
+ // - bytes (8-bit)
+ %svl_b = arm_sme.streaming_vl <bytes>
+ // - half words (16-bit)
+ %svl_h = arm_sme.streaming_vl <half_words>
+ // - words (32-bit)
+ %svl_w = arm_sme.streaming_vl <words>
+ // - double words (64-bit)
+ %svl_d = arm_sme.streaming_vl <double_words>
+ ```
+ }];
+
+ let arguments = (ins ArmSME_TypeSizeAttr: $type_size);
+ let results = (outs Index);
+
+ let assemblyFormat = "$type_size attr-dict";
+}
+
#endif // ARMSME_OPS
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 0c6e2e80b88a3b..c5fdba6d00cc0f 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -518,6 +518,45 @@ struct OuterProductOpConversion
}
};
+/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
+///
+/// Example:
+///
+/// %0 = arm_sme.streaming_vl <half_words>
+///
+/// is converted to:
+///
+/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
+/// %0 = arith.index_cast %cnt : i64 to index
+///
+struct StreamingVLOpConversion
+ : public ConvertOpToLLVMPattern<arm_sme::StreamingVLOp> {
+ using ConvertOpToLLVMPattern<arm_sme::StreamingVLOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
+ arm_sme::StreamingVLOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = streamingVlOp.getLoc();
+ auto i64Type = rewriter.getI64Type();
+ auto *intrOp = [&]() -> Operation * {
+ switch (streamingVlOp.getTypeSize()) {
+ case arm_sme::TypeSize::Bytes:
+ return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
+ case arm_sme::TypeSize::HalfWords:
+ return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
+ case arm_sme::TypeSize::Words:
+ return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
+ case arm_sme::TypeSize::DoubleWords:
+ return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
+ }
+ }();
+ rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
+ streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -555,7 +594,9 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
+ arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
+ arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
+ arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
}
@@ -572,8 +613,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
- OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
- converter);
+ OuterProductOpConversion, ZeroOpConversion, GetTileConversion,
+ StreamingVLOpConversion>(converter);
}
std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index bd88da37bdf966..bab4fd60518c54 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -559,3 +559,45 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index)
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
return %slice : vector<[1]xi128>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.streaming_vl
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_bytes
+// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
+// CHECK: return %[[INDEX_COUNT]] : index
+func.func @arm_sme_streaming_vl_bytes() -> index {
+ %svl_b = arm_sme.streaming_vl <bytes>
+ return %svl_b : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_half_words
+// CHECK: "arm_sme.intr.cntsh"() : () -> i64
+func.func @arm_sme_streaming_vl_half_words() -> index {
+ %svl_h = arm_sme.streaming_vl <half_words>
+ return %svl_h : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_words
+// CHECK: "arm_sme.intr.cntsw"() : () -> i64
+func.func @arm_sme_streaming_vl_words() -> index {
+ %svl_w = arm_sme.streaming_vl <words>
+ return %svl_w : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_double_words
+// CHECK: "arm_sme.intr.cntsd"() : () -> i64
+func.func @arm_sme_streaming_vl_double_words() -> index {
+ %svl_d = arm_sme.streaming_vl <double_words>
+ return %svl_d : index
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 58ff7ef4d8340e..eb4c7149b61f1d 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1095,3 +1095,39 @@ func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: v
%result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
return %result : vector<[16]x[16]xi8>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.streaming_vl
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_streaming_vl_bytes() -> index {
+ // CHECK: arm_sme.streaming_vl <bytes>
+ %svl_b = arm_sme.streaming_vl <bytes>
+ return %svl_b : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_half_words() -> index {
+ // CHECK: arm_sme.streaming_vl <half_words>
+ %svl_h = arm_sme.streaming_vl <half_words>
+ return %svl_h : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_words() -> index {
+ // CHECK: arm_sme.streaming_vl <words>
+ %svl_w = arm_sme.streaming_vl <words>
+ return %svl_w : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_double_words() -> index {
+ // CHECK: arm_sme.streaming_vl <double_words>
+ %svl_d = arm_sme.streaming_vl <double_words>
+ return %svl_d : index
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks Ben, couple of nits but LGTM regardless, cheers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a small suggestion.
This operation provides a convenient way to query the streaming vector length regardless of the streaming mode. This most useful for functions that call/pass data to streaming functions, but are not streaming themselves. Example: ```mlir %svl_w = arm_sme.streaming_vl <word> ``` Created based on discussion here: llvm#76086 (comment)
This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.
Example:
Created based on discussion here: #76086 (comment)