diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h index 9982d4278b603..c507cea5357a7 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h @@ -25,8 +25,9 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir::arm_sme { +static constexpr unsigned kInMemoryTileIdBase = 16; #include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc" -} +} // namespace mlir::arm_sme #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc" diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index bb0db59add009..8a34ad7e52012 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -97,6 +97,11 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> { // This operation does not allocate a tile. return std::nullopt; }] + >, + InterfaceMethod< + "Returns the VectorType of the tile used by this operation.", + /*returnType=*/"VectorType", + /*methodName=*/"getTileType" > ]; @@ -117,6 +122,11 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> { rewriter.replaceOp($_op, newOp); return newOp; } + + bool isInMemoryTile() { + auto tileId = getTileId(); + return tileId && tileId.getInt() >= kInMemoryTileIdBase; + } }]; let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }]; @@ -331,6 +341,9 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> { std::optional getAllocatedTileType() { return arm_sme::getSMETileType(getVectorType()); } + VectorType getTileType() { + return getVectorType(); + } }]; let assemblyFormat = "attr-dict `:` type($res)"; } @@ -407,6 +420,9 @@ def TileLoadOp : ArmSME_Op<"tile_load", [ std::optional getAllocatedTileType() { return arm_sme::getSMETileType(getVectorType()); } + VectorType getTileType() { + return getVectorType(); + } }]; let builders = [ @@ -475,6 +491,9 @@ def TileStoreOp : ArmSME_Op<"tile_store", [ VectorType getVectorType() { return ::llvm::cast(getValueToStore().getType()); } + VectorType getTileType() { + return getVectorType(); + } }]; let builders = [ @@ -539,6 +558,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ VectorType getVectorType() { return ::llvm::cast(getResult().getType()); } + VectorType getTileType() { + return getVectorType(); + } }]; let assemblyFormat = [{ @@ -596,6 +618,9 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [ VectorType getVectorType() { return ::llvm::cast(getTile().getType()); } + VectorType getTileType() { + return getVectorType(); + } }]; let assemblyFormat = [{ @@ -688,6 +713,9 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [ let extraClassDeclaration = [{ VectorType getSliceType() { return getResult().getType(); } + VectorType getTileType() { + return ::llvm::cast(getTile().getType()); + } }]; let assemblyFormat = [{ @@ -780,6 +808,9 @@ let arguments = (ins return arm_sme::getSMETileType(getResultType()); return std::nullopt; } + VectorType getTileType() { + return getResultType(); + } }]; } diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 0bb7ccb463e48..f78b06776606e 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -32,6 +33,8 @@ using namespace mlir; namespace { +static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id"); + /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic. static Operation *createLoadTileSliceIntrinsic( RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, @@ -129,8 +132,267 @@ IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) { return tileId; } -struct GetTileConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is +/// placed in the first block of the function. +static memref::AllocaOp +createAllocaForTile(RewriterBase &rewriter, Location loc, + FunctionOpInterface func, + arm_sme::ArmSMETileOpInterface tileOp) { + RewriterBase::InsertionGuard g(rewriter); + // Move to the first operation in the function. + rewriter.setInsertionPointToStart(&func.getBlocks().front()); + // Create an alloca matching the tile size of the `tileOp`. + auto vscale = rewriter.create(loc); + auto tileElementType = tileOp.getTileType().getElementType(); + auto memrefType = MemRefType::get( + {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); + unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType); + auto minElementsOp = + rewriter.create(loc, minElements); + auto vectorLen = rewriter.create(loc, vscale, minElementsOp); + auto alloca = rewriter.create( + loc, memrefType, ValueRange{vectorLen, vectorLen}); + return alloca; +} + +/// Finds or creates an alloca for a spill of a tile. +static memref::AllocaOp getOrCreateAllocaForTile( + RewriterBase &rewriter, Location loc, FunctionOpInterface func, + arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) { + // Find an alloca at the top of the function tagged with a + // 'arm_sme.in_memory_tile_id' that matches `tileId`. + for (auto &op : func.getBlocks().front()) { + auto alloca = llvm::dyn_cast(op); + if (!alloca) + continue; + auto inMemoryTileId = llvm::dyn_cast_or_null( + alloca->getDiscardableAttr(kInMemoryTileIdAttr)); + if (!inMemoryTileId) + continue; + if (inMemoryTileId.getInt() == tileId) + return alloca; + } + // Otherwise, create a new alloca: + auto alloca = createAllocaForTile(rewriter, loc, func, tileOp); + alloca->setDiscardableAttr(kInMemoryTileIdAttr, + rewriter.getI32IntegerAttr(tileId)); + return alloca; +} + +/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a +/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning +/// the op to tile 0, then emitting a full tile swap between ZA and memory +/// before + after the tile op. +/// +/// Example: +/// +/// // Note: = tile ID >= 16. +/// arm_sme.tile_op { tile_id = } +/// +/// is converted to: +/// // At function entry: +/// %spill = memref.alloca ... : memref +/// +/// // Around op: +/// scf.for %slice_idx { +/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> +/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> +/// vector.store %slice_to_save, %spill[%slice_idx, %c0] +/// } +/// arm_sme.tile_op { tile_id = 0 } +/// scf.for %slice_idx { +/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> +/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> +/// vector.store %slice_to_save, %spill[%slice_idx, %c0] +/// } +/// +/// Note that these spills/fills are not inserted earlier as concept of a +/// register, and the need to swap the contents, can't really be represented +/// correctly at a high level in MLIR. +/// +/// TODO: Reduce the spills/reloads to single slices where possible (and omit +/// redundant reloads). This could be done via a method on the +/// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.: +/// +/// `tileOp.getZaUsage()` could return: +/// +/// struct ArmSMEOpZAUsage { +/// enum class Kind { +/// TileRead, // Omit store after tile operation. +/// TileWrite, // Omit load before tile operation. +/// TileReadWrite, // Needs both tile load and store. +/// SliceRead, // Spill single slice and omit store after operation. +/// SliceWrite, // Spill single slice and omit load before operation. +/// SliceReadWrite // Spill single slice. +/// }; +/// Value sliceIndex {}; +/// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal }; +/// }; +/// +struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { + + ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName, + const LLVMTypeConverter &typeConverter, + PatternBenefit benefit) + : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(), + typeConverter, benefit) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto tileOp = cast(op); + // Tile has a real (hardware) tile. No spills/reloads required. + if (!tileOp.isInMemoryTile()) + return failure(); + + // Step 1. Create an alloca for the tile at the top of the function (if one + // does not already exist). + auto loc = tileOp.getLoc(); + auto func = tileOp->getParentOfType(); + auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp, + tileOp.getTileId().getInt()); + + // Step 2. Assign the op a real tile ID. + // For simplicity, we always use tile 0 (which always exists). + auto zeroTileId = rewriter.getI32IntegerAttr(0); + rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); }); + + VectorType tileVectorType = tileOp.getTileType(); + auto sliceType = VectorType::Builder(tileVectorType).dropDim(0); + auto swapInMemoryTileWithSMETileZero = [&] { + emitFullTileSwap(rewriter, loc, tileAlloca, + *arm_sme::getSMETileType(tileVectorType), sliceType, + zeroTileId); + }; + + // Step 3. Emit tile swaps before and after the op. + // TODO: Reduce the amount spilled to the amount of data the `tileOp` + // touches (i.e. a single tile slice). + { + rewriter.setInsertionPoint(op); + // Swap the contents of ZA and the in-memory tile before the op. + swapInMemoryTileWithSMETileZero(); + rewriter.setInsertionPointAfter(op); + // Swap the tile back out to memory again after the op. + swapInMemoryTileWithSMETileZero(); + } + + return success(); + } + + /// Extracts a pointer to a slice of an in-memory tile. + Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc, + Value tileMemory, Value sliceIndex) const { + auto llvmType = getTypeConverter()->convertType(tileMemory.getType()); + auto descriptor = + rewriter.create(loc, llvmType, tileMemory); + auto zero = rewriter.create(loc, 0, /*width=*/64); + auto sliceIndexI64 = rewriter.create( + loc, rewriter.getI64Type(), sliceIndex); + return getStridedElementPtr( + loc, llvm::cast(tileMemory.getType()), + descriptor.getResult(0), {sliceIndexI64, zero}, + static_cast(rewriter)); + } + + /// Emits an in-place swap of a slice of a tile in ZA and a slice of a + /// tile-sized memref (`tileAlloca`). + void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca, + arm_sme::ArmSMETileType tileType, VectorType sliceType, + IntegerAttr tileId, Value sliceIndex) const { + // Cast the slice index to an i32. + auto sliceIndexI32 = rewriter.create( + loc, rewriter.getI32Type(), sliceIndex); + // Create an all-true predicate for the slice. + auto predicateType = sliceType.clone(rewriter.getI1Type()); + auto allTruePredicate = rewriter.create( + loc, DenseElementsAttr::get(predicateType, true)); + // Create padding vector (never used due to all-true predicate). + auto padVector = rewriter.create(loc, sliceType); + // Get a pointer to the current slice. + auto slicePtr = + getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex); + // Read the value of the current slice from ZA. + auto currentTileSlice = rewriter.create( + loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32); + // Load the new tile slice back from memory into ZA. + createLoadTileSliceIntrinsic( + rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, + allTruePredicate, slicePtr, tileId, sliceIndexI32); + // Store the current tile slice to memory. + auto zero = rewriter.create(loc, 0); + rewriter.create(loc, currentTileSlice, tileAlloca, + ValueRange{sliceIndex, zero}); + } + + /// Emits a full in-place swap of the contents of a tile in ZA and a + /// tile-sized memref (`tileAlloca`). + void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca, + arm_sme::ArmSMETileType tileType, VectorType sliceType, + IntegerAttr tileId) const { + RewriterBase::InsertionGuard guard(rewriter); + // Create an scf.for over all tile slices. + auto minNumElts = + rewriter.create(loc, sliceType.getDimSize(0)); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = rewriter.create( + loc, minNumElts, rewriter.create(loc)); + auto step = rewriter.create(loc, 1); + auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + // Emit a swap for each tile slice. + rewriter.setInsertionPointToStart(forOp.getBody()); + auto sliceIndex = forOp.getInductionVar(); + emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId, + sliceIndex); + } +}; + +enum class RequiresSpillsAndFills { Yes, No }; + +/// Base class for ArmSME to LLVM conversion patterns. By default, this adds +/// spills and fills around ArmSME ops that use in-memory tile IDs. This can be +/// disabled by setting the `requiresSpillsAndFills` template parameter to +/// `RequiresSpillsAndFills::No`. +template +struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern { + using ArmSMEOp = SourceOp; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + static constexpr bool requiresSpillsAndFillsConversion() { + return requiresSpillsAndFills == RequiresSpillsAndFills::Yes; + } +}; + +/// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns. +template +static void +addArmSMEConversionPatterns(RewritePatternSet &patterns, + LLVMTypeConverter const &typeConverter) { + ( + [&] { + // Register spills/fills for ops that implement the + // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to + // `RequiresSpillsAndFills::Yes`. + if constexpr (Pattern::requiresSpillsAndFillsConversion() && + std::is_base_of_v, + typename Pattern::ArmSMEOp>) { + // Add spill/fill conversions with a very high benefit to ensure + // they are lowered first. + patterns.add( + Pattern::ArmSMEOp::getOperationName(), typeConverter, + /*benefit=*/1337); + } + patterns.add(typeConverter); + }(), + ...); +} + +struct GetTileConversion + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor, @@ -156,8 +418,8 @@ struct GetTileConversion : public ConvertOpToLLVMPattern { /// /// The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away /// once all ArmSME ops have been converted to LLVM intrinsics. -struct ZeroOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor, @@ -233,9 +495,8 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern { /// Lower `arm_sme.load_tile_slice` to SME intrinsics. struct LoadTileSliceConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp, @@ -277,9 +538,8 @@ struct LoadTileSliceConversion /// Lower for `arm_sme.store_tile_slice` to SME intrinsics. struct StoreTileSliceConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp, @@ -319,9 +579,8 @@ struct StoreTileSliceConversion /// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. struct MoveVectorToTileSliceConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp, @@ -373,9 +632,8 @@ struct MoveVectorToTileSliceConversion /// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. struct MoveTileSliceToVectorConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector, @@ -435,8 +693,8 @@ struct MoveTileSliceToVectorConversion /// /// Currently only supports FMOPA and BFMOPA (non-widening). struct OuterProductOpConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::OuterProductOp outerProductOp, @@ -530,8 +788,9 @@ struct OuterProductOpConversion /// %0 = arith.index_cast %cnt : i64 to index /// struct StreamingVLOpConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; LogicalResult matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp, @@ -597,7 +856,10 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>(); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); } @@ -611,10 +873,11 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, return std::nullopt; }); - patterns.add(converter); + addArmSMEConversionPatterns< + LoadTileSliceConversion, MoveTileSliceToVectorConversion, + MoveVectorToTileSliceConversion, StoreTileSliceConversion, + OuterProductOpConversion, ZeroOpConversion, GetTileConversion, + StreamingVLOpConversion>(patterns, converter); } std::unique_ptr mlir::createConvertArmSMEToLLVMPass() { diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp index 6105cd6225283..1fa060cafc0bc 100644 --- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp @@ -69,7 +69,6 @@ LogicalResult verifyOperationHasValidTileId(Operation *op) { return success(); // Not having a tile ID (yet) is okay. if (!tileId.getType().isSignlessInteger(32)) return tileOp.emitOpError("tile ID should be a 32-bit signless integer"); - // TODO: Verify value of tile ID is in range. return success(); } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index 8aa51f352f822..49ea6bb5f8614 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -61,7 +61,9 @@ using namespace mlir::arm_sme; namespace { -static constexpr char kTilesInUseAttr[] = "arm_sme.tiles_in_use"; +static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use"); +static constexpr StringLiteral + kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id"); enum class TileMask : unsigned { // clang-format off @@ -200,7 +202,6 @@ static void findDependantOps(Value rootValue, }); } } - struct AssignTileIDsPattern : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; @@ -209,28 +210,40 @@ struct AssignTileIDsPattern if (tileOp.getTileId()) return failure(); + auto func = tileOp->getParentOfType(); + auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) { + if (auto attr = llvm::dyn_cast_or_null( + func->getDiscardableAttr(name))) + return unsigned(attr.getInt()); + return defaultVal; + }; + auto setDiscardableIntAttr = [&](StringRef name, auto value) { + rewriter.updateRootInPlace(tileOp, [&] { + func->setDiscardableAttr(name, + rewriter.getI32IntegerAttr((unsigned)value)); + }); + }; + std::optional tileType = tileOp.getAllocatedTileType(); if (!tileType) return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile"); - auto func = tileOp->getParentOfType(); - TileMask tilesInUse = TileMask::kNone; - if (auto tilesInUseAttr = llvm::dyn_cast_or_null( - func->getDiscardableAttr(kTilesInUseAttr))) - tilesInUse = static_cast(tilesInUseAttr.getInt()); - + TileMask tilesInUse = + static_cast(getDiscardableIntAttr(kTilesInUseAttr)); auto tileId = allocateTileId(*tileType, tilesInUse); - if (failed(tileId)) - return tileOp.emitError("ran out of SME virtual tiles!"); - - rewriter.updateRootInPlace(func, [&]() { - func->setDiscardableAttr( - kTilesInUseAttr, rewriter.getI32IntegerAttr((unsigned)tilesInUse)); - }); - - // Find all the ops that (transitively) depend on this tile. - SetVector dependantOps; - findDependantOps(tileOp->getResult(0), dependantOps); + bool tileIsInMemory = failed(tileId); + if (!tileIsInMemory) + setDiscardableIntAttr(kTilesInUseAttr, tilesInUse); + else { + // If we could not find a real tile ID, use an in-memory tile ID (ID >= + // 16). A later pass will insert the necessary spills and reloads. + tileId = + getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase); + setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1); + tileOp->emitWarning( + "failed to allocate SME virtual tile to operation, all tile " + "operations will go through memory, expect degraded performance"); + } // Set all operations dependent on `tileOp` to use the same tile ID. // This is a naive tile allocation scheme, but works for common cases. For @@ -246,16 +259,18 @@ struct AssignTileIDsPattern // This case would require allocating a new tile for the result of the // scf.if, and moving the contents of %tileA or %tileB to result tile (based // on the %some_cond). + // Find all the ops that (transitively) depend on this tile. + SetVector dependantOps; + findDependantOps(tileOp->getResult(0), dependantOps); auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId); - rewriter.updateRootInPlace(tileOp, [&]() { tileOp.setTileId(tileIDAttr); }); + rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); }); for (auto *op : dependantOps) { - if (auto tileOp = llvm::dyn_cast(op)) { - auto currentTileId = tileOp.getTileId(); + if (auto dependantTileOp = llvm::dyn_cast(op)) { + auto currentTileId = dependantTileOp.getTileId(); if (currentTileId && unsigned(currentTileId.getInt()) != tileId) - return tileOp.emitOpError( + return dependantTileOp.emitOpError( "already assigned different SME virtual tile!"); - rewriter.updateRootInPlace(tileOp, - [&]() { tileOp.setTileId(tileIDAttr); }); + dependantTileOp.setTileId(tileIDAttr); } } diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir new file mode 100644 index 0000000000000..7a9e6b4215754 --- /dev/null +++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir @@ -0,0 +1,154 @@ +// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | \ +// RUN: FileCheck %s --check-prefix=AFTER-TILE-ALLOC +// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize -cse \ +// RUN: -split-input-file -verify-diagnostics | \ +// RUN: FileCheck %s --check-prefix=AFTER-LLVM-LOWERING + +/// Checks tile spill/reloads are inserted around in-memory tiles (i.e. tiles +/// that were not assigned a physical SME tile). +/// +/// These spills are currently very naive and will spill/reload entire tiles +/// around ArmSME ops. +/// +/// The general pattern is: +/// +/// During tile allocation if there's not a physical tile ID available an op +/// will be assigned an in-memory tile ID (which is a tile ID >= 16). +/// +/// Example: +/// +/// arm_sme.zero : vector<[8]x[8]xi16> +/// +/// Becomes: +/// +/// arm_sme.zero { tile_id = 16 } : vector<[8]x[8]xi16> +/// +/// This works like normal until the final lowering to LLVM, where spills and +/// reloads will be inserted around uses of in-memory tiles. +/// +/// So the above example becomes: +/// +/// // Placed at the top of the function: +/// %tileAlloca = memref.alloca(%svl_h, %svl_h) : memref +/// +/// Then around the op: +/// +/// // Swap contents of %tileAlloca and tile 0 +/// scf.for %sliceIdx ... +/// %currentSlice = arm_sme.intr.read.horiz {tile_id = 0} +/// arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0} +/// vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0] +/// // Execute the op using tile 0 +/// arm_sme.intr.zero +/// // Swap contents of %tileAlloca and tile 0 +/// scf.for %sliceIdx ... +/// %currentSlice = arm_sme.intr.read.horiz {tile_id = 0} +/// arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0} +/// vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0] +/// + +// ----- + +/// Note: In this example loads into ZA are inserted before the zero instruction. +/// These are obviously redundant, but there's no checks to avoid this. +func.func @use_too_many_tiles() { + %0 = arm_sme.zero : vector<[4]x[4]xi32> + %1 = arm_sme.zero : vector<[4]x[4]xi32> + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} + %2 = arm_sme.zero : vector<[8]x[8]xi16> + return +} +// AFTER-TILE-ALLOC-LABEL: @use_too_many_tiles +// AFTER-TILE-ALLOC: arm_sme.zero +// AFTER-TILE-ALLOC-SAME: tile_id = 0 +// AFTER-TILE-ALLOC: arm_sme.zero +// AFTER-TILE-ALLOC-SAME: tile_id = 1 +// AFTER-TILE-ALLOC: arm_sme.zero +// AFTER-TILE-ALLOC-SAME: tile_id = 16 + +// AFTER-LLVM-LOWERING-LABEL: @use_too_many_tiles +// AFTER-LLVM-LOWERING-DAG: %[[C0:.*]] = arith.constant 0 : index +// AFTER-LLVM-LOWERING-DAG: %[[C1:.*]] = arith.constant 1 : index +// AFTER-LLVM-LOWERING-DAG: %[[C8:.*]] = arith.constant 8 : index +// AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale +// AFTER-LLVM-LOWERING-DAG: %[[SVL_H:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index +// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_H]], %[[SVL_H]]) +// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref +// +// AFTER-LLVM-LOWERING-NOT: scf.for +// Note: 17 is the mask for the 32-bit tile 0. +// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> +// +// AFTER-LLVM-LOWERING-NOT: scf.for +// Note: 34 is the mask for the 32-bit tile 1. +// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 34 : i32}> +// +// AFTER-LLVM-LOWERING: scf.for +// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { +// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] +// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] +// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING-NEXT: } +// Note: 85 is the mask for the 16-bit tile 0. +// AFTER-LLVM-LOWERING: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> +// AFTER-LLVM-LOWERING: scf.for +// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { +// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] +// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] +// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING-NEXT: } + +// ----- + +/// Note: In this example an entire tile swap is inserted before/after the +/// `arm_sme.load_tile_slice` operation. Really, this only needs to spill a +/// single tile slice (and can omit the initial load, like in the previous example). +func.func @very_excessive_spills(%memref : memref) -> vector<[4]x[4]xf32> { + %useAllTiles = arm_sme.get_tile : vector<[16]x[16]xi8> + %c0 = arith.constant 0 : index + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} + %tile = arm_sme.get_tile : vector<[4]x[4]xf32> + %mask = vector.constant_mask [4] : vector<[4]xi1> + %loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref, vector<[4]xi1>, vector<[4]x[4]xf32> + return %loadSlice : vector<[4]x[4]xf32> +} +// AFTER-TILE-ALLOC-LABEL: @very_excessive_spills +// AFTER-TILE-ALLOC: arm_sme.get_tile +// AFTER-TILE-ALLOC-SAME: tile_id = 0 +// AFTER-TILE-ALLOC: arm_sme.load_tile_slice +// AFTER-TILE-ALLOC-SAME: tile_id = 16 + +// AFTER-LLVM-LOWERING-LABEL: @very_excessive_spills +// AFTER-LLVM-LOWERING-DAG: %[[C0:.*]] = arith.constant 0 : index +// AFTER-LLVM-LOWERING-DAG: %[[C1:.*]] = arith.constant 1 : index +// AFTER-LLVM-LOWERING-DAG: %[[C4:.*]] = arith.constant 4 : index +// AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale +// AFTER-LLVM-LOWERING-DAG: %[[SVL_S:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index +// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]]) +// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref +// +// AFTER-LLVM-LOWERING: scf.for +// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] { +// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] +// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] +// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING-NEXT: } +// AFTER-LLVM-LOWERING: "arm_sme.intr.ld1w.horiz"{{.*}} <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING: scf.for +// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] { +// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] +// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] +// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> +// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING-NEXT: } diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir index 1f895e4984ba8..9c368dd4fa23f 100644 --- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir +++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir @@ -35,7 +35,7 @@ func.func @za_b() { func.func @za_b__out_of_tiles() { %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[16]x[16]xi8> return } @@ -44,7 +44,7 @@ func.func @za_b__out_of_tiles() { func.func @za_b_overlapping_za_q() { %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -79,7 +79,7 @@ func.func @za_h__out_of_tiles() { %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16> // CHECK-NEXT: tile_id = 1 %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[8]x[8]xi16> return } @@ -136,7 +136,7 @@ func.func @za_h_overlapping_za_q() { %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -174,7 +174,7 @@ func.func @za_s__out_of_tiles() { %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32> %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32> %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[4]x[4]xi32> return } @@ -218,7 +218,7 @@ func.func @za_s_overlapping_za_q() { %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -268,7 +268,7 @@ func.func @za_d__out_of_tiles() { %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64> %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64> %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[2]x[2]xi64> return } @@ -291,7 +291,7 @@ func.func @za_d_overlapping_za_q() { %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } @@ -365,7 +365,7 @@ func.func @za_q__out_of_tiles() { %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128> %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128> - // expected-error@+1 {{ran out of SME virtual tiles!}} + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128> return } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir new file mode 100644 index 0000000000000..dd9f280cb7509 --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir @@ -0,0 +1,71 @@ +// RUN: mlir-opt %s \ +// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles \ +// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ +// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \ +// RUN: -convert-arm-sme-to-llvm -convert-vector-to-llvm=enable-arm-sve -cse \ +// RUN: -canonicalize -test-lower-to-llvm -verify-diagnostics | \ +// RUN: %mcr_aarch64_cmd \ +// RUN: -e=main -entry-point-result=void \ +// RUN: -march=aarch64 -mattr="+sve,+sme" \ +// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \ +// RUN: FileCheck %s + +/// This function uses too many tiles! There's only two i16 tiles (ZA0.H and +/// ZA1.H), but this function uses five i16 tiles! Very expensive spills/reloads +/// will be inserted to emulate the extra three tiles. Note: This is only done +/// to avoid the compiler erroring out but is expected to have very poor +/// performance (hence the warning). +func.func @use_too_many_tiles(%a: memref, %b: memref, %c: memref) { + %c0 = arith.constant 0 : index + %tile_a = arith.constant dense<0> : vector<[8]x[8]xi16> + %tile_b = arith.constant dense<1> : vector<[8]x[8]xi16> + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} + %tile_c = arm_sme.tile_load %a[%c0, %c0] : memref, vector<[8]x[8]xi16> + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} + %tile_d = arm_sme.tile_load %b[%c0, %c0] : memref, vector<[8]x[8]xi16> + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} + %tile_e = arm_sme.tile_load %c[%c0, %c0] : memref, vector<[8]x[8]xi16> + + // CHECK-LABEL: tile_a: + // CHECK-COUNT-8: ( 0, 0, 0, 0, 0, 0, 0, 0 + vector.print str "tile_a:" + vector.print %tile_a : vector<[8]x[8]xi16> + // CHECK-LABEL: tile_b: + // CHECK-COUNT-8: ( 1, 1, 1, 1, 1, 1, 1, 1 + vector.print str "tile_b:" + vector.print %tile_b : vector<[8]x[8]xi16> + // CHECK-LABEL: tile_c: + // CHECK-COUNT-8: ( 2, 2, 2, 2, 2, 2, 2, 2 + vector.print str "tile_c:" + vector.print %tile_c : vector<[8]x[8]xi16> + // CHECK-LABEL: tile_d: + // CHECK-COUNT-8: ( 3, 3, 3, 3, 3, 3, 3, 3 + vector.print str "tile_d:" + vector.print %tile_d : vector<[8]x[8]xi16> + // CHECK-LABEL: tile_e: + // CHECK-COUNT-8: ( 4, 4, 4, 4, 4, 4, 4, 4 + vector.print str "tile_e:" + vector.print %tile_e : vector<[8]x[8]xi16> + return +} + +func.func @main() { + %c16 = arith.constant 16 : index + %svl_h = arm_sme.streaming_vl + + %c2 = arith.constant 2 : i16 + %c3 = arith.constant 3 : i16 + %c4 = arith.constant 4 : i16 + + %memA = memref.alloca(%svl_h, %svl_h) : memref + %memB = memref.alloca(%svl_h, %svl_h) : memref + %memC = memref.alloca(%svl_h, %svl_h) : memref + + linalg.fill ins(%c2 : i16) outs(%memA : memref) + linalg.fill ins(%c3 : i16) outs(%memB : memref) + linalg.fill ins(%c4 : i16) outs(%memC : memref) + + func.call @use_too_many_tiles(%memA, %memB, %memC) : (memref, memref, memref) -> () + return +}