Skip to content

[mlir][ArmSME] Add rudimentary support for tile spills to the stack #76086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 12, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Dec 20, 2023

This adds very basic (and inelegant) support for something like spilling and reloading tiles, if you use more SME tiles than physically exist.

This is purely implemented to prevent the compiler from aborting if a function uses too many tiles (i.e. due to bad unrolling), but is expected to perform very poorly.

Currently, this works in two stages:

During tile allocation, if we run out of tiles instead of giving up, we switch to allocating 'in-memory' tile IDs. These are tile IDs that start at 16 (which is higher than any real tile ID). A warning will also be emitted for each (root) tile op assigned an in-memory tile ID:

warning: failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance

Everything after this works like normal until -convert-arm-sme-to-llvm

Here the in-memory tile op:

arm_sme.tile_op { tile_id = <IN MEMORY TILE> }

Is lowered to:

// At function entry:
%alloca = memref.alloca ... : memref<?x?xty>

// Around the op:
// Swap the contents of %alloca and tile 0.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}
// Execute op using tile 0.
arm_sme.tile_op { tile_id = 0 }
// Swap the contents of %alloca and tile 0.
// This restores tile 0 to its original state.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}

This is inserted during the lowering to LLVM as spilling/reloading registers is a very low-level concept, that can't really be modeled correctly at a high level in MLIR.

Note: This is always doing the worst case full-tile swap. This could be optimized to only spill/load data the tile op will use, which could be just a slice. It's also not making any use of liveness, which could allow reusing tiles. But these is not seen as important as correct code should only use the available number of tiles.

@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

Changes

This adds very basic (and inelegant) support for something like spilling and reloading tiles, if you use more SME tiles than physically exist.

This is purely implemented to prevent the compiler from aborting if a function uses too many tiles (i.e. due to bad unrolling), but is expected to perform very poorly.

Currently, this works in two stages:

During tile allocation, if we run out of tiles instead of giving up, we switch to allocating 'in-memory' tile IDs. These are tile IDs that start at 16 (which is higher than any real tile ID). A warning will also be emitted for each (root) tile op assigned an in-memory tile ID:

warning: failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation

Everything after this works like normal until -convert-arm-sme-to-llvm

Here the in-memory tile op:

arm_sme.tile_op { tile_id = &lt;IN MEMORY TILE&gt; }

Is lowered to:

// At function entry:
%alloca = memref.alloca ... : memref&lt;?x?xty&gt;

// Around the op:
// Swap the contents of %alloca and tile 0.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... &lt;{tile_id = 0 : i32}&gt;
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  &lt;{tile_id = 0 : i32}&gt;
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}
// Execute op using tile 0.
arm_sme.tile_op { tile_id = 0 }
// Swap the contents of %alloca and tile 0.
// This restores tile 0 to its original state.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... &lt;{tile_id = 0 : i32}&gt;
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  &lt;{tile_id = 0 : i32}&gt;
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}

This is inserted during the lowering to LLVM as spilling/reloading registers is a very low-level concept, that can't really be modeled correctly at a high level in MLIR.

Note: This is always doing the worst case full-tile swap. This could be optimized to only spill/load data the tile op will use, which could be just a slice. It's also not making any use of liveness, which could allow reusing tiles. But these is not seen as important as correct code should only use the available number of tiles.


Patch is 44.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76086.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h (+2-1)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+33)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+338-120)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (-1)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp (+33-15)
  • (added) mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir (+96)
  • (modified) mlir/test/Dialect/ArmSME/tile-allocation.mlir (+9-9)
  • (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir (+78)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index 9982d4278b6033..c507cea5357a74 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -25,8 +25,9 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 namespace mlir::arm_sme {
+static constexpr unsigned kInMemoryTileIdBase = 16;
 #include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
-}
+} // namespace mlir::arm_sme
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index f7cc1d3fe7517f..adb3fae87e1017 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -97,6 +97,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
         // This operation does not allocate a tile.
         return std::nullopt;
       }]
