diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index f7cc1d3fe7517..bb0db59add009 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -223,6 +223,21 @@ def ArmSME_CombiningKindAttr : EnumAttr, + I32EnumAttrCase<"Half" , 1, "half">, + I32EnumAttrCase<"Word" , 2, "word">, + I32EnumAttrCase<"Double", 3, "double">, +]> { + let cppNamespace = "::mlir::arm_sme"; + let genSpecializedAttr = 0; +} + +def ArmSME_TypeSizeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // ArmSME op definitions //===----------------------------------------------------------------------===// @@ -768,4 +783,33 @@ let arguments = (ins }]; } +def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]> +{ + let summary = "Query the streaming vector length"; + + let description = [{ + This operation returns the streaming vector length (SVL) for a given type + size. Unlike `vector.vscale` the value returned is invariant to the + streaming mode. + + Example: + ```mlir + // Streaming vector length in: + // - bytes (8-bit, SVL.B) + %svl_b = arm_sme.streaming_vl + // - half words (16-bit, SVL.H) + %svl_h = arm_sme.streaming_vl + // - words (32-bit, SVL.W) + %svl_w = arm_sme.streaming_vl + // - double words (64-bit, SVL.D) + %svl_d = arm_sme.streaming_vl + ``` + }]; + + let arguments = (ins ArmSME_TypeSizeAttr: $type_size); + let results = (outs Index); + + let assemblyFormat = "$type_size attr-dict"; +} + #endif // ARMSME_OPS diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 0c6e2e80b88a3..0bb7ccb463e48 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -518,6 +518,45 @@ struct OuterProductOpConversion } }; +/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics. +/// +/// Example: +/// +/// %0 = arm_sme.streaming_vl +/// +/// is converted to: +/// +/// %cnt = "arm_sme.intr.cntsh"() : () -> i64 +/// %0 = arith.index_cast %cnt : i64 to index +/// +struct StreamingVLOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp, + arm_sme::StreamingVLOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = streamingVlOp.getLoc(); + auto i64Type = rewriter.getI64Type(); + auto *intrOp = [&]() -> Operation * { + switch (streamingVlOp.getTypeSize()) { + case arm_sme::TypeSize::Byte: + return rewriter.create(loc, i64Type); + case arm_sme::TypeSize::Half: + return rewriter.create(loc, i64Type); + case arm_sme::TypeSize::Word: + return rewriter.create(loc, i64Type); + case arm_sme::TypeSize::Double: + return rewriter.create(loc, i64Type); + } + }(); + rewriter.replaceOpWithNewOp( + streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0)); + return success(); + } +}; + } // namespace namespace { @@ -555,7 +594,9 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz, - arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>(); + arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa, + arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh, + arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>(); target.addLegalDialect(); target.addLegalOp(); } @@ -572,8 +613,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, patterns.add( - converter); + OuterProductOpConversion, ZeroOpConversion, GetTileConversion, + StreamingVLOpConversion>(converter); } std::unique_ptr mlir::createConvertArmSMEToLLVMPass() { diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir index bd88da37bdf96..f9cf77ca15ffb 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir @@ -559,3 +559,45 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index) %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout : vector<[1]xi128> from vector<[1]x[1]xi128> return %slice : vector<[1]xi128> } + +//===----------------------------------------------------------------------===// +// arm_sme.streaming_vl +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @arm_sme_streaming_vl_bytes +// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64 +// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index +// CHECK: return %[[INDEX_COUNT]] : index +func.func @arm_sme_streaming_vl_bytes() -> index { + %svl_b = arm_sme.streaming_vl + return %svl_b : index +} + +// ----- + +// CHECK-LABEL: @arm_sme_streaming_vl_half_words +// CHECK: "arm_sme.intr.cntsh"() : () -> i64 +func.func @arm_sme_streaming_vl_half_words() -> index { + %svl_h = arm_sme.streaming_vl + return %svl_h : index +} + +// ----- + +// CHECK-LABEL: @arm_sme_streaming_vl_words +// CHECK: "arm_sme.intr.cntsw"() : () -> i64 +func.func @arm_sme_streaming_vl_words() -> index { + %svl_w = arm_sme.streaming_vl + return %svl_w : index +} + +// ----- + +// CHECK-LABEL: @arm_sme_streaming_vl_double_words +// CHECK: "arm_sme.intr.cntsd"() : () -> i64 +func.func @arm_sme_streaming_vl_double_words() -> index { + %svl_d = arm_sme.streaming_vl + return %svl_d : index +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index 58ff7ef4d8340..2ad742493408b 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -1095,3 +1095,39 @@ func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: v %result = arm_sme.outerproduct %vecA, %vecB kind acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8> return %result : vector<[16]x[16]xi8> } + +//===----------------------------------------------------------------------===// +// arm_sme.streaming_vl +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_streaming_vl_bytes() -> index { + // CHECK: arm_sme.streaming_vl + %svl_b = arm_sme.streaming_vl + return %svl_b : index +} + +// ----- + +func.func @arm_sme_streaming_vl_half_words() -> index { + // CHECK: arm_sme.streaming_vl + %svl_h = arm_sme.streaming_vl + return %svl_h : index +} + +// ----- + +func.func @arm_sme_streaming_vl_words() -> index { + // CHECK: arm_sme.streaming_vl + %svl_w = arm_sme.streaming_vl + return %svl_w : index +} + +// ----- + +func.func @arm_sme_streaming_vl_double_words() -> index { + // CHECK: arm_sme.streaming_vl + %svl_d = arm_sme.streaming_vl + return %svl_d : index +}