From 056d5e1d8a7a7b89ca2fad67d633f1a88bd7fc7e Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 20 Dec 2023 16:24:27 +0000 Subject: [PATCH 01/13] [mlir][ArmSME] Add rudimentary support for tile spills to the stack This adds very basic and inelegant support for something like spilling and reloading tiles if you use more SME tiles than physically exist. This is purely implemented to prevent the compiler from aborting if a function uses too many tiles (i.e. due to bad unrolling), but is expected to perform very poorly. Currenly, this works in two stages: During tile allocation, if we run out of tiles instead of giving up, we switch to allocating 'in-memory' tile IDs. These are tile IDs that start at 16 (which is higher than any real tile ID). A warning will also be emitted for each (root) tile op assigned an in-memory tile ID: ``` warning: failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation ``` Everything after this works like normal until `-convert-arm-sme-to-llvm` Here the in-memory tile op: ```mlir arm_sme.tile_op { tile_id = } ``` Is lowered to: ```mlir // At function entry: %alloca = memref.alloca ... : memref // Around the op: // Swap the contents of %alloca and tile 0. scf.for %slice_idx { %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}> vector.store %current_slice, %alloca[%slice_idx, %c0] } // Execute op using tile 0. arm_sme.tile_op { tile_id = 0 } // Swap the contents of %alloca and tile 0. // This restores tile 0 to its original state. scf.for %slice_idx { %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}> vector.store %current_slice, %alloca[%slice_idx, %c0] } ``` This is inserted during the lowering to LLVM as spilling/reloading registers is a very low-level concept, that can't really be modeled correctly at a high level in MLIR. Note: This is always doing the worst case full-tile swap. This could be optimized to only spill/load data the tile op will use, which could be just a slice. It's also not making any use of liveness, which could allow reusing tiles. But these is not seen as important as correct code should only use the available number of tiles. --- mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h | 3 +- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 33 +++ .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 228 +++++++++++++++++- mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 1 - .../ArmSME/Transforms/TileAllocation.cpp | 66 +++-- .../ArmSMEToLLVM/tile-spills-and-fills.mlir | 96 ++++++++ mlir/test/Dialect/ArmSME/tile-allocation.mlir | 18 +- .../Linalg/CPU/ArmSME/use-too-many-tiles.mlir | 78 ++++++ 8 files changed, 486 insertions(+), 37 deletions(-) create mode 100644 mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h index 9982d4278b603..c507cea5357a7 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h @@ -25,8 +25,9 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir::arm_sme { +static constexpr unsigned kInMemoryTileIdBase = 16; #include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc" -} +} // namespace mlir::arm_sme #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc" diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index bb0db59add009..238e6e4a079c3 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -97,6 +97,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> { // This operation does not allocate a tile. return std::nullopt; }] + >, + InterfaceMethod< + "Returns the VectorType of the tile used by this operation.", + /*returnType=*/"VectorType", + /*methodName=*/"getTileType", + /*arguments=*/(ins), + /*methodBody=*/[{}] > ]; @@ -117,6 +124,11 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> { rewriter.replaceOp($_op, newOp); return newOp; } + + bool isInMemoryTile() { + auto tileId = getTileId(); + return tileId && tileId.getInt() >= kInMemoryTileIdBase; + } }]; let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }]; @@ -331,6 +343,9 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> { std::optional getAllocatedTileType() { return arm_sme::getSMETileType(getVectorType()); } + VectorType getTileType() { + return getVectorType(); + } }]; let assemblyFormat = "attr-dict `:` type($res)"; } @@ -407,6 +422,9 @@ def TileLoadOp : ArmSME_Op<"tile_load", [ std::optional getAllocatedTileType() { return arm_sme::getSMETileType(getVectorType()); } + VectorType getTileType() { + return getVectorType(); + } }]; let builders = [ @@ -475,6 +493,9 @@ def TileStoreOp : ArmSME_Op<"tile_store", [ VectorType getVectorType() { return ::llvm::cast(getValueToStore().getType()); } + VectorType getTileType() { + return getVectorType(); + } }]; let builders = [ @@ -539,6 +560,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ VectorType getVectorType() { return ::llvm::cast(getResult().getType()); } + VectorType getTileType() { + return getVectorType(); + } }]; let assemblyFormat = [{ @@ -596,6 +620,9 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [ VectorType getVectorType() { return ::llvm::cast(getTile().getType()); } + VectorType getTileType() { + return getVectorType(); + } }]; let assemblyFormat = [{ @@ -688,6 +715,9 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [ let extraClassDeclaration = [{ VectorType getSliceType() { return getResult().getType(); } + VectorType getTileType() { + return ::llvm::cast(getTile().getType()); + } }]; let assemblyFormat = [{ @@ -780,6 +810,9 @@ let arguments = (ins return arm_sme::getSMETileType(getResultType()); return std::nullopt; } + VectorType getTileType() { + return getResultType(); + } }]; } diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 0bb7ccb463e48..097b27042f17d 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -32,6 +33,8 @@ using namespace mlir; namespace { +static constexpr StringLiteral kInMemoryTileId("arm_sme.in_memory_tile_id"); + /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic. static Operation *createLoadTileSliceIntrinsic( RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, @@ -129,6 +132,209 @@ IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) { return tileId; } +/// Creates a alloca matching the size of tile used by `tileOp`. The alloca is +/// placed in the first block of the function. +static memref::AllocaOp +createAllocaForTile(RewriterBase &rewriter, Location loc, + FunctionOpInterface func, + arm_sme::ArmSMETileOpInterface tileOp) { + RewriterBase::InsertionGuard g(rewriter); + // Move to the first operation in the function. + rewriter.setInsertionPoint(&func.getBlocks().front().front()); + // Create an alloca matching the tile size of the `tileOp`. + auto vscale = rewriter.create(loc); + auto tileElementType = + llvm::cast(tileOp.getTileType()).getElementType(); + auto memrefType = MemRefType::get( + {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); + auto minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType); + auto minElementsOp = + rewriter.create(loc, minElements); + auto vectorLen = rewriter.create(loc, vscale, minElementsOp); + auto alloca = rewriter.create( + loc, memrefType, ValueRange{vectorLen, vectorLen}); + return alloca; +} + +/// Finds or creates an alloca for a spill of a tile. +static memref::AllocaOp +getOrCreateTileMemory(RewriterBase &rewriter, Location loc, + FunctionOpInterface func, + arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) { + // Find an alloca at the top of the function tagged with a + // 'arm_sme.in_memory_tile_id' that matches `tileId`. + for (auto &op : func.getBlocks().front()) { + auto alloca = llvm::dyn_cast(op); + if (!alloca) + continue; + auto inMemoryTileId = llvm::dyn_cast_or_null( + alloca->getDiscardableAttr(kInMemoryTileId)); + if (!inMemoryTileId) + continue; + if (inMemoryTileId.getInt() == tileId) + return alloca; + } + // Otherwise, create a new alloca: + auto alloca = createAllocaForTile(rewriter, loc, func, tileOp); + alloca->setDiscardableAttr(kInMemoryTileId, + rewriter.getI32IntegerAttr(tileId)); + return alloca; +} + +/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a +/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning +/// the op to tile 0, then emitting a full tile swap between ZA and memory +/// before + after the tile op. +/// +/// Example: +/// +/// arm_sme.tile_op { tile_id = } +/// +/// is converted to: +/// // At function entry: +/// %alloca = memref.alloca ... : memref +/// +/// // Around op: +/// scf.for %slice_idx { +/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> +/// "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}> +/// vector.store %current_slice, %alloca[%slice_idx, %c0] +/// } +/// arm_sme.tile_op { tile_id = 0 } +/// scf.for %slice_idx { +/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> +/// "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}> +/// vector.store %current_slice, %alloca[%slice_idx, %c0] +/// } +/// +/// Note that these spills/fills are not inserted earlier as concept of a +/// register, and the need to swap the contents, can't really be represented +/// correctly at a high level in MLIR. +/// +/// TODO: Reduce the spills/reloads to single slices where possible. +struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { + + ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName, + const LLVMTypeConverter &typeConverter, + PatternBenefit benefit) + : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(), + typeConverter, benefit) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto tileOp = cast(op); + // Tile has a real (hardware) tile. No spills/reloads required. + if (!tileOp.isInMemoryTile()) + return failure(); + + // Step 1. Create an alloca for the tile at the top of the function (if one + // does not already exist). + auto loc = tileOp.getLoc(); + auto func = tileOp->getParentOfType(); + auto tileAlloca = getOrCreateTileMemory(rewriter, loc, func, tileOp, + tileOp.getTileId().getInt()); + + // Step 2. Assign the op a real tile ID. + // For simplicity, we always use tile 0. + auto zeroTileId = rewriter.getI32IntegerAttr(0); + { + rewriter.startRootUpdate(tileOp); + tileOp.setTileId(zeroTileId); + rewriter.finalizeRootUpdate(tileOp); + } + + VectorType tileVectorType = tileOp.getTileType(); + auto sliceType = VectorType::Builder(tileOp.getTileType()).dropDim(0); + auto emitTileSwap = [&] { + emitFullTileSwap(rewriter, loc, tileAlloca, + *arm_sme::getSMETileType(tileVectorType), sliceType, + zeroTileId); + }; + + // Step 3. Emit tile swaps before and after the op. + // TODO: Reduce the amount spilled to the amount of data the `tileOp` + // touches (i.e. a single tile slice). + { + rewriter.setInsertionPoint(op); + // Swap the in-memory tile's contents into ZA before the op. + emitTileSwap(); + rewriter.setInsertionPointAfter(op); + // Swap the tile back out to memory again after the op. + emitTileSwap(); + } + + return success(); + } + + /// Extracts a pointer to a slice of an in-memory tile. + Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc, + Value tileMemory, Value sliceIndex) const { + auto llvmType = getTypeConverter()->convertType(tileMemory.getType()); + auto descriptor = + rewriter.create(loc, llvmType, tileMemory); + auto zero = rewriter.create(loc, 0, /*width=*/64); + auto sliceIndexI64 = rewriter.create( + loc, rewriter.getI64Type(), sliceIndex); + return getStridedElementPtr( + loc, llvm::cast(tileMemory.getType()), + descriptor.getResult(0), {sliceIndexI64, zero}, + static_cast(rewriter)); + } + + /// Emits an in-place swap of a slice of a tile in ZA and a slice of a + /// tile-sized memref (`tileAlloca`). + void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca, + arm_sme::ArmSMETileType tileType, VectorType sliceType, + IntegerAttr tileId, Value sliceIndex) const { + // Cast the slice index to an i32. + auto sliceIndexI32 = rewriter.create( + loc, rewriter.getI32Type(), sliceIndex); + // Create an all-true predicate for the slice. + auto predicateType = sliceType.clone(rewriter.getI1Type()); + auto allTruePredicate = rewriter.create( + loc, DenseElementsAttr::get(predicateType, true)); + // Create zero padding vector (never used due to all-true predicate). + auto zeroVector = rewriter.create( + loc, sliceType, rewriter.getZeroAttr(sliceType)); + // Get a pointer to the current slice. + auto slicePtr = + getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex); + // Read the value of the current slice from ZA. + auto currentTileSlice = rewriter.create( + loc, sliceType, zeroVector, allTruePredicate, tileId, sliceIndexI32); + // Load the new tile slice back from memory into ZA. + createLoadTileSliceIntrinsic( + rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, + allTruePredicate, slicePtr, tileId, sliceIndexI32); + // Store the current tile slice to memory. + auto zero = rewriter.create(loc, 0); + rewriter.create(loc, currentTileSlice, tileAlloca, + ValueRange{sliceIndex, zero}); + } + + /// Emits a full in-place swap of the contents of a tile in ZA and a + /// tile-sized memref (`tileAlloca`). + void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca, + arm_sme::ArmSMETileType tileType, VectorType sliceType, + IntegerAttr tileId) const { + RewriterBase::InsertionGuard guard(rewriter); + // Create an scf.for over all tile slices. + auto minNumElts = + rewriter.create(loc, sliceType.getDimSize(0)); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = rewriter.create( + loc, minNumElts, rewriter.create(loc)); + auto step = rewriter.create(loc, 1); + auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + // Emit a swap for each tile slice. + rewriter.setInsertionPointToStart(forOp.getBody()); + auto sliceIndex = forOp.getInductionVar(); + emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId, + sliceIndex); + } +}; + struct GetTileConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -576,6 +782,17 @@ struct ConvertArmSMEToLLVMPass } }; +template +static void addSpillAndFillsForTileOp(RewritePatternSet &patterns, + LLVMTypeConverter const &typeConverter) { + // Add spill/fill conversions with a very high benefit to ensure they are + // lowered first. + (patterns.add(TileOp::getOperationName(), + typeConverter, + /*benefit=*/1337), + ...); +} + } // namespace void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { @@ -597,7 +814,10 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>(); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); } @@ -611,6 +831,12 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, return std::nullopt; }); + // Register ops that need spills/fills. + addSpillAndFillsForTileOp< + arm_sme::LoadTileSliceOp, arm_sme::MoveTileSliceToVectorOp, + arm_sme::MoveVectorToTileSliceOp, arm_sme::StoreTileSliceOp, + arm_sme::OuterProductOp, arm_sme::ZeroOp>(patterns, converter); + patterns.add { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp, PatternRewriter &rewriter) const override { + auto func = tileOp->getParentOfType(); if (tileOp.getTileId()) return failure(); + auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) { + if (auto attr = llvm::dyn_cast_or_null( + func->getDiscardableAttr(name))) + return unsigned(attr.getInt()); + return defaultVal; + }; + + auto setDiscardableIntAttr = [&](StringRef name, auto value) { + rewriter.updateRootInPlace(tileOp, [&] { + func->setDiscardableAttr(name, + rewriter.getI32IntegerAttr((unsigned)value)); + }); + }; + std::optional tileType = tileOp.getAllocatedTileType(); if (!tileType) return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile"); - auto func = tileOp->getParentOfType(); - TileMask tilesInUse = TileMask::kNone; - if (auto tilesInUseAttr = llvm::dyn_cast_or_null( - func->getDiscardableAttr(kTilesInUseAttr))) - tilesInUse = static_cast(tilesInUseAttr.getInt()); - + TileMask tilesInUse = + static_cast(getDiscardableIntAttr(kTilesInUseAttr)); auto tileId = allocateTileId(*tileType, tilesInUse); - if (failed(tileId)) - return tileOp.emitError("ran out of SME virtual tiles!"); - - rewriter.updateRootInPlace(func, [&]() { - func->setDiscardableAttr( - kTilesInUseAttr, rewriter.getI32IntegerAttr((unsigned)tilesInUse)); - }); - - // Find all the ops that (transitively) depend on this tile. - SetVector dependantOps; - findDependantOps(tileOp->getResult(0), dependantOps); + bool tileIsInMemory = failed(tileId); + if (!tileIsInMemory) + setDiscardableIntAttr(kTilesInUseAttr, tilesInUse); + else { + // If we could not find a real tile, set use a virtual tile ID (ID >= 16). + // A later pass will insert the necessary spills and reloads. + tileId = getDiscardableIntAttr(kNextTileMemoryIndex, kInMemoryTileIdBase); + setDiscardableIntAttr(kNextTileMemoryIndex, *tileId + 1); + tileOp->emitWarning( + "failed to allocate physical tile to operation, all tile " + "operations will go through memory, expect " + "performance degradation"); + } // Set all operations dependent on `tileOp` to use the same tile ID. // This is a naive tile allocation scheme, but works for common cases. For @@ -246,16 +260,18 @@ struct AssignTileIDsPattern // This case would require allocating a new tile for the result of the // scf.if, and moving the contents of %tileA or %tileB to result tile (based // on the %some_cond). + // Find all the ops that (transitively) depend on this tile. + SetVector dependantOps; + findDependantOps(tileOp->getResult(0), dependantOps); auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId); - rewriter.updateRootInPlace(tileOp, [&]() { tileOp.setTileId(tileIDAttr); }); + rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); }); for (auto *op : dependantOps) { - if (auto tileOp = llvm::dyn_cast(op)) { - auto currentTileId = tileOp.getTileId(); + if (auto dependantTileOp = llvm::dyn_cast(op)) { + auto currentTileId = dependantTileOp.getTileId(); if (currentTileId && unsigned(currentTileId.getInt()) != tileId) - return tileOp.emitOpError( + return dependantTileOp.emitOpError( "already assigned different SME virtual tile!"); - rewriter.updateRootInPlace(tileOp, - [&]() { tileOp.setTileId(tileIDAttr); }); + dependantTileOp.setTileId(tileIDAttr); } } diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir new file mode 100644 index 0000000000000..9908f04b7c855 --- /dev/null +++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir @@ -0,0 +1,96 @@ +// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | \ +// RUN: FileCheck %s --check-prefix=AFTER-TILE-ALLOC +// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize -cse \ +// RUN: -split-input-file -verify-diagnostics | \ +// RUN: FileCheck %s --check-prefix=AFTER-LLVM-LOWERING + +// ----- + +/// Checks tile spill/reloads are inserted around in-memory tiles (i.e. tiles +/// that were not assigned a physical SME tile). +/// +/// These spills are currently very naive and paranoid and will spill/reload +/// entire tiles around ArmSME ops. +/// +/// The general pattern is: +/// +/// During tile allocation if there's not a physical tile ID available an op +/// will be assigned an in-memory tile ID (which is a tile ID >= 16). +/// +/// Example: +/// +/// arm_sme.zero : vector<[8]x[8]xi16> +/// +/// Becomes: +/// +/// arm_sme.zero { tile_id = 16 } : vector<[8]x[8]xi16> +/// +/// This works like normal till the final lowering to LLVM, where spills and +/// reloads will be inserted around uses of in-memory tiles. +/// +/// So the above example becomes: +/// +/// // Placed at the top of the function: +/// %tileAlloca = memref.alloca(%svl_h, %svl_h) : memref +/// +/// Then around the op: +/// +/// // Swap contents of %tileAlloca and tile 0 +/// scf.for %sliceIdx ... { +/// %currentSlice = arm_sme.intr.read.horiz {tile_id = 0} +/// arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0} +/// vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0] +/// } +/// // Execute the op using tile 0 +/// arm_sme.intr.zero +/// // Swap contents of %tileAlloca and tile 0 +/// scf.for %sliceIdx ... { +/// %currentSlice = arm_sme.intr.read.horiz {tile_id = 0} +/// arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0} +/// vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0] +/// } +/// + +func.func @use_too_many_tiles() { + %0 = arm_sme.zero : vector<[4]x[4]xi32> + %1 = arm_sme.zero : vector<[4]x[4]xi32> + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + %2 = arm_sme.zero : vector<[8]x[8]xi16> + return +} +// AFTER-TILE-ALLOC-LABEL: @use_too_many_tiles +// AFTER-TILE-ALLOC: arm_sme.zero +// AFTER-TILE-ALLOC-SAME: tile_id = 0 +// AFTER-TILE-ALLOC: arm_sme.zero +// AFTER-TILE-ALLOC-SAME: tile_id = 1 +// AFTER-TILE-ALLOC: arm_sme.zero +// AFTER-TILE-ALLOC-SAME: tile_id = 16 + +// AFTER-LLVM-LOWERING-LABEL: @use_too_many_tiles +// AFTER-LLVM-LOWERING-DAG: %[[C0:.*]] = arith.constant 0 : index +// AFTER-LLVM-LOWERING-DAG: %[[C1:.*]] = arith.constant 1 : index +// AFTER-LLVM-LOWERING-DAG: %[[C8:.*]] = arith.constant 8 : index +// AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale +// AFTER-LLVM-LOWERING-DAG: %[[SVL_H:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index +// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_H]], %[[SVL_H]]) +// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref +// +// AFTER-LLVM-LOWERING-NOT: scf.for +// AFTER-LLVM-LOWERING: arm_sme.intr.zero +// +// AFTER-LLVM-LOWERING-NOT: scf.for +// AFTER-LLVM-LOWERING: arm_sme.intr.zero +// +// AFTER-LLVM-LOWERING: scf.for +// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { +// AFTER-LLVM-LOWERING: arm_sme.intr.read.horiz +// AFTER-LLVM-LOWERING-NEXT: arm_sme.intr.ld1h.horiz +// AFTER-LLVM-LOWERING-NEXT: vector.store +// AFTER-LLVM-LOWERING-NEXT: } +// AFTER-LLVM-LOWERING: arm_sme.intr.zero +// AFTER-LLVM-LOWERING: scf.for +// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { +// AFTER-LLVM-LOWERING: arm_sme.intr.read.horiz +// AFTER-LLVM-LOWERING-NEXT: arm_sme.intr.ld1h.horiz +// AFTER-LLVM-LOWERING-NEXT: vector.store +// AFTER-LLVM-LOWERING-NEXT: } diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir index 1f895e4984ba8..7c887ced160b1 100644 --- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir +++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir @@ -35,7 +35,7 @@ func.func @za_b() { func.func @za_b__out_of_tiles() { %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} %next_tile = arm_sme.get_tile : vector<[16]x[16]xi8> return } @@ -44,7 +44,7 @@ func.func @za_b__out_of_tiles() { func.func @za_b_overlapping_za_q() { %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -79,7 +79,7 @@ func.func @za_h__out_of_tiles() { %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16> // CHECK-NEXT: tile_id = 1 %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} %next_tile = arm_sme.get_tile : vector<[8]x[8]xi16> return } @@ -136,7 +136,7 @@ func.func @za_h_overlapping_za_q() { %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -174,7 +174,7 @@ func.func @za_s__out_of_tiles() { %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32> %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32> %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} %next_tile = arm_sme.get_tile : vector<[4]x[4]xi32> return } @@ -218,7 +218,7 @@ func.func @za_s_overlapping_za_q() { %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -268,7 +268,7 @@ func.func @za_d__out_of_tiles() { %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64> %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64> %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} %next_tile = arm_sme.get_tile : vector<[2]x[2]xi64> return } @@ -291,7 +291,7 @@ func.func @za_d_overlapping_za_q() { %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -365,7 +365,7 @@ func.func @za_q__out_of_tiles() { %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir new file mode 100644 index 0000000000000..ea48fa77861cf --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir @@ -0,0 +1,78 @@ + +// RUN: mlir-opt %s \ +// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles \ +// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ +// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \ +// RUN: -convert-arm-sme-to-llvm -convert-vector-to-llvm=enable-arm-sve -cse \ +// RUN: -canonicalize -test-lower-to-llvm -verify-diagnostics | \ +// RUN: %mcr_aarch64_cmd \ +// RUN: -e=main -entry-point-result=void \ +// RUN: -march=aarch64 -mattr="+sve,+sme" \ +// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \ +// RUN: FileCheck %s + +/// This function uses too many tiles! There's only two i16 tiles (ZA0.H and +/// ZA1.H), but this function uses five i16 tiles! Very expensive spills/reloads +/// will be inserted to emulate the extra three tiles. Note: This is only done +/// to avoid the compiler erroring out but is expected to have very poor +/// performance (hence the warning). +func.func @use_too_many_tiles(%a: memref, %b: memref, %c: memref) { + %c0 = arith.constant 0 : index + %tile_a = arith.constant dense<0> : vector<[8]x[8]xi16> + %tile_b = arith.constant dense<1> : vector<[8]x[8]xi16> + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + %tile_c = arm_sme.tile_load %a[%c0, %c0] : memref, vector<[8]x[8]xi16> + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + %tile_d = arm_sme.tile_load %b[%c0, %c0] : memref, vector<[8]x[8]xi16> + // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + %tile_e = arm_sme.tile_load %c[%c0, %c0] : memref, vector<[8]x[8]xi16> + + // CHECK-LABEL: tile_a: + // CHECK-COUNT-8: ( 0, 0, 0, 0, 0, 0, 0, 0 + vector.print str "tile_a:" + vector.print %tile_a : vector<[8]x[8]xi16> + // CHECK-LABEL: tile_b: + // CHECK-COUNT-8: ( 1, 1, 1, 1, 1, 1, 1, 1 + vector.print str "tile_b:" + vector.print %tile_b : vector<[8]x[8]xi16> + // CHECK-LABEL: tile_c: + // CHECK-COUNT-8: ( 2, 2, 2, 2, 2, 2, 2, 2 + vector.print str "tile_c:" + vector.print %tile_c : vector<[8]x[8]xi16> + // CHECK-LABEL: tile_d: + // CHECK-COUNT-8: ( 3, 3, 3, 3, 3, 3, 3, 3 + vector.print str "tile_d:" + vector.print %tile_d : vector<[8]x[8]xi16> + // CHECK-LABEL: tile_e: + // CHECK-COUNT-8: ( 4, 4, 4, 4, 4, 4, 4, 4 + vector.print str "tile_e:" + vector.print %tile_e : vector<[8]x[8]xi16> + return +} + +func.func @get_svl() -> index attributes { enable_arm_streaming_ignore, arm_locally_streaming }{ + %vscale = vector.vscale + return %vscale : index +} + +func.func @main() { + %c16 = arith.constant 16 : index + %svl = call @get_svl() : () -> index + %svl_h = arith.muli %c16, %svl : index + + %two = arith.constant 2 : i16 + %three = arith.constant 3 : i16 + %four = arith.constant 4 : i16 + + %memA = memref.alloca(%svl_h, %svl_h) : memref + %memB = memref.alloca(%svl_h, %svl_h) : memref + %memC = memref.alloca(%svl_h, %svl_h) : memref + + linalg.fill ins(%two : i16) outs(%memA : memref) + linalg.fill ins(%three : i16) outs(%memB : memref) + linalg.fill ins(%four : i16) outs(%memC : memref) + + func.call @use_too_many_tiles(%memA, %memB, %memC) : (memref, memref, memref) -> () + return +} From 08634f2b33f9fae27ef814486f22f9b095b67fba Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Thu, 21 Dec 2023 11:59:43 +0000 Subject: [PATCH 02/13] fixups --- .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 3 +- .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 47 +++++++++---------- .../ArmSME/Transforms/TileAllocation.cpp | 5 +- .../ArmSMEToLLVM/tile-spills-and-fills.mlir | 2 +- mlir/test/Dialect/ArmSME/tile-allocation.mlir | 18 +++---- .../Linalg/CPU/ArmSME/use-too-many-tiles.mlir | 18 +++---- 6 files changed, 43 insertions(+), 50 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 238e6e4a079c3..fa090362aad59 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -102,8 +102,7 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> { "Returns the VectorType of the tile used by this operation.", /*returnType=*/"VectorType", /*methodName=*/"getTileType", - /*arguments=*/(ins), - /*methodBody=*/[{}] + /*arguments=*/(ins) > ]; diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 097b27042f17d..d40e52e2d5cd1 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -33,7 +33,7 @@ using namespace mlir; namespace { -static constexpr StringLiteral kInMemoryTileId("arm_sme.in_memory_tile_id"); +static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id"); /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic. static Operation *createLoadTileSliceIntrinsic( @@ -132,7 +132,7 @@ IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) { return tileId; } -/// Creates a alloca matching the size of tile used by `tileOp`. The alloca is +/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is /// placed in the first block of the function. static memref::AllocaOp createAllocaForTile(RewriterBase &rewriter, Location loc, @@ -140,14 +140,13 @@ createAllocaForTile(RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileOpInterface tileOp) { RewriterBase::InsertionGuard g(rewriter); // Move to the first operation in the function. - rewriter.setInsertionPoint(&func.getBlocks().front().front()); + rewriter.setInsertionPointToStart(&func.getBlocks().front()); // Create an alloca matching the tile size of the `tileOp`. auto vscale = rewriter.create(loc); - auto tileElementType = - llvm::cast(tileOp.getTileType()).getElementType(); + auto tileElementType = tileOp.getTileType().getElementType(); auto memrefType = MemRefType::get( {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); - auto minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType); + unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType); auto minElementsOp = rewriter.create(loc, minElements); auto vectorLen = rewriter.create(loc, vscale, minElementsOp); @@ -157,10 +156,9 @@ createAllocaForTile(RewriterBase &rewriter, Location loc, } /// Finds or creates an alloca for a spill of a tile. -static memref::AllocaOp -getOrCreateTileMemory(RewriterBase &rewriter, Location loc, - FunctionOpInterface func, - arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) { +static memref::AllocaOp getOrCreateAllocaForTile( + RewriterBase &rewriter, Location loc, FunctionOpInterface func, + arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) { // Find an alloca at the top of the function tagged with a // 'arm_sme.in_memory_tile_id' that matches `tileId`. for (auto &op : func.getBlocks().front()) { @@ -168,7 +166,7 @@ getOrCreateTileMemory(RewriterBase &rewriter, Location loc, if (!alloca) continue; auto inMemoryTileId = llvm::dyn_cast_or_null( - alloca->getDiscardableAttr(kInMemoryTileId)); + alloca->getDiscardableAttr(kInMemoryTileIdAttr)); if (!inMemoryTileId) continue; if (inMemoryTileId.getInt() == tileId) @@ -176,7 +174,7 @@ getOrCreateTileMemory(RewriterBase &rewriter, Location loc, } // Otherwise, create a new alloca: auto alloca = createAllocaForTile(rewriter, loc, func, tileOp); - alloca->setDiscardableAttr(kInMemoryTileId, + alloca->setDiscardableAttr(kInMemoryTileIdAttr, rewriter.getI32IntegerAttr(tileId)); return alloca; } @@ -188,23 +186,24 @@ getOrCreateTileMemory(RewriterBase &rewriter, Location loc, /// /// Example: /// +/// // Note: = tile ID >= 16. /// arm_sme.tile_op { tile_id = } /// /// is converted to: /// // At function entry: -/// %alloca = memref.alloca ... : memref +/// %spill = memref.alloca ... : memref /// /// // Around op: /// scf.for %slice_idx { /// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> -/// "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}> -/// vector.store %current_slice, %alloca[%slice_idx, %c0] +/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> +/// vector.store %current_slice, %spill[%slice_idx, %c0] /// } /// arm_sme.tile_op { tile_id = 0 } /// scf.for %slice_idx { /// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> -/// "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}> -/// vector.store %current_slice, %alloca[%slice_idx, %c0] +/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> +/// vector.store %current_slice, %spill[%slice_idx, %c0] /// } /// /// Note that these spills/fills are not inserted earlier as concept of a @@ -232,20 +231,16 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { // does not already exist). auto loc = tileOp.getLoc(); auto func = tileOp->getParentOfType(); - auto tileAlloca = getOrCreateTileMemory(rewriter, loc, func, tileOp, - tileOp.getTileId().getInt()); + auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp, + tileOp.getTileId().getInt()); // Step 2. Assign the op a real tile ID. - // For simplicity, we always use tile 0. + // For simplicity, we always use tile 0 (which always exists). auto zeroTileId = rewriter.getI32IntegerAttr(0); - { - rewriter.startRootUpdate(tileOp); - tileOp.setTileId(zeroTileId); - rewriter.finalizeRootUpdate(tileOp); - } + rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); }); VectorType tileVectorType = tileOp.getTileType(); - auto sliceType = VectorType::Builder(tileOp.getTileType()).dropDim(0); + auto sliceType = VectorType::Builder(tileVectorType).dropDim(0); auto emitTileSwap = [&] { emitFullTileSwap(rewriter, loc, tileAlloca, *arm_sme::getSMETileType(tileVectorType), sliceType, diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index a77b218bc1a60..3c089d47d2860 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -241,9 +241,8 @@ struct AssignTileIDsPattern tileId = getDiscardableIntAttr(kNextTileMemoryIndex, kInMemoryTileIdBase); setDiscardableIntAttr(kNextTileMemoryIndex, *tileId + 1); tileOp->emitWarning( - "failed to allocate physical tile to operation, all tile " - "operations will go through memory, expect " - "performance degradation"); + "failed to allocate SME virtual tile to operation, all tile " + "operations will go through memory, expect degraded performance"); } // Set all operations dependent on `tileOp` to use the same tile ID. diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir index 9908f04b7c855..999acbfc66bef 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir @@ -54,7 +54,7 @@ func.func @use_too_many_tiles() { %0 = arm_sme.zero : vector<[4]x[4]xi32> %1 = arm_sme.zero : vector<[4]x[4]xi32> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %2 = arm_sme.zero : vector<[8]x[8]xi16> return } diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir index 7c887ced160b1..9c368dd4fa23f 100644 --- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir +++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir @@ -35,7 +35,7 @@ func.func @za_b() { func.func @za_b__out_of_tiles() { %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[16]x[16]xi8> return } @@ -44,7 +44,7 @@ func.func @za_b__out_of_tiles() { func.func @za_b_overlapping_za_q() { %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -79,7 +79,7 @@ func.func @za_h__out_of_tiles() { %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16> // CHECK-NEXT: tile_id = 1 %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[8]x[8]xi16> return } @@ -136,7 +136,7 @@ func.func @za_h_overlapping_za_q() { %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -174,7 +174,7 @@ func.func @za_s__out_of_tiles() { %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32> %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32> %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[4]x[4]xi32> return } @@ -218,7 +218,7 @@ func.func @za_s_overlapping_za_q() { %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -268,7 +268,7 @@ func.func @za_d__out_of_tiles() { %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64> %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64> %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[2]x[2]xi64> return } @@ -291,7 +291,7 @@ func.func @za_d_overlapping_za_q() { %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -365,7 +365,7 @@ func.func @za_q__out_of_tiles() { %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir index ea48fa77861cf..ef5a874268751 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir @@ -21,11 +21,11 @@ func.func @use_too_many_tiles(%a: memref, %b: memref, %c: mem %c0 = arith.constant 0 : index %tile_a = arith.constant dense<0> : vector<[8]x[8]xi16> %tile_b = arith.constant dense<1> : vector<[8]x[8]xi16> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_c = arm_sme.tile_load %a[%c0, %c0] : memref, vector<[8]x[8]xi16> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_d = arm_sme.tile_load %b[%c0, %c0] : memref, vector<[8]x[8]xi16> - // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_e = arm_sme.tile_load %c[%c0, %c0] : memref, vector<[8]x[8]xi16> // CHECK-LABEL: tile_a: @@ -61,17 +61,17 @@ func.func @main() { %svl = call @get_svl() : () -> index %svl_h = arith.muli %c16, %svl : index - %two = arith.constant 2 : i16 - %three = arith.constant 3 : i16 - %four = arith.constant 4 : i16 + %c2 = arith.constant 2 : i16 + %c3 = arith.constant 3 : i16 + %c4 = arith.constant 4 : i16 %memA = memref.alloca(%svl_h, %svl_h) : memref %memB = memref.alloca(%svl_h, %svl_h) : memref %memC = memref.alloca(%svl_h, %svl_h) : memref - linalg.fill ins(%two : i16) outs(%memA : memref) - linalg.fill ins(%three : i16) outs(%memB : memref) - linalg.fill ins(%four : i16) outs(%memC : memref) + linalg.fill ins(%c2 : i16) outs(%memA : memref) + linalg.fill ins(%c3 : i16) outs(%memB : memref) + linalg.fill ins(%c4 : i16) outs(%memC : memref) func.call @use_too_many_tiles(%memA, %memB, %memC) : (memref, memref, memref) -> () return From ab41d905208df872e0fc007aed717728b25dd91b Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Thu, 21 Dec 2023 14:21:50 +0000 Subject: [PATCH 03/13] fixups - Show alloca usage in tests - Add test showing some very excessive spills - Document a possible API to reduce spills --- .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 21 +++++- .../ArmSMEToLLVM/tile-spills-and-fills.mlir | 73 +++++++++++++++++-- 2 files changed, 85 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index d40e52e2d5cd1..861f2d2a26201 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -210,7 +210,26 @@ static memref::AllocaOp getOrCreateAllocaForTile( /// register, and the need to swap the contents, can't really be represented /// correctly at a high level in MLIR. /// -/// TODO: Reduce the spills/reloads to single slices where possible. +/// TODO: Reduce the spills/reloads to single slices where possible (and omit +/// redundant reloads). This could be done via a method on the +/// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.: +/// +/// `tileOp.getZaUsage()` could return: +/// +/// struct ArmSMEOpZAUsage { +/// enum class Kind { +/// TileRead, // Omit store after tile operation. +/// TileWrite, // Omit load before tile operation. +/// TileReadWrite, // Needs both tile load and store. +/// SliceRead, // Spill single slice and omit store after operation. +/// SliceWrite, // Spill single slice and omit load before operation. +/// SliceReadWrite // Spill single slice. +/// }; +/// Value sliceIndex {}; +/// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal }; +/// }; +/// +} struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName, diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir index 999acbfc66bef..ffa249f998601 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir @@ -4,8 +4,6 @@ // RUN: -split-input-file -verify-diagnostics | \ // RUN: FileCheck %s --check-prefix=AFTER-LLVM-LOWERING -// ----- - /// Checks tile spill/reloads are inserted around in-memory tiles (i.e. tiles /// that were not assigned a physical SME tile). /// @@ -51,6 +49,10 @@ /// } /// +// ----- + +/// Note: In this example loads into ZA are inserted before the zero instruction. +/// These are obviously redundant, but there's no checks to avoid this. func.func @use_too_many_tiles() { %0 = arm_sme.zero : vector<[4]x[4]xi32> %1 = arm_sme.zero : vector<[4]x[4]xi32> @@ -83,14 +85,69 @@ func.func @use_too_many_tiles() { // // AFTER-LLVM-LOWERING: scf.for // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { -// AFTER-LLVM-LOWERING: arm_sme.intr.read.horiz -// AFTER-LLVM-LOWERING-NEXT: arm_sme.intr.ld1h.horiz -// AFTER-LLVM-LOWERING-NEXT: vector.store +// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] +// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] +// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] // AFTER-LLVM-LOWERING-NEXT: } // AFTER-LLVM-LOWERING: arm_sme.intr.zero // AFTER-LLVM-LOWERING: scf.for // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { -// AFTER-LLVM-LOWERING: arm_sme.intr.read.horiz -// AFTER-LLVM-LOWERING-NEXT: arm_sme.intr.ld1h.horiz -// AFTER-LLVM-LOWERING-NEXT: vector.store +// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] +// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] +// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING-NEXT: } + +// ----- + +/// Note: In this example an entire tile swap is inserted before/after the +/// `arm_sme.load_tile_slice` operation. Really, this only needs to spill a +/// single tile slice (and can omit the initial load, like in the previous example). +func.func @very_excessive_spills(%memref : memref) -> vector<[4]x[4]xf32> { + %useAllTiles = arm_sme.get_tile : vector<[16]x[16]xi8> + %c0 = arith.constant 0 : index + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} + %tile = arm_sme.get_tile : vector<[4]x[4]xf32> + %mask = vector.constant_mask [4] : vector<[4]xi1> + %loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref, vector<[4]xi1>, vector<[4]x[4]xf32> + return %loadSlice : vector<[4]x[4]xf32> +} +// AFTER-TILE-ALLOC-LABEL: @very_excessive_spills +// AFTER-TILE-ALLOC: arm_sme.get_tile +// AFTER-TILE-ALLOC-SAME: tile_id = 0 +// AFTER-TILE-ALLOC: arm_sme.load_tile_slice +// AFTER-TILE-ALLOC-SAME: tile_id = 16 + +// AFTER-LLVM-LOWERING-LABEL: @very_excessive_spills +// AFTER-LLVM-LOWERING-DAG: %[[C0:.*]] = arith.constant 0 : index +// AFTER-LLVM-LOWERING-DAG: %[[C1:.*]] = arith.constant 1 : index +// AFTER-LLVM-LOWERING-DAG: %[[C4:.*]] = arith.constant 4 : index +// AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale +// AFTER-LLVM-LOWERING-DAG: %[[SVL_S:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index +// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]]) +// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref +// +// AFTER-LLVM-LOWERING: scf.for +// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] { +// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] +// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] +// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING-NEXT: } +// AFTER-LLVM-LOWERING: "arm_sme.intr.ld1w.horiz"{{.*}} <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING: scf.for +// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] { +// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] +// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] +// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] // AFTER-LLVM-LOWERING-NEXT: } From 5d2879b76ad307865130a5f3285c9b2e30391fd1 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Thu, 21 Dec 2023 14:25:13 +0000 Subject: [PATCH 04/13] Remove newline --- .../Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir index ef5a874268751..fe125c9f3cf16 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir @@ -1,4 +1,3 @@ - // RUN: mlir-opt %s \ // RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles \ // RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ From 35fb2768e69a1d45cc93ec9554d72723cb6fd617 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Thu, 21 Dec 2023 14:45:03 +0000 Subject: [PATCH 05/13] Fix build error --- mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 861f2d2a26201..9e6e074978707 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -229,7 +229,6 @@ static memref::AllocaOp getOrCreateAllocaForTile( /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal }; /// }; /// -} struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName, From 2d1a9404aca1f9a0ec7e56738fcd0f881e54c44e Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Thu, 21 Dec 2023 16:08:31 +0000 Subject: [PATCH 06/13] fixups --- mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 3 +-- mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 15 +++++++-------- .../Dialect/ArmSME/Transforms/TileAllocation.cpp | 10 +++++----- .../ArmSMEToLLVM/tile-spills-and-fills.mlir | 2 +- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index fa090362aad59..8a34ad7e52012 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -101,8 +101,7 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> { InterfaceMethod< "Returns the VectorType of the tile used by this operation.", /*returnType=*/"VectorType", - /*methodName=*/"getTileType", - /*arguments=*/(ins) + /*methodName=*/"getTileType" > ]; diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 9e6e074978707..920858af7b5ba 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -195,15 +195,15 @@ static memref::AllocaOp getOrCreateAllocaForTile( /// /// // Around op: /// scf.for %slice_idx { -/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> +/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> -/// vector.store %current_slice, %spill[%slice_idx, %c0] +/// vector.store %slice_to_save, %spill[%slice_idx, %c0] /// } /// arm_sme.tile_op { tile_id = 0 } /// scf.for %slice_idx { -/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> +/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> -/// vector.store %current_slice, %spill[%slice_idx, %c0] +/// vector.store %slice_to_save, %spill[%slice_idx, %c0] /// } /// /// Note that these spills/fills are not inserted earlier as concept of a @@ -307,15 +307,14 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { auto predicateType = sliceType.clone(rewriter.getI1Type()); auto allTruePredicate = rewriter.create( loc, DenseElementsAttr::get(predicateType, true)); - // Create zero padding vector (never used due to all-true predicate). - auto zeroVector = rewriter.create( - loc, sliceType, rewriter.getZeroAttr(sliceType)); + // Create padding vector (never used due to all-true predicate). + auto padVector = rewriter.create(loc, sliceType); // Get a pointer to the current slice. auto slicePtr = getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex); // Read the value of the current slice from ZA. auto currentTileSlice = rewriter.create( - loc, sliceType, zeroVector, allTruePredicate, tileId, sliceIndexI32); + loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32); // Load the new tile slice back from memory into ZA. createLoadTileSliceIntrinsic( rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index 3c089d47d2860..51a85f516319f 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -63,7 +63,7 @@ namespace { static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use"); static constexpr StringLiteral - kNextTileMemoryIndex("arm_sme.next_in_memory_tile_id"); + kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id"); enum class TileMask : unsigned { // clang-format off @@ -207,17 +207,16 @@ struct AssignTileIDsPattern using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp, PatternRewriter &rewriter) const override { - auto func = tileOp->getParentOfType(); if (tileOp.getTileId()) return failure(); + auto func = tileOp->getParentOfType(); auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) { if (auto attr = llvm::dyn_cast_or_null( func->getDiscardableAttr(name))) return unsigned(attr.getInt()); return defaultVal; }; - auto setDiscardableIntAttr = [&](StringRef name, auto value) { rewriter.updateRootInPlace(tileOp, [&] { func->setDiscardableAttr(name, @@ -238,8 +237,9 @@ struct AssignTileIDsPattern else { // If we could not find a real tile, set use a virtual tile ID (ID >= 16). // A later pass will insert the necessary spills and reloads. - tileId = getDiscardableIntAttr(kNextTileMemoryIndex, kInMemoryTileIdBase); - setDiscardableIntAttr(kNextTileMemoryIndex, *tileId + 1); + tileId = + getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase); + setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1); tileOp->emitWarning( "failed to allocate SME virtual tile to operation, all tile " "operations will go through memory, expect degraded performance"); diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir index ffa249f998601..6a4ac2dfb05bb 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir @@ -23,7 +23,7 @@ /// /// arm_sme.zero { tile_id = 16 } : vector<[8]x[8]xi16> /// -/// This works like normal till the final lowering to LLVM, where spills and +/// This works like normal until the final lowering to LLVM, where spills and /// reloads will be inserted around uses of in-memory tiles. /// /// So the above example becomes: From c0dce2aea70a2e8bbf21bdb5e483c30704265064 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 22 Dec 2023 10:52:28 +0000 Subject: [PATCH 07/13] Make adding spills/fills conversions implict This switches the ArmSME -> LLVM conversion patterns to use a new `ConvertArmSMEOpToLLVMPattern` base class. Using this implicitly adds the spills/fills conversion pattern, unless the `requiresSpillsAndFills` template parameteris set to `RequiresSpillsAndFills::No`. --- .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 108 +++++++++++------- 1 file changed, 67 insertions(+), 41 deletions(-) diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 920858af7b5ba..07bacffc4b838 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -240,9 +240,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto tileOp = cast(op); + auto tileOp = dyn_cast(op); // Tile has a real (hardware) tile. No spills/reloads required. - if (!tileOp.isInMemoryTile()) + if (!tileOp || !tileOp.isInMemoryTile()) return failure(); // Step 1. Create an alloca for the tile at the top of the function (if one @@ -347,8 +347,29 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { } }; -struct GetTileConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +enum class RequiresSpillsAndFills { Yes, No }; + +/// Base class for ArmSME to LLVM conversion patterns. By default, this adds +/// spills and fills around ArmSME ops that use in-memory tile IDs. This can be +/// disabled by setting the `requiresSpillsAndFills` template parameter to +/// `RequiresSpillsAndFills::No`. +template +struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern { + using ArmSMEOp = SourceOp; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + static constexpr bool requiresSpillsAndFillsConversion() { + return requiresSpillsAndFills == RequiresSpillsAndFills::Yes; + } +}; + +struct GetTileConversion + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern< + arm_sme::GetTileOp, + RequiresSpillsAndFills::No>::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor, @@ -374,8 +395,9 @@ struct GetTileConversion : public ConvertOpToLLVMPattern { /// /// The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away /// once all ArmSME ops have been converted to LLVM intrinsics. -struct ZeroOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern< + arm_sme::ZeroOp>::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor, @@ -451,9 +473,9 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern { /// Lower `arm_sme.load_tile_slice` to SME intrinsics. struct LoadTileSliceConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern< + arm_sme::LoadTileSliceOp>::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp, @@ -495,9 +517,9 @@ struct LoadTileSliceConversion /// Lower for `arm_sme.store_tile_slice` to SME intrinsics. struct StoreTileSliceConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern< + arm_sme::StoreTileSliceOp>::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp, @@ -537,9 +559,9 @@ struct StoreTileSliceConversion /// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. struct MoveVectorToTileSliceConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern< + arm_sme::MoveVectorToTileSliceOp>::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp, @@ -591,9 +613,9 @@ struct MoveVectorToTileSliceConversion /// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. struct MoveTileSliceToVectorConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern< + arm_sme::MoveTileSliceToVectorOp>::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector, @@ -653,8 +675,9 @@ struct MoveTileSliceToVectorConversion /// /// Currently only supports FMOPA and BFMOPA (non-widening). struct OuterProductOpConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern< + arm_sme::OuterProductOp>::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::OuterProductOp outerProductOp, @@ -748,8 +771,9 @@ struct OuterProductOpConversion /// %0 = arith.index_cast %cnt : i64 to index /// struct StreamingVLOpConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp, @@ -794,15 +818,22 @@ struct ConvertArmSMEToLLVMPass } }; -template -static void addSpillAndFillsForTileOp(RewritePatternSet &patterns, - LLVMTypeConverter const &typeConverter) { - // Add spill/fill conversions with a very high benefit to ensure they are - // lowered first. - (patterns.add(TileOp::getOperationName(), - typeConverter, - /*benefit=*/1337), - ...); +template +static void +addArmSMEConversionPatterns(RewritePatternSet &patterns, + LLVMTypeConverter const &typeConverter) { + ( + [&] { + if (Pattern::requiresSpillsAndFillsConversion()) { + // Add spill/fill conversions with a very high benefit to ensure they + // are lowered first. + patterns.add( + Pattern::ArmSMEOp::getOperationName(), typeConverter, + /*benefit=*/1337); + } + patterns.add(typeConverter); + }(), + ...); } } // namespace @@ -843,16 +874,11 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, return std::nullopt; }); - // Register ops that need spills/fills. - addSpillAndFillsForTileOp< - arm_sme::LoadTileSliceOp, arm_sme::MoveTileSliceToVectorOp, - arm_sme::MoveVectorToTileSliceOp, arm_sme::StoreTileSliceOp, - arm_sme::OuterProductOp, arm_sme::ZeroOp>(patterns, converter); - - patterns.add(converter); + addArmSMEConversionPatterns< + LoadTileSliceConversion, MoveTileSliceToVectorConversion, + MoveVectorToTileSliceConversion, StoreTileSliceConversion, + OuterProductOpConversion, ZeroOpConversion, GetTileConversion, + StreamingVLOpConversion>(patterns, converter); } std::unique_ptr mlir::createConvertArmSMEToLLVMPass() { From 7117bd525ea464f7d6687ad083fb4ee60202d171 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 22 Dec 2023 11:02:28 +0000 Subject: [PATCH 08/13] fixups --- mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 3 ++- mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 07bacffc4b838..eb65cec6854f5 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -270,7 +270,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { // touches (i.e. a single tile slice). { rewriter.setInsertionPoint(op); - // Swap the in-memory tile's contents into ZA before the op. + // Swap the contents of ZA and the in-memory tile before the op. emitTileSwap(); rewriter.setInsertionPointAfter(op); // Swap the tile back out to memory again after the op. @@ -818,6 +818,7 @@ struct ConvertArmSMEToLLVMPass } }; +/// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns. template static void addArmSMEConversionPatterns(RewritePatternSet &patterns, diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index 51a85f516319f..49ea6bb5f8614 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -235,8 +235,8 @@ struct AssignTileIDsPattern if (!tileIsInMemory) setDiscardableIntAttr(kTilesInUseAttr, tilesInUse); else { - // If we could not find a real tile, set use a virtual tile ID (ID >= 16). - // A later pass will insert the necessary spills and reloads. + // If we could not find a real tile ID, use an in-memory tile ID (ID >= + // 16). A later pass will insert the necessary spills and reloads. tileId = getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase); setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1); From 5ffc48e33cc4d843a0f57f779792bd4ab253d9b2 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 22 Dec 2023 11:09:21 +0000 Subject: [PATCH 09/13] Remove space --- mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index eb65cec6854f5..5d85423ab2406 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -349,7 +349,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { enum class RequiresSpillsAndFills { Yes, No }; -/// Base class for ArmSME to LLVM conversion patterns. By default, this adds +/// Base class for ArmSME to LLVM conversion patterns. By default, this adds /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be /// disabled by setting the `requiresSpillsAndFills` template parameter to /// `RequiresSpillsAndFills::No`. From df4e2f716a8fd91af5b94ab51fb88b18c6d34e6e Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 22 Dec 2023 12:07:43 +0000 Subject: [PATCH 10/13] Move helper --- .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 5d85423ab2406..4d4b1e4faa238 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -240,9 +240,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto tileOp = dyn_cast(op); + auto tileOp = cast(op); // Tile has a real (hardware) tile. No spills/reloads required. - if (!tileOp || !tileOp.isInMemoryTile()) + if (!tileOp.isInMemoryTile()) return failure(); // Step 1. Create an alloca for the tile at the top of the function (if one @@ -364,6 +364,31 @@ struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern { } }; +/// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns. +template +static void +addArmSMEConversionPatterns(RewritePatternSet &patterns, + LLVMTypeConverter const &typeConverter) { + ( + [&] { + // Register spills/fills for ops that implement the + // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to + // `RequiresSpillsAndFills::Yes`. + if constexpr (Pattern::requiresSpillsAndFillsConversion() && + std::is_base_of_v, + typename Pattern::ArmSMEOp>) { + // Add spill/fill conversions with a very high benefit to ensure + // they are lowered first. + patterns.add( + Pattern::ArmSMEOp::getOperationName(), typeConverter, + /*benefit=*/1337); + } + patterns.add(typeConverter); + }(), + ...); +} + struct GetTileConversion : public ConvertArmSMEOpToLLVMPattern { @@ -818,25 +843,6 @@ struct ConvertArmSMEToLLVMPass } }; -/// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns. -template -static void -addArmSMEConversionPatterns(RewritePatternSet &patterns, - LLVMTypeConverter const &typeConverter) { - ( - [&] { - if (Pattern::requiresSpillsAndFillsConversion()) { - // Add spill/fill conversions with a very high benefit to ensure they - // are lowered first. - patterns.add( - Pattern::ArmSMEOp::getOperationName(), typeConverter, - /*benefit=*/1337); - } - patterns.add(typeConverter); - }(), - ...); -} - } // namespace void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { From a543dec3da9aa9cb7ed8ea69e34812db8e3a1f9c Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 22 Dec 2023 12:11:30 +0000 Subject: [PATCH 11/13] Remove redundant template parameters --- .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 4d4b1e4faa238..6687f4eae22c0 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -392,9 +392,7 @@ addArmSMEConversionPatterns(RewritePatternSet &patterns, struct GetTileConversion : public ConvertArmSMEOpToLLVMPattern { - using ConvertArmSMEOpToLLVMPattern< - arm_sme::GetTileOp, - RequiresSpillsAndFills::No>::ConvertArmSMEOpToLLVMPattern; + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor, @@ -421,8 +419,7 @@ struct GetTileConversion /// The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away /// once all ArmSME ops have been converted to LLVM intrinsics. struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern { - using ConvertArmSMEOpToLLVMPattern< - arm_sme::ZeroOp>::ConvertArmSMEOpToLLVMPattern; + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor, @@ -499,8 +496,7 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern { /// Lower `arm_sme.load_tile_slice` to SME intrinsics. struct LoadTileSliceConversion : public ConvertArmSMEOpToLLVMPattern { - using ConvertArmSMEOpToLLVMPattern< - arm_sme::LoadTileSliceOp>::ConvertArmSMEOpToLLVMPattern; + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp, @@ -543,8 +539,7 @@ struct LoadTileSliceConversion /// Lower for `arm_sme.store_tile_slice` to SME intrinsics. struct StoreTileSliceConversion : public ConvertArmSMEOpToLLVMPattern { - using ConvertArmSMEOpToLLVMPattern< - arm_sme::StoreTileSliceOp>::ConvertArmSMEOpToLLVMPattern; + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp, @@ -585,8 +580,7 @@ struct StoreTileSliceConversion /// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. struct MoveVectorToTileSliceConversion : public ConvertArmSMEOpToLLVMPattern { - using ConvertArmSMEOpToLLVMPattern< - arm_sme::MoveVectorToTileSliceOp>::ConvertArmSMEOpToLLVMPattern; + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp, @@ -639,8 +633,7 @@ struct MoveVectorToTileSliceConversion /// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. struct MoveTileSliceToVectorConversion : public ConvertArmSMEOpToLLVMPattern { - using ConvertArmSMEOpToLLVMPattern< - arm_sme::MoveTileSliceToVectorOp>::ConvertArmSMEOpToLLVMPattern; + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector, @@ -701,8 +694,7 @@ struct MoveTileSliceToVectorConversion /// Currently only supports FMOPA and BFMOPA (non-widening). struct OuterProductOpConversion : public ConvertArmSMEOpToLLVMPattern { - using ConvertArmSMEOpToLLVMPattern< - arm_sme::OuterProductOp>::ConvertArmSMEOpToLLVMPattern; + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::OuterProductOp outerProductOp, From 607170f3b2d2316451c6951fcf4b6b8a9988ca92 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 10 Jan 2024 10:49:51 +0000 Subject: [PATCH 12/13] Fixup --- .../Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir index fe125c9f3cf16..dd9f280cb7509 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir @@ -50,15 +50,9 @@ func.func @use_too_many_tiles(%a: memref, %b: memref, %c: mem return } -func.func @get_svl() -> index attributes { enable_arm_streaming_ignore, arm_locally_streaming }{ - %vscale = vector.vscale - return %vscale : index -} - func.func @main() { %c16 = arith.constant 16 : index - %svl = call @get_svl() : () -> index - %svl_h = arith.muli %c16, %svl : index + %svl_h = arm_sme.streaming_vl %c2 = arith.constant 2 : i16 %c3 = arith.constant 3 : i16 From 1373bbe89e07cbffbd546a049ddb67751fbdb67b Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 10 Jan 2024 14:12:11 +0000 Subject: [PATCH 13/13] Fixups --- .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 6 +++--- .../ArmSMEToLLVM/tile-spills-and-fills.mlir | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 6687f4eae22c0..f78b06776606e 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -259,7 +259,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { VectorType tileVectorType = tileOp.getTileType(); auto sliceType = VectorType::Builder(tileVectorType).dropDim(0); - auto emitTileSwap = [&] { + auto swapInMemoryTileWithSMETileZero = [&] { emitFullTileSwap(rewriter, loc, tileAlloca, *arm_sme::getSMETileType(tileVectorType), sliceType, zeroTileId); @@ -271,10 +271,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { { rewriter.setInsertionPoint(op); // Swap the contents of ZA and the in-memory tile before the op. - emitTileSwap(); + swapInMemoryTileWithSMETileZero(); rewriter.setInsertionPointAfter(op); // Swap the tile back out to memory again after the op. - emitTileSwap(); + swapInMemoryTileWithSMETileZero(); } return success(); diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir index 6a4ac2dfb05bb..7a9e6b4215754 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir @@ -7,8 +7,8 @@ /// Checks tile spill/reloads are inserted around in-memory tiles (i.e. tiles /// that were not assigned a physical SME tile). /// -/// These spills are currently very naive and paranoid and will spill/reload -/// entire tiles around ArmSME ops. +/// These spills are currently very naive and will spill/reload entire tiles +/// around ArmSME ops. /// /// The general pattern is: /// @@ -34,19 +34,17 @@ /// Then around the op: /// /// // Swap contents of %tileAlloca and tile 0 -/// scf.for %sliceIdx ... { +/// scf.for %sliceIdx ... /// %currentSlice = arm_sme.intr.read.horiz {tile_id = 0} /// arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0} /// vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0] -/// } /// // Execute the op using tile 0 /// arm_sme.intr.zero /// // Swap contents of %tileAlloca and tile 0 -/// scf.for %sliceIdx ... { +/// scf.for %sliceIdx ... /// %currentSlice = arm_sme.intr.read.horiz {tile_id = 0} /// arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0} /// vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0] -/// } /// // ----- @@ -78,10 +76,12 @@ func.func @use_too_many_tiles() { // AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref // // AFTER-LLVM-LOWERING-NOT: scf.for -// AFTER-LLVM-LOWERING: arm_sme.intr.zero +// Note: 17 is the mask for the 32-bit tile 0. +// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> // // AFTER-LLVM-LOWERING-NOT: scf.for -// AFTER-LLVM-LOWERING: arm_sme.intr.zero +// Note: 34 is the mask for the 32-bit tile 1. +// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 34 : i32}> // // AFTER-LLVM-LOWERING: scf.for // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { @@ -92,7 +92,8 @@ func.func @use_too_many_tiles() { // AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> // AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] // AFTER-LLVM-LOWERING-NEXT: } -// AFTER-LLVM-LOWERING: arm_sme.intr.zero +// Note: 85 is the mask for the 16-bit tile 0. +// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> // AFTER-LLVM-LOWERING: scf.for // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { // AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]