+    >,
+    InterfaceMethod<
+      "Returns the VectorType of the tile used by this operation.",
+      /*returnType=*/"VectorType",
+      /*methodName=*/"getTileType",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}]
     >
   ];
 
@@ -117,6 +124,11 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
       rewriter.replaceOp($_op, newOp);
       return newOp;
     }
+
+    bool isInMemoryTile() {
+      auto tileId = getTileId();
+      return tileId && tileId.getInt() >= kInMemoryTileIdBase;
+    }
   }];
 
   let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
@@ -316,6 +328,9 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
     std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
       return arm_sme::getSMETileType(getVectorType());
     }
+    VectorType getTileType() {
+      return getVectorType();
+    }
   }];
   let assemblyFormat = "attr-dict `:` type($res)";
 }
@@ -392,6 +407,9 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
       return arm_sme::getSMETileType(getVectorType());
     }
+    VectorType getTileType() {
+      return getVectorType();
+    }
   }];
 
   let builders = [
@@ -460,6 +478,9 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getValueToStore().getType());
     }
+    VectorType getTileType() {
+      return getVectorType();
+    }
   }];
 
   let builders = [
@@ -524,6 +545,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
+    VectorType getTileType() {
+      return getVectorType();
+    }
   }];
 
   let assemblyFormat = [{
@@ -581,6 +605,9 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getTile().getType());
     }
+    VectorType getTileType() {
+      return getVectorType();
+    }
   }];
 
   let assemblyFormat = [{
@@ -673,6 +700,9 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
 
   let extraClassDeclaration = [{
     VectorType getSliceType() { return getResult().getType(); }
+    VectorType getTileType() {
+      return ::llvm::cast<VectorType>(getTile().getType());
+    }
   }];
 
   let assemblyFormat = [{
@@ -765,6 +795,9 @@ let arguments = (ins
         return arm_sme::getSMETileType(getResultType());
       return std::nullopt;
     }
+    VectorType getTileType() {
+      return getResultType();
+    }
   }];
 }
 
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index f9d6f04a811f3e..131f734b4c7485 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -32,6 +33,97 @@ using namespace mlir;
 
 namespace {
 
+static constexpr StringLiteral kInMemoryTileId("arm_sme.in_memory_tile_id");
+
+/// Helper to create a arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
+static Operation *createLoadTileSliceIntrinsic(
+    RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
+    arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
+    IntegerAttr tileId, Value tileSliceI32) {
+  if (layout == arm_sme::TileSliceLayout::Horizontal) {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    }
+  } else {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+      break;
+    }
+  }
+}
+
+/// Helper to create a arm_sme.intr.st1*.(horiz|vert)' intrinsic.
+static Operation *createStoreTileSliceIntrinsic(
+    RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
+    arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
+    IntegerAttr tileId, Value tileSliceI32) {
+  if (layout == arm_sme::TileSliceLayout::Horizontal) {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    }
+  } else {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    }
+  }
+}
+
 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
   auto tileId = op.getTileId();
   if (!tileId)
@@ -40,6 +132,209 @@ IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
   return tileId;
 }
 
+/// Creates a alloca matching the size of tile used by `tileOp`. The alloca is
+/// placed in the first block of the function.
+static memref::AllocaOp
+createAllocaForTile(RewriterBase &rewriter, Location loc,
+                    FunctionOpInterface func,
+                    arm_sme::ArmSMETileOpInterface tileOp) {
+  RewriterBase::InsertionGuard g(rewriter);
+  // Move to the first operation in the function.
+  rewriter.setInsertionPoint(&func.getBlocks().front().front());
+  // Create an alloca matching the tile size of the `tileOp`.
+  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+  auto tileElementType =
+      llvm::cast<VectorType>(tileOp.getTileType()).getElementType();
+  auto memrefType = MemRefType::get(
+      {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
+  auto minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
+  auto minElementsOp =
+      rewriter.create<arith::ConstantIndexOp>(loc, minElements);
+  auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
+  auto alloca = rewriter.create<memref::AllocaOp>(
+      loc, memrefType, ValueRange{vectorLen, vectorLen});
+  return alloca;
+}
+
+/// Finds or creates an alloca for a spill of a tile.
+static memref::AllocaOp
+getOrCreateTileMemory(RewriterBase &rewriter, Location loc,
+                      FunctionOpInterface func,
+                      arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
+  // Find an alloca at the top of the function tagged with a
+  // 'arm_sme.in_memory_tile_id' that matches `tileId`.
+  for (auto &op : func.getBlocks().front()) {
+    auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
+    if (!alloca)
+      continue;
+    auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
+        alloca->getDiscardableAttr(kInMemoryTileId));
+    if (!inMemoryTileId)
+      continue;
+    if (inMemoryTileId.getInt() == tileId)
+      return alloca;
+  }
+  // Otherwise, create a new alloca:
+  auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
+  alloca->setDiscardableAttr(kInMemoryTileId,
+                             rewriter.getI32IntegerAttr(tileId));
+  return alloca;
+}
+
+/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
+/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
+/// the op to tile 0, then emitting a full tile swap between ZA and memory
+/// before + after the tile op.
+///
+/// Example:
+///
+///    arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
+///
+/// is converted to:
+///     // At function entry:
+///     %alloca = memref.alloca ... : memref<?x?xty>
+///
+///     // Around op:
+///     scf.for %slice_idx {
+///       %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
+///       "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
+///       vector.store %current_slice, %alloca[%slice_idx, %c0]
+///     }
+///     arm_sme.tile_op { tile_id = 0 }
+///     scf.for %slice_idx {
+///       %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
+///       "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
+///       vector.store %current_slice, %alloca[%slice_idx, %c0]
+///     }
+///
+/// Note that these spills/fills are not inserted earlier as concept of a
+/// register, and the need to swap the contents, can't really be represented
+/// correctly at a high level in MLIR.
+///
+/// TODO: Reduce the spills/reloads to single slices where possible.
+struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
+
+  ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
+                                    const LLVMTypeConverter &typeConverter,
+                                    PatternBenefit benefit)
+      : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
+                             typeConverter, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
+    // Tile has a real (hardware) tile. No spills/reloads required.
+    if (!tileOp.isInMemoryTile())
+      return failure();
+
+    // Step 1. Create an alloca for the tile at the top of the function (if one
+    // does not already exist).
+    auto loc = tileOp.getLoc();
+    auto func = tileOp->getParentOfType<FunctionOpInterface>();
+    auto tileAlloca = getOrCreateTileMemory(rewriter, loc, func, tileOp,
+                                            tileOp.getTileId().getInt());
+
+    // Step 2. Assign the op a real tile ID.
+    // For simplicity, we always use tile 0.
+    auto zeroTileId = rewriter.getI32IntegerAttr(0);
+    {
+      rewriter.startRootUpdate(tileOp);
+      tileOp.setTileId(zeroTileId);
+      rewriter.finalizeRootUpdate(tileOp);
+    }
+
+    VectorType tileVectorType = tileOp.getTileType();
+    auto sliceType = VectorType::Builder(tileOp.getTileType()).dropDim(0);
+    auto emitTileSwap = [&] {
+      emitFullTileSwap(rewriter, loc, tileAlloca,
+                       *arm_sme::getSMETileType(tileVectorType), sliceType,
+                       zeroTileId);
+    };
+
+    // Step 3. Emit tile swaps before and after the op.
+    // TODO: Reduce the amount spilled to the amount of data the `tileOp`
+    // touches (i.e. a single tile slice).
+    {
+      rewriter.setInsertionPoint(op);
+      // Swap the in-memory tile's contents into ZA before the op.
+      emitTileSwap();
+      rewriter.setInsertionPointAfter(op);
+      // Swap the tile back out to memory again after the op.
+      emitTileSwap();
+    }
+
+    return success();
+  }
+
+  /// Extracts a pointer to a slice of an in-memory tile.
+  Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
+                                Value tileMemory, Value sliceIndex) const {
+    auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
+    auto descriptor =
+        rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
+    auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
+    auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI64Type(), sliceIndex);
+    return getStridedElementPtr(
+        loc, llvm::cast<MemRefType>(tileMemory.getType()),
+        descriptor.getResult(0), {sliceIndexI64, zero},
+        static_cast<ConversionPatternRewriter &>(rewriter));
+  }
+
+  /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
+  /// tile-sized memref (`tileAlloca`).
+  void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
+                     arm_sme::ArmSMETileType tileType, VectorType sliceType,
+                     IntegerAttr tileId, Value sliceIndex) const {
+    // Cast the slice index to an i32.
+    auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI32Type(), sliceIndex);
+    // Create an all-true predicate for the slice.
+    auto predicateType = sliceType.clone(rewriter.getI1Type());
+    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+    // Create zero padding vector (never used due to all-true predicate).
+    auto zeroVector = rewriter.create<arith::ConstantOp>(
+        loc, sliceType, rewriter.getZeroAttr(sliceType));
+    // Get a pointer to the current slice.
+    auto slicePtr =
+        getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
+    // Read the value of the current slice from ZA.
+    auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
+        loc, sliceType, zeroVector, allTruePredicate, tileId, sliceIndexI32);
+    // Load the new tile slice back from memory into ZA.
+    createLoadTileSliceIntrinsic(
+        rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
+        allTruePredicate, slicePtr, tileId, sliceIndexI32);
+    // Store the current tile slice to memory.
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
+                                     ValueRange{sliceIndex, zero});
+  }
+
+  /// Emits a full in-place swap of the contents of a tile in ZA and a
+  /// tile-sized memref (`tileAlloca`).
+  void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
+                        arm_sme::ArmSMETileType tileType, VectorType sliceType,
+                        IntegerAttr tileId) const {
+    RewriterBase::InsertionGuard guard(rewriter);
+    // Create an scf.for over all tile slices.
+    auto minNumElts =
+        rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = rewriter.create<arith::MulIOp>(
+        loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    // Emit a swap for each tile slice.
+    rewriter.setInsertionPointToStart(forOp.getBody());
+    auto sliceIndex = forOp.getInductionVar();
+    emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
+                  sliceIndex);
+  }
+};
+
 struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTileOp> {
   using ConvertOpToLLVMPattern<arm_sme::GetTileOp>::ConvertOpToLLVMPattern;
 
@@ -75,8 +370,8 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = zero.getLoc();
 
-    unsigned tileElementWidth =
-        zero.getVectorType().getElementType().getIntOrFloatBitWidth();
+    arm_sme::ArmSMETileType tileType =
+        *arm_sme::getSMETileType(zero.getVectorType());
 
     auto tileId = getTileIdOrError(zero);
     if (!tileId)
@@ -87,22 +382,22 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
     // These masks are derived from:
     // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
     auto baseMaskForSize = [&] {
-      switch (tileElementWidth) {
-      case 8:
+      switch (tileType) {
+      case arm_sme::ArmSMETileType::ZAB:
         // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
         // 64-bit element tiles named ZA0.D to ZA7.D.
         return 0b1111'1111;
-      case 16:
-        // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
-        // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
-        // Shift this left once for ZA1.H.
+      case arm_sme::ArmSMETileType::ZAH:
+        // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
+        // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
+        // once for ZA1.H.
         return 0b0101'0101;
-      case 32:
+      case arm_sme::ArmSMETileType::ZAS:
         // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
         // element tiles named ZA0.D and ZA4.D.
         // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
         return 0b0001'0001;
-      case 64:
+      case arm_sme::ArmSMETileType::ZAD:
         // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
         // setting the bit for that tile.
         return 0b0000'0001;
@@ -172,63 +467,14 @@ struct LoadTileSliceConversion
     // Create all active predicate mask.
     auto maskOp = loadTileSliceOp.getMask();
 
-    auto tileType = loadTileSliceOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileE...
[truncated]

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the patch Ben! I've not gone through it all yet, but left some comments for now. Will have another pass later, cheers

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Mostly makes sense, but I have a few small questions/suggestions.

From summary:

Here the in-memory tile op:

arm_sme.tile_op { tile_id = }

Could you clarify what "IN MEMORY TILE" is in practice? IIUC, it's an integer >= 16?

This could be optimized to only spill/load data the tile op will use, which could be just a slice.

Could you add a test that demonstrates excessive spilling?

It's also not making any use of liveness, which could allow reusing tiles. But these is not seen as important as correct code should only use the available number of tiles.

Please document these as next steps for optimisation. In particular, this is important, but we may not have the bandwidth to prioritise this just yet. Functional correctness first, performance next ;-)

Comment on lines 752 to 762
// Add spill/fill conversions with a very high benefit to ensure they are
// lowered first.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the benefit is much lower? As in, what breaks?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the op is lowered to an intrinsic before the spills/fills are added, then there's no op to add spills/fills to.

@MacDue MacDue force-pushed the arm_sme_rudimentary_tile_spills branch from 8112e8e to 590de4e Compare December 21, 2023 12:01
return
}

func.func @get_svl() -> index attributes { enable_arm_streaming_ignore, arm_locally_streaming }{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_arm_streaming_ignore isn't necessary with only-if-required-by-ops

Suggested change
func.func @get_svl() -> index attributes { enable_arm_streaming_ignore, arm_locally_streaming }{
func.func @get_svl() -> index attributes { arm_locally_streaming }{

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, I'm just being explicit as I've manually set the streaming mode to ensure this returns SVL.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see, perhaps there's value in us adding an op for this at some point, there's synthetic intrinsics in the backend for this the ACLE uses: https://arm-software.github.io/acle/main/acle.html#rdsvl

  • int_aarch64_sme_cntsb
  • int_aarch64_sme_cntsh
  • int_aarch64_sme_cntsw
  • int_aarch64_sme_cntsd

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those would probably be nicer to use than the current vscale * base_size we're currently doing :)
Especially if we had a op like:

%svl = arm_sme.streaming_vl <bytes|half-words|words|double-words>

@MacDue MacDue force-pushed the arm_sme_rudimentary_tile_spills branch from d375d37 to 1460c70 Compare December 21, 2023 17:52
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite involved, but very well crafted and actually relatively easy to follow. I've left a few minor comments, but otherwise LGTM!

Btw, you've written a lot of useful documentation while working on this - could you summarise this in https://mlir.llvm.org/docs/Dialects/ArmSME/? That shouldn't be too much work given that you've already done most of it :)

MacDue added a commit that referenced this pull request Jan 10, 2024
This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.

Example:
```mlir
%svl_w = arm_sme.streaming_vl <word>
```

Created based on discussion here:
#76086 (comment)
@MacDue MacDue force-pushed the arm_sme_rudimentary_tile_spills branch from 1bf0903 to 6734f38 Compare January 10, 2024 10:53
MacDue added 11 commits January 10, 2024 14:14
This adds very basic and inelegant support for something like spilling
and reloading tiles if you use more SME tiles than physically exist.

This is purely implemented to prevent the compiler from aborting if a
function uses too many tiles (i.e. due to bad unrolling), but is
expected to perform very poorly.

Currenly, this works in two stages:

During tile allocation, if we run out of tiles instead of giving up, we
switch to allocating 'in-memory' tile IDs. These are tile IDs that start
at 16 (which is higher than any real tile ID). A warning will also be
emitted for each (root) tile op assigned an in-memory tile ID:

```
warning: failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation
```

Everything after this works like normal until `-convert-arm-sme-to-llvm`

Here the in-memory tile op:

```mlir
arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
```

Is lowered to:

```mlir
// At function entry:
%alloca = memref.alloca ... : memref<?x?xty>

// Around the op:
// Swap the contents of %alloca and tile 0.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}
// Execute op using tile 0.
arm_sme.tile_op { tile_id = 0 }
// Swap the contents of %alloca and tile 0.
// This restores tile 0 to its original state.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}
```

This is inserted during the lowering to LLVM as spilling/reloading
registers is a very low-level concept, that can't really be modeled
correctly at a high level in MLIR.

Note: This is always doing the worst case full-tile swap. This could be
optimized to only spill/load data the tile op will use, which could be
just a slice. It's also not making any use of liveness, which could
allow reusing tiles. But these is not seen as important as correct code
should only use the available number of tiles.
- Show alloca usage in tests
- Add test showing some very excessive spills
- Document a possible API to reduce spills
This switches the ArmSME -> LLVM conversion patterns to use a new
`ConvertArmSMEOpToLLVMPattern` base class. Using this implicitly adds
the spills/fills conversion pattern, unless the `requiresSpillsAndFills`
template parameteris set to `RequiresSpillsAndFills::No`.
@MacDue MacDue force-pushed the arm_sme_rudimentary_tile_spills branch from 6734f38 to 1373bbe Compare January 10, 2024 14:14
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers Ben!

I know you mentioned it's not exactly optimal or elegant, and I agree, but we have to play the hand we're dealt and in this case it's a set of intrinsics that force us to do things we probably shouldn't be doing in MLIR. Hopefully a bit further down the road these intrinsics will either evolve or a new set will come along where the tile is a proper type with spilling/filling support in the backend and dataflow is correctly modeled.

In the meantime, the functionality you've implemented here allows code that uses more tiles than exist at the architecture level to correctly run and there's value in that. It's much better than a crash anyway! 😄

@MacDue MacDue merged commit 5417a5f into llvm:main Jan 12, 2024
@MacDue MacDue deleted the arm_sme_rudimentary_tile_spills branch January 12, 2024 14:51
@ronlieb
Copy link
Contributor

ronlieb commented Jan 13, 2024

hi, we have SLES 15 SP4 systems with gcc 7.5 (llvm minimum supported version of gnu is 7.4) and see build time failures with this patch applied.

RHEL8 and ubuntu20 compile code ok,however much more recent gnu toolchain.

llvm-project/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp: In lambda function:
llvm-project/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp:377:67: error: parameter packs not expanded with ‘...’:
if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~
std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
~~~

@MacDue
Copy link
Member Author

MacDue commented Jan 13, 2024

@ronlieb Hi, could you test if #78046 fixes the issue? I tried something similar on compiler explorer and it seemed to work for GCC 7.4, but I don't have such an old compiler setup to build LLVM locally.

searlmc1 pushed a commit to ROCm/llvm-project that referenced this pull request Jan 13, 2024
… stack (llvm#76086)"

breaks SLES 15 SP4 gcc 7.5 builds

This reverts commit 5417a5f.

Change-Id: I1d78d877481f9445ef74ee59b0987630e7279844
@ronlieb
Copy link
Contributor

ronlieb commented Jan 13, 2024

@ronlieb Hi, could you test if #78046 fixes the issue? I tried something similar compiler explorer and it seemed to work for GCC 7.4, but I don't have such an old compiler setup to build LLVM locally.

thx for fix, i will give it a go, it will need to grind through our CI, so likely will take about 7 hours

@ronlieb
Copy link
Contributor

ronlieb commented Jan 14, 2024

@ronlieb Hi, could you test if #78046 fixes the issue? I tried something similar compiler explorer and it seemed to work for GCC 7.4, but I don't have such an old compiler setup to build LLVM locally.

thx for fix, i will give it a go, it will need to grind through our CI, so likely will take about 7 hours

the patch worked, thanks! SLES build succeeded

searlmc1 pushed a commit to ROCm/llvm-project that referenced this pull request Jan 14, 2024
…lvm#76086)

Folds in subsequent fix
  MacDue@7d17f32

This adds very basic (and inelegant) support for something like spilling
and reloading tiles, if you use more SME tiles than physically exist.

This is purely implemented to prevent the compiler from aborting if a
function uses too many tiles (i.e. due to bad unrolling), but is
expected to perform very poorly.

Currently, this works in two stages:

During tile allocation, if we run out of tiles instead of giving up, we
switch to allocating 'in-memory' tile IDs. These are tile IDs that start
at 16 (which is higher than any real tile ID). A warning will also be
emitted for each (root) tile op assigned an in-memory tile ID:

```
warning: failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance
```

Everything after this works like normal until `-convert-arm-sme-to-llvm`

Here the in-memory tile op:

```mlir
arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
```

Is lowered to:

```mlir
// At function entry:
%alloca = memref.alloca ... : memref<?x?xty>

// Around the op:
// Swap the contents of %alloca and tile 0.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}
// Execute op using tile 0.
arm_sme.tile_op { tile_id = 0 }
// Swap the contents of %alloca and tile 0.
// This restores tile 0 to its original state.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}
```

This is inserted during the lowering to LLVM as spilling/reloading
registers is a very low-level concept, that can't really be modeled
correctly at a high level in MLIR.

Note: This is always doing the worst case full-tile swap. This could be
optimized to only spill/load data the tile op will use, which could be
just a slice. It's also not making any use of liveness, which could
allow reusing tiles. But these is not seen as important as correct code
should only use the available number of tiles.

Change-Id: I44f462d9803d020d97225cad9ea84ee42ed60093
@MacDue
Copy link
Member Author

MacDue commented Jan 14, 2024

@ronlieb Hi, could you test if #78046 fixes the issue? I tried something similar compiler explorer and it seemed to work for GCC 7.4, but I don't have such an old compiler setup to build LLVM locally.

thx for fix, i will give it a go, it will need to grind through our CI, so likely will take about 7 hours

the patch worked, thanks! SLES build succeeded

Landed the fix now, thanks for testing! 👍

justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.

Example:
```mlir
%svl_w = arm_sme.streaming_vl <word>
```

Created based on discussion here:
llvm#76086 (comment)
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…lvm#76086)

This adds very basic (and inelegant) support for something like spilling
and reloading tiles, if you use more SME tiles than physically exist.

This is purely implemented to prevent the compiler from aborting if a
function uses too many tiles (i.e. due to bad unrolling), but is
expected to perform very poorly.

Currently, this works in two stages:

During tile allocation, if we run out of tiles instead of giving up, we
switch to allocating 'in-memory' tile IDs. These are tile IDs that start
at 16 (which is higher than any real tile ID). A warning will also be
emitted for each (root) tile op assigned an in-memory tile ID:

```
warning: failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance
```

Everything after this works like normal until `-convert-arm-sme-to-llvm`

Here the in-memory tile op:

```mlir
arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
```

Is lowered to:

```mlir
// At function entry:
%alloca = memref.alloca ... : memref<?x?xty>

// Around the op:
// Swap the contents of %alloca and tile 0.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}
// Execute op using tile 0.
arm_sme.tile_op { tile_id = 0 }
// Swap the contents of %alloca and tile 0.
// This restores tile 0 to its original state.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}
```

This is inserted during the lowering to LLVM as spilling/reloading
registers is a very low-level concept, that can't really be modeled
correctly at a high level in MLIR.

Note: This is always doing the worst case full-tile swap. This could be
optimized to only spill/load data the tile op will use, which could be
just a slice. It's also not making any use of liveness, which could
allow reusing tiles. But these is not seen as important as correct code
should only use the available number of tiles.
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants