diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h index cc32e97084830..b550980c4ad01 100644 --- a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h +++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h @@ -16,18 +16,26 @@ namespace mlir { class LLVMTypeConverter; class RewritePatternSet; +class TypeConverter; class Pass; #define GEN_PASS_DECL_CONVERTAMDGPUTOROCDLPASS #include "mlir/Conversion/Passes.h.inc" -/// Note: The ROCDL target does not support the LLVM bfloat type at this time -/// and so this function will add conversions to change all `bfloat` uses -/// to `i16`. -void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter, +/// Note: This function will also add conversions for the AMDGPU-specific +/// address spaces, but those can be added separately using +/// populateAMDGPUMemorySpaceAttributeConversions(). +void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset); +/// Remap AMDGPU memory spaces to LLVM address spaces +/// by mapping amdgpu::AddressSpace::fat_raw_buffer to ptr addrspace(7), +/// amdgpu::AddressSpace::buffer_rsrc to ptr addrspace(8), and +/// amdgpu::AddressSpace::fat_strided_buffer to ptr addrspace(9). +void populateAMDGPUMemorySpaceAttributeConversions( + TypeConverter &typeConverter); + } // namespace mlir #endif // MLIR_CONVERSION_AMDGPUTOROCDL_AMDGPUTOROCDL_H_ diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index f795dd89b79a1..ef33dc43f1d9e 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -9,8 +9,11 @@ #ifndef AMDGPU #define AMDGPU +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/EnumAttr.td" +include "mlir/IR/Properties.td" include "mlir/IR/OpBase.td" def AMDGPU_Dialect : Dialect { @@ -32,6 +35,45 @@ def AMDGPU_Dialect : Dialect { let useDefaultAttributePrinterParser = 1; } +//===----------------------------------------------------------------------===// +// AMDGPU general attribute definitions +//===----------------------------------------------------------------------===// + +def AMDGPU_AddressSpace : I32EnumAttr<"AddressSpace", + "AMDGPU-specific address spaces", + [ + I32EnumAttrCase<"FatRawBuffer", 0, "fat_raw_buffer">, + I32EnumAttrCase<"BufferRsrc", 1, "buffer_rsrc">, + I32EnumAttrCase<"FatStructuredBuffer", 2, "fat_structured_buffer">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::amdgpu"; +} + +def AMDGPU_AddressSpaceAttr : EnumAttr { + let description = [{ + AMDGPU-specific memory spaces that may not have exact analogues on other + GPU targets or backends. + + - `fat_raw_buffer` is the memory space used when a memref is stored as + as a "buffer fat pointer" - that is, a buffer resource (that is set up to + use raw byte-level indexing) along with its offset. The AMDGPU backend + implements `ptr addrspace(7)` to represent these fat pointers so that + buffer resources (which allow advanced features like bounds checking or + cache swizzling) can be used like ordinary LLVM pointers or memrefs. + See also the `fat_raw_buffer_cast` operation + - `buffer_rsrc` is the memory space for `ptr addrspace(8)`, representing a + buffer resource. It should not be used for memrefs, since it does not support + indexing + - `fat_structured_buffer` represents `ptr addrspace(9)`, a buffer resource + that carries both an index and offset field, which are used for complex + structured indexing that is primarily seen in graphics applications. This + is also incompatible with the simple indexing model supported by memref. + }]; + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // AMDGPU Op definitions //===----------------------------------------------------------------------===// @@ -118,6 +160,69 @@ def AMDGPU_PackedStochRoundFp8Op : let hasVerifier = 1; } +def AMDGPU_FatRawBufferCastOp : + AMDGPU_Op<"fat_raw_buffer_cast", + [Pure, + DeclareOpInterfaceMethods, + ViewLikeOpInterface, AttrSizedOperandSegments]>, + Arguments<(ins AnyMemRef:$source, + Optional:$validBytes, + Optional>:$cacheSwizzleStride, + DefaultValuedProp:$boundsCheck, + UnitProp:$resetOffset)>, + Results<(outs AnyMemRef:$result)> { + let summary = "Create a raw buffer fat pointer that matches `memref`"; + let description = [{ + Wraps the memory pointed to by `source` as a raw buffer fat pointer, or, + in LLVM terms, a `ptr addrspace(7)`, returning a memref that has the same + sizes and layout but the `#amdgpu.address_space` + address space. + + This memref can be used with standard memref operations like `memref.load`, + `memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant + buffer intrinsics. (`vector.masked_load/store` will work once there's backend + support for lowering them, and then this document will be updated) + + If `validBytes` is given, it is the number of bytes that will be valid as + an offset to `out`. If it is not provided, this will be inferred from + the size of the memref during lowering. This size is + max_{d = 0 upto rank(source)} (sizes[d] * strides[d]) * sizeof(element type). + + The flags of the buffer descriptor will be set up to enable raw usage - + for example, stride = 0, add_tid = 0, and so on. The `boundsCheck` + property determines if bounds checking is enabled or not (on architectures + where this can be controlled - that is, on RDNA chips). + + If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled + on architectures that support it. This swizzling, unlike the main swizzling + mode (whose usage makes a buffer non-raw) does not affect index calculation, + but does affect cache behavior. Mixing access between cache-swizzled raw + buffers and other forms of memory access, like ordinary pointer loads or + unswizzled buffer pointers can cause incorrect behavior and must be avoided. + + This operation preserves the sizes, strides, and offset of the input + memref - they'll be added in by `memref.load` later. However, if + `resetOffset` is set, that offset will be added to the base pointer. + If the value of the memref's offset is not uniform (independent of the lane/thread ID), + this will lead to substantially decreased performance due to the need for + a waterfall loop on the base address of the buffer resource. + }]; + + let extraClassDeclaration = [{ + Value getViewSource() { return getSource(); } + }]; + + let assemblyFormat = [{ + $source oilist (`validBytes` `(` $validBytes `)` + | `cacheSwizzleStride` `(` $cacheSwizzleStride `)` + | `boundsCheck` `(` $boundsCheck `)` + | `resetOffset` $resetOffset ) + attr-dict `:` type($source) `to` type($result) + }]; + + let hasVerifier = 1; +} + /// Raw buffer load def AMDGPU_RawBufferLoadOp : AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>, diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h index 0a2e6bb5e9fe4..3de57c923178a 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h @@ -18,7 +18,9 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc" diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h index 8dd5ff1a4b198..c3ae7930e8ec8 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h @@ -21,12 +21,15 @@ class ConversionTarget; namespace amdgpu { #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS +#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS #define GEN_PASS_REGISTRATION #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset); + +void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns); } // namespace amdgpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td index 23f8b8f653b67..6d0bcd6e1066e 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td @@ -31,4 +31,24 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> { "Chipset that these operations will run on">]; } +def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> { + let summary = "Resolve memref.extract_strided_metadata on AMDGPU ops"; + let description = [{ + This pass rrewrites `memref.extract_strided_metadata` operations + targeting the AMDGPU dialect casts. + + The patterns in this pass should normally be run alongside those in + -expand-strided-metadata, and creating a pass that combines those two + sets of patterns is the recommended way to use this functionality. + However, this pass (which will likely need a second -expand-strided-metadata + after it) is provided so that simple usecases do not need to create custom passes. + These patterns have not been added to -expnad-strided-metadata to + prevent the memref dialect from depending on platform-specific code. + }]; + let dependentDialects = [ + "arith::ArithDialect", + "memref::MemRefDialect" + ]; +} + #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_ diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index b29228ef87ea7..26ecc000138d2 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -19,6 +19,8 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" +#include "../LLVMCommon/MemRefDescriptor.h" + #include "llvm/ADT/STLExtras.h" #include @@ -30,6 +32,11 @@ namespace mlir { using namespace mlir; using namespace mlir::amdgpu; +// Define commonly used chipsets versions for convenience. +constexpr Chipset kGfx908 = Chipset(9, 0, 8); +constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); +constexpr Chipset kGfx942 = Chipset(9, 4, 2); + /// Convert an unsigned number `val` to i32. static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val) { @@ -76,11 +83,166 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, return index ? index : createI32Constant(rewriter, loc, 0); } +/// Compute the contents of the `num_records` field for a given memref +/// descriptor - that is, the number of bytes that's one element past the +/// greatest possible valid index into the memref. +static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, + MemRefType memrefType, + MemRefDescriptor &memrefDescriptor, + ArrayRef strides, + uint32_t elementByteWidth) { + if (memrefType.hasStaticShape() && + !llvm::any_of(strides, ShapedType::isDynamic)) { + int64_t size = memrefType.getRank() == 0 ? 1 : 0; + ArrayRef shape = memrefType.getShape(); + for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) + size = std::max(shape[i] * strides[i], size); + size = size * elementByteWidth; + assert(size < std::numeric_limits::max() && + "the memref buffer is too large"); + return createI32Constant(rewriter, loc, static_cast(size)); + } + Value maxIndex; + for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { + Value size = memrefDescriptor.size(rewriter, loc, i); + Value stride = memrefDescriptor.stride(rewriter, loc, i); + Value maxThisDim = rewriter.create(loc, size, stride); + maxIndex = maxIndex + ? rewriter.create(loc, maxIndex, maxThisDim) + : maxThisDim; + } + Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex); + Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); + return rewriter.create(loc, maxIndexI32, byteWidthConst); +} + +static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, + Value basePointer, Value numRecords, + bool boundsCheck, amdgpu::Chipset chipset, + Value cacheSwizzleStride = nullptr, + unsigned addressSpace = 8) { + // The stride value is generally 0. However, on MI-300 and onward, you can + // enable a cache swizzling mode by setting bit 14 of the stride field + // and setting that stride to a cache stride. + Type i16 = rewriter.getI16Type(); + Value stride; + if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) { + Value cacheStrideZext = + rewriter.create(loc, i16, cacheSwizzleStride); + Value swizzleBit = rewriter.create( + loc, i16, rewriter.getI16IntegerAttr(1 << 14)); + stride = rewriter.create(loc, cacheStrideZext, swizzleBit, + /*isDisjoint=*/true); + } else { + stride = rewriter.create(loc, i16, + rewriter.getI16IntegerAttr(0)); + } + // Get the number of elements. + // Flag word: + // bits 0-11: dst sel, ignored by these intrinsics + // bits 12-14: data format (ignored, must be nonzero, 7=float) + // bits 15-18: data format (ignored, must be nonzero, 4=32bit) + // bit 19: In nested heap (0 here) + // bit 20: Behavior on unmap (0 means "return 0 / ignore") + // bits 21-22: Index stride for swizzles (N/A) + // bit 23: Add thread ID (0) + // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) + // bits 25-26: Reserved (0) + // bit 27: Buffer is non-volatile (CDNA only) + // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = + // none, 3 = either swizzles or testing against offset field) RDNA only + // bits 30-31: Type (must be 0) + uint32_t flags = (7 << 12) | (4 << 15); + if (chipset.majorVersion >= 10) { + flags |= (1 << 24); + uint32_t oob = boundsCheck ? 3 : 2; + flags |= (oob << 28); + } + Value flagsConst = createI32Constant(rewriter, loc, flags); + Type rsrcType = + LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); + Value resource = rewriter.createOrFold( + loc, rsrcType, basePointer, stride, numRecords, flagsConst); + return resource; +} + namespace { -// Define commonly used chipsets versions for convenience. -constexpr Chipset kGfx908 = Chipset(9, 0, 8); -constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); -constexpr Chipset kGfx942 = Chipset(9, 4, 2); +struct FatRawBufferCastLowering + : public ConvertOpToLLVMPattern { + FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), + chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value memRef = adaptor.getSource(); + Value unconvertedMemref = op.getSource(); + MemRefType memrefType = cast(unconvertedMemref.getType()); + MemRefDescriptor descriptor(memRef); + + DataLayout dataLayout = DataLayout::closest(op); + int64_t elementByteWidth = + dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8; + + int64_t unusedOffset = 0; + SmallVector strideVals; + if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset))) + return op.emitOpError("Can't lower non-stride-offset memrefs"); + + Value numRecords = adaptor.getValidBytes(); + if (!numRecords) + numRecords = getNumRecords(rewriter, loc, memrefType, descriptor, + strideVals, elementByteWidth); + + Value basePointer = + adaptor.getResetOffset() + ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), + memrefType) + : descriptor.alignedPtr(rewriter, loc); + + Value offset = adaptor.getResetOffset() + ? rewriter.create( + loc, getIndexType(), rewriter.getIndexAttr(0)) + : descriptor.offset(rewriter, loc); + + bool hasSizes = memrefType.getRank() > 0; + // No need to unpack() and pack() all the individual sizes and strides, + // so we'll just extract the arrays. + Value sizes = hasSizes ? rewriter.create( + loc, descriptor, kSizePosInMemRefDescriptor) + : Value{}; + Value strides = hasSizes + ? rewriter.create( + loc, descriptor, kStridePosInMemRefDescriptor) + : Value{}; + + Value fatPtr = makeBufferRsrc( + rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(), + chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7); + + Value result = MemRefDescriptor::poison( + rewriter, loc, + getTypeConverter()->convertType(op.getResult().getType())); + result = rewriter.create( + loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor); + result = rewriter.create( + loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); + result = rewriter.create(loc, result, offset, + kOffsetPosInMemRefDescriptor); + if (hasSizes) { + result = rewriter.create(loc, result, sizes, + kSizePosInMemRefDescriptor); + result = rewriter.create( + loc, result, strides, kStridePosInMemRefDescriptor); + } + rewriter.replaceOp(op, result); + return success(); + } +}; /// Define lowering patterns for raw buffer ops template @@ -122,7 +284,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); Type i32 = rewriter.getI32Type(); - Type i16 = rewriter.getI16Type(); // Get the type size in bytes. DataLayout dataLayout = DataLayout::closest(gpuOp); @@ -199,60 +360,10 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { Value ptr = memrefDescriptor.bufferPtr( rewriter, loc, *this->getTypeConverter(), memrefType); - // The stride value is always 0 for raw buffers. This also disables - // swizling. - Value stride = rewriter.create( - loc, i16, rewriter.getI16IntegerAttr(0)); - // Get the number of elements. - Value numRecords; - if (memrefType.hasStaticShape() && - !llvm::any_of(strides, ShapedType::isDynamic)) { - int64_t size = memrefType.getRank() == 0 ? 1 : 0; - ArrayRef shape = memrefType.getShape(); - for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) - size = std::max(shape[i] * strides[i], size); - size = size * elementByteWidth; - assert(size < std::numeric_limits::max() && - "the memref buffer is too large"); - numRecords = createI32Constant(rewriter, loc, static_cast(size)); - } else { - Value maxIndex; - for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { - Value size = memrefDescriptor.size(rewriter, loc, i); - Value stride = memrefDescriptor.stride(rewriter, loc, i); - Value maxThisDim = rewriter.create(loc, size, stride); - maxIndex = - maxIndex ? rewriter.create(loc, maxIndex, maxThisDim) - : maxThisDim; - } - numRecords = rewriter.create( - loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst); - } - - // Flag word: - // bits 0-11: dst sel, ignored by these intrinsics - // bits 12-14: data format (ignored, must be nonzero, 7=float) - // bits 15-18: data format (ignored, must be nonzero, 4=32bit) - // bit 19: In nested heap (0 here) - // bit 20: Behavior on unmap (0 means "return 0 / ignore") - // bits 21-22: Index stride for swizzles (N/A) - // bit 23: Add thread ID (0) - // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) - // bits 25-26: Reserved (0) - // bit 27: Buffer is non-volatile (CDNA only) - // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = - // none, 3 = either swizzles or testing against offset field) RDNA only - // bits 30-31: Type (must be 0) - uint32_t flags = (7 << 12) | (4 << 15); - if (chipset.majorVersion >= 10) { - flags |= (1 << 24); - uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2; - flags |= (oob << 28); - } - Value flagsConst = createI32Constant(rewriter, loc, flags); - Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); - Value resource = rewriter.createOrFold( - loc, rsrcType, ptr, stride, numRecords, flagsConst); + Value numRecords = getNumRecords( + rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth); + Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords, + adaptor.getBoundsCheck(), chipset); args.push_back(resource); // Indexing (voffset) @@ -1062,11 +1173,32 @@ struct ConvertAMDGPUToROCDLPass }; } // namespace -void mlir::populateAMDGPUToROCDLConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns, - Chipset chipset) { +void mlir::populateAMDGPUMemorySpaceAttributeConversions( + TypeConverter &typeConverter) { + typeConverter.addTypeAttributeConversion( + [](BaseMemRefType type, amdgpu::AddressSpaceAttr as) + -> TypeConverter::AttributeConversionResult { + MLIRContext *ctx = as.getContext(); + Type i64 = IntegerType::get(ctx, 64); + switch (as.getValue()) { + case amdgpu::AddressSpace::FatRawBuffer: + return IntegerAttr::get(i64, 7); + case amdgpu::AddressSpace::BufferRsrc: + return IntegerAttr::get(i64, 8); + case amdgpu::AddressSpace::FatStructuredBuffer: + return IntegerAttr::get(i64, 9); + } + return TypeConverter::AttributeConversionResult::abort(); + }); +} + +void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns, + Chipset chipset) { + populateAMDGPUMemorySpaceAttributeConversions(converter); patterns - .add, + .add, RawBufferOpLowering, RawBufferOpLowering, diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 271ca382e2f0b..d2bfb863244d9 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -59,6 +59,59 @@ LogicalResult PackedStochRoundFp8Op::verify() { return success(); } +//===----------------------------------------------------------------------===// +// FatRawBuferCastOp +//===----------------------------------------------------------------------===// + +/// Convert the type `source` to one with the same sizes and strides - and +/// offset, unless `stripOffset` is true, in which case the offset is reset to +/// 0, if the offset should be reset but the layout of `source` isn't either the +/// identity layout or a strided layout, this function fails. +static FailureOr getFatRawBufferTypeLike(MemRefType source, + bool resetOffset) { + MLIRContext *ctx = source.getContext(); + MemRefType::Builder mb(source); + mb.setMemorySpace( + amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer)); + MemRefLayoutAttrInterface layout = source.getLayout(); + if (resetOffset && !layout.isIdentity()) { + auto stridedLayout = dyn_cast(layout); + if (!stridedLayout) + return failure(); + mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides())); + } + return (MemRefType)(mb); +} + +LogicalResult FatRawBufferCastOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + auto sourceType = + dyn_cast_if_present(adaptor.getSource().getType()); + if (!sourceType) + return failure(); + FailureOr resultType = + getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset()); + if (failed(resultType)) + return failure(); + inferredReturnTypes = SmallVector{*resultType}; + return success(); +} + +LogicalResult FatRawBufferCastOp::verify() { + FailureOr expectedResultType = + getFatRawBufferTypeLike(getSource().getType(), getResetOffset()); + if (failed(expectedResultType)) + return emitOpError("source type ") + << getSource().getType() << " can't have its offset reset"; + if (getResult().getType() != *expectedResultType) + return emitOpError("expected result type to be ") + << *expectedResultType << " but got " << getResult().getType(); + return success(); +} + //===----------------------------------------------------------------------===// // RawBuffer*Op //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt index 5f934714d988a..3d4567bff1e32 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms EmulateAtomics.cpp + ResolveStridedMetadata.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms @@ -11,6 +12,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms MLIRAMDGPUDialect MLIRAMDGPUUtils MLIRArithDialect + MLIRMemRefDialect MLIRVectorDialect MLIRControlFlowDialect MLIRFuncDialect diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp new file mode 100644 index 0000000000000..4b3d94b4ce2ad --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp @@ -0,0 +1,79 @@ +//===- ResolveStridedMetadata.cpp - AMDGPU expand_strided_metadata ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::amdgpu { +#define GEN_PASS_DEF_AMDGPURESOLVESTRIDEDMETADATAPASS +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" +} // namespace mlir::amdgpu + +using namespace mlir; +using namespace mlir::amdgpu; + +namespace { +struct AmdgpuResolveStridedMetadataPass + : public amdgpu::impl::AmdgpuResolveStridedMetadataPassBase< + AmdgpuResolveStridedMetadataPass> { + void runOnOperation() override; +}; + +struct ExtractStridedMetadataOnFatRawBufferCastFolder final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp, + PatternRewriter &rewriter) const override { + auto castOp = metadataOp.getSource().getDefiningOp(); + if (!castOp) + return rewriter.notifyMatchFailure(metadataOp, + "not a fat raw buffer cast"); + Location loc = castOp.getLoc(); + auto sourceMetadata = rewriter.create( + loc, castOp.getSource()); + SmallVector results; + if (metadataOp.getBaseBuffer().use_empty()) { + results.push_back(nullptr); + } else { + auto baseBufferType = + cast(metadataOp.getBaseBuffer().getType()); + if (baseBufferType == castOp.getResult().getType()) { + results.push_back(castOp.getResult()); + } else { + results.push_back(rewriter.create( + loc, baseBufferType, castOp.getResult(), /*offset=*/0, + /*sizes=*/ArrayRef{}, /*strides=*/ArrayRef{})); + } + } + if (castOp.getResetOffset()) + results.push_back(rewriter.create(loc, 0)); + else + results.push_back(sourceMetadata.getOffset()); + llvm::append_range(results, sourceMetadata.getSizes()); + llvm::append_range(results, sourceMetadata.getStrides()); + rewriter.replaceOp(metadataOp, results); + return success(); + } +}; +} // namespace + +void mlir::amdgpu::populateAmdgpuResolveStridedMetadataPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +void AmdgpuResolveStridedMetadataPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateAmdgpuResolveStridedMetadataPatterns(patterns); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 062b63c076c3c..ae1b34ef3f8eb 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -1,13 +1,124 @@ // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx908 | FileCheck %s --check-prefixes=CHECK,GFX9,GFX908 // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx90a | FileCheck %s --check-prefixes=CHECK,GFX9,GFX90A +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9,GFX942 // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10,RDNA // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11,RDNA // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12,RDNA +// Note: #gpu.address_space is hardcoded to `1` here because the +// test pass doesn't set up the GPU address space conversions. + +#gpu_global_addrspace = 1 + +// CHECK-LABEL: func @fat_raw_buffer_cast +func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space> { + // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<8xi32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1] + // CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2] + // CHECK-DAG: %[[sizes:.*]] = llvm.extractvalue %[[desc]][3] + // CHECK-DAG: %[[strides:.*]] = llvm.extractvalue %[[desc]][4] + // CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16 + // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32) + // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32) + // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]] : <1> to <7> + // CHECK: %[[ret0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<7>, ptr<7>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0] + // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1] + // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2] + // CHECK: %[[ret4:.*]] = llvm.insertvalue %[[sizes]], %[[ret3]][3] + // CHECK: %[[ret5:.*]] = llvm.insertvalue %[[strides]], %[[ret4]][4] + // CHECK: builtin.unrealized_conversion_cast %[[ret5]] + %ret = amdgpu.fat_raw_buffer_cast %buf : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space> + return %ret : memref<8xi32, #amdgpu.address_space> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_0d +func.func @fat_raw_buffer_cast_0d(%buf: memref) -> memref> { + // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref to !llvm.struct<(ptr<1>, ptr<1>, i64)> + // CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1] + // CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2] + // CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16 + // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32) + // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32) + // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]] + // CHECK: %[[ret0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<7>, ptr<7>, i64)> + // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0] + // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1] + // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2] + // CHECK: builtin.unrealized_conversion_cast %[[ret3]] + %ret = amdgpu.fat_raw_buffer_cast %buf : memref to memref> + return %ret : memref> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_dyn_size_offset +func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref, #gpu_global_addrspace>) -> memref, #amdgpu.address_space> { + // CHECK: %[[size0:.*]] = llvm.extractvalue %{{.*}}[3, 0] + // CHECK: %[[stride0:.*]] = llvm.extractvalue %{{.*}}[4, 0] + // CHECK: %[[maxVals:.*]] = llvm.mul %[[size0]], %[[stride0]] + // CHECK: %[[maxValsI32:.*]] = llvm.trunc %[[maxVals]] : i64 to i32 + // CHECK: %[[byteSize:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[numRecords:.*]] = llvm.mul %[[maxValsI32]], %[[byteSize]] + // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2] + // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}} + // CHECK: llvm.insertvalue %[[offset]], %{{.*}}[2] + %ret = amdgpu.fat_raw_buffer_cast %buf : memref, #gpu_global_addrspace> to memref, #amdgpu.address_space> + return %ret : memref, #amdgpu.address_space> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset +func.func @fat_raw_buffer_cast_reset_offset(%buf: memref, #gpu_global_addrspace>) -> memref, #amdgpu.address_space> { + // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1] + // CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2] + // CHECK-DAG: %[[basePtr:.*]] = llvm.getelementptr %[[memRefPtr]][%[[memRefOff]]] + // CHECK-DAG: %[[zeroOff:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}} + // CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1] + // CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2] + %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref, #gpu_global_addrspace> to memref, #amdgpu.address_space> + return %ret : memref, #amdgpu.address_space> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes +func.func @fat_raw_buffer_cast_valid_bytes(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space> { + // CHECK: %[[numRecords:.*]] = arith.constant -1 : i32 + // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}} + %cu32_max = arith.constant 0xffffffff : i32 + %ret = amdgpu.fat_raw_buffer_cast %buf validBytes(%cu32_max) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space> + return %ret : memref<8xi32, #amdgpu.address_space> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_bounds_check +func.func @fat_raw_buffer_cast_bounds_check(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space> { + // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32) + // RDNA: %[[flags:.*]] = llvm.mlir.constant(553807872 : i32) + // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %{{.*}}, %[[flags]] + %ret = amdgpu.fat_raw_buffer_cast %buf boundsCheck(false) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space> + return %ret : memref<8xi32, #amdgpu.address_space> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_cache_swizzle +// CHECK-SAME: (%{{.*}}: memref<64x64xi32, 1>, %[[stride:.*]]: i14) +func.func @fat_raw_buffer_cast_cache_swizzle(%buf: memref<64x64xi32, #gpu_global_addrspace>, %stride: i14) -> memref<64x64xi32, #amdgpu.address_space> { + // GFX908: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16 + // GFX90A: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16 + // RDNA: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16 + // GFX942: %[[asI16:.*]] = llvm.zext %[[stride]] : i14 to i16 + // GFX942: %[[cacheSwizzleOn:.*]] = llvm.mlir.constant(16384 : i16) : i16 + // GFX942: %[[stride:.*]] = llvm.or disjoint %[[asI16]], %[[cacheSwizzleOn]] + // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %{{.*}}, %{{.*}} + %ret = amdgpu.fat_raw_buffer_cast %buf cacheSwizzleStride(%stride) : memref<64x64xi32, #gpu_global_addrspace> to memref<64x64xi32, #amdgpu.address_space> + return %ret : memref<64x64xi32, #amdgpu.address_space> +} + // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_scalar_i32 func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref) -> i32 { - // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) + // Extra constant for byte width + // CHECK: llvm.mlir.constant(4 : i32) // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32) + // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32) // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32) // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8> @@ -19,8 +130,8 @@ func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref) -> i32 { // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32 func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 { - // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32) + // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32) // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32) // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8> @@ -37,7 +148,6 @@ func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[ // CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %[[descriptor]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[offset:.*]] = llvm.extractvalue %[[descriptor]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16 // CHECK: %[[sz_i:.*]] = llvm.extractvalue %[[descriptor]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[stride_i:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[ext_i:.*]] = llvm.mul %[[sz_i]], %[[stride_i]] : i64 @@ -46,7 +156,9 @@ func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[ // CHECK: %[[ext_j:.*]] = llvm.mul %[[sz_j]], %[[stride_j]] : i64 // CHECK: %[[num_records:.*]] = llvm.intr.umax(%[[ext_i]], %[[ext_j]]) : (i64, i64) -> i64 // CHECK: %[[num_rec_i32:.*]] = llvm.trunc %[[num_records]] : i64 to i32 - // CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size]] : i32 + // CHECK: %[[elem_size_2:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size_2]] : i32 + // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16 // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %[[stride]], %[[num_rec_bytes_i32]], %{{.*}} : !llvm.ptr to <8> // CHECK: %[[stride_i_1:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[stride_i_i32:.*]] = llvm.trunc %[[stride_i_1]] : i64 to i32 @@ -289,6 +401,8 @@ func.func @lds_barrier() { // GFX908-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier" // GFX90A: rocdl.s.waitcnt -7937 // GFX90A-NEXT: rocdl.s.barrier + // GFX942: rocdl.s.waitcnt -7937 + // GFX942-NEXT: rocdl.s.barrier // GFX10: rocdl.s.waitcnt -16129 // GFX10-NEXT: rocdl.s.barrier // GFX11: llvm.inline_asm has_side_effects asm_dialect = att diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir new file mode 100644 index 0000000000000..831bb5f0f66ec --- /dev/null +++ b/mlir/test/Dialect/AMDGPU/amdgpu-resolve-strided-metadata.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt -amdgpu-resolve-strided-metadata -split-input-file %s | FileCheck %s + +!tSrc = memref> +!tDst = memref, #amdgpu.address_space> +!tRes = memref> +// CHECK-LABEL: @resolve_metadata_no_offset_reset +// CHECK-SAME: (%[[arg0:.*]]: memref>) +// CHECK-NEXT: %[[cast:.+]] = amdgpu.fat_raw_buffer_cast %[[arg0]] +// CHECK-NEXT: %{{.+}}, %[[offset:.+]], %[[size:.+]]:2, %[[stride:.+]]:2 = memref.extract_strided_metadata %[[arg0]] +// CHECK-NEXT: %[[reinterp:.+]] = memref.reinterpret_cast %[[cast]] +// CHECK-NEXT: return %[[reinterp]], %[[offset]], %[[size]]#0, %[[size]]#1, %[[stride]]#0, %[[stride]]#1 +func.func @resolve_metadata_no_offset_reset(%arg0: !tSrc) -> (!tRes, index, index, index, index, index) { + %cast = amdgpu.fat_raw_buffer_cast %arg0 : !tSrc to !tDst + %base, %offset, %size:2, %stride:2 = memref.extract_strided_metadata %cast : !tDst -> !tRes, index, index, index, index, index + func.return %base, %offset, %size#0, %size#1, %stride#0, %stride#1 : !tRes, index, index, index, index, index +} + +// ----- + +!tSrc = memref> +!tDst = memref, #amdgpu.address_space> +!tRes = memref> +// CHECK-LABEL: @resolve_metadata_offset_reset +// CHECK-SAME: (%[[arg0:.*]]: memref>) +// CHECK-NEXT: %[[offset:.+]] = arith.constant 0 : index +// CHECK-NEXT: %[[cast:.+]] = amdgpu.fat_raw_buffer_cast %[[arg0]] +// CHECK-NEXT: %{{.+}}, %{{.+}}, %[[size:.+]]:2, %[[stride:.+]]:2 = memref.extract_strided_metadata %[[arg0]] +// CHECK-NEXT: %[[reinterp:.+]] = memref.reinterpret_cast %[[cast]] +// CHECK-NEXT: return %[[reinterp]], %[[offset]], %[[size]]#0, %[[size]]#1, %[[stride]]#0, %[[stride]]#1 +func.func @resolve_metadata_offset_reset(%arg0: !tSrc) -> (!tRes, index, index, index, index, index) { + %cast = amdgpu.fat_raw_buffer_cast %arg0 resetOffset : !tSrc to !tDst + %base, %offset, %size:2, %stride:2 = memref.extract_strided_metadata %cast : !tDst -> !tRes, index, index, index, index, index + func.return %base, %offset, %size#0, %size#1, %stride#0, %stride#1 : !tRes, index, index, index, index, index +} + +// ----- + +!tSrc = memref> +!tDst = memref, #amdgpu.address_space> +!tRes = memref> +// CHECK-LABEL: @resolve_metadata_no_base_ptr +// CHECK-SAME: (%[[arg0:.*]]: memref>) +// CHECK-NEXT: %[[offset:.+]] = arith.constant 0 : index +// CHECK-NEXT: %[[cast:.+]] = amdgpu.fat_raw_buffer_cast %[[arg0]] +// CHECK-NEXT: %{{.+}}, %{{.+}}, %[[size:.+]]:2, %[[stride:.+]]:2 = memref.extract_strided_metadata %[[arg0]] +// CHECK-NEXT: return %[[cast]], %[[offset]], %[[size]]#0, %[[size]]#1, %[[stride]]#0, %[[stride]]#1 +func.func @resolve_metadata_no_base_ptr(%arg0: !tSrc) -> (!tDst, index, index, index, index, index) { + %cast = amdgpu.fat_raw_buffer_cast %arg0 resetOffset : !tSrc to !tDst + %base, %offset, %size:2, %stride:2 = memref.extract_strided_metadata %cast : !tDst -> !tRes, index, index, index, index, index + func.return %cast, %offset, %size#0, %size#1, %stride#0, %stride#1 : !tDst, index, index, index, index, index +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 5e1ab79962d2f..7cb16f5259070 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -125,3 +125,28 @@ func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32> func.return %0 : vector<8xi32> } + +// ----- + +// Missinng `resetOffset` +func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<8xi32, #amdgpu.address_space> { + // expected-error@+1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space>' but got 'memref<8xi32, #amdgpu.address_space>'}} + %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<[1], offset: ?>, #gpu.address_space> to memref<8xi32, #amdgpu.address_space> + func.return %ret : memref<8xi32, #amdgpu.address_space> +} + +// ----- + +func.func @fat_raw_buffer_cast_wrong_as(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.address_space> { + // expected-error@+1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, #amdgpu.address_space>' but got 'memref<8xi32, #amdgpu.address_space>'}} + %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32> to memref<8xi32, #amdgpu.address_space> + return %ret : memref<8xi32, #amdgpu.address_space> +} + +// ----- + +func.func @fat_raw_buffer_cast_stripping_offset_affine_map(%m: memref<8xi32, affine_map<(d0)[s0] -> (d0 + s0)>>) -> memref<8xi32, #amdgpu.address_space> { + // expected-error@+1 {{'amdgpu.fat_raw_buffer_cast' op source type 'memref<8xi32, affine_map<(d0)[s0] -> (d0 + s0)>>' can't have its offset reset}} + %ret = amdgpu.fat_raw_buffer_cast %m resetOffset : memref<8xi32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<8xi32, #amdgpu.address_space> + func.return %ret : memref<8xi32, #amdgpu.address_space> +} diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 9457a1b9e4498..567e6498330a3 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -25,6 +25,25 @@ func.func @packed_stoch_round_fp8(%v1: f32, %stoch: i32, %others: vector<4xf8E5M func.return %ret : vector<4xf8E5M2FNUZ> } +// CHECK-LABEL: func @fat_raw_buffer_cast_easy +// CHECK: amdgpu.fat_raw_buffer_cast +func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.address_space> { + %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32> to memref<8xi32, #amdgpu.address_space> + func.return %ret : memref<8xi32, #amdgpu.address_space> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast +// CHECK: amdgpu.fat_raw_buffer_cast +// CHECK-SAME: validBytes(%{{[^)]*}}) +// CHECK-SAME: cacheSwizzleStride(%{{[^)]*}}) +// CHECK-SAME: boundsCheck(false) +// CHECK-SAME: resetOffset +func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, strided<[1]>, #amdgpu.address_space> { + %ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset + : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, strided<[1]>, #amdgpu.address_space> + func.return %ret : memref<8xi32, strided<[1]>, #amdgpu.address_space> +} + // CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1 func.func @raw_buffer_load_f32_from_rank_1(%src : memref<128xf32>, %offset : i32, %idx0 : i32) -> f32 { // CHECK: amdgpu.raw_buffer_load {indexOffset = 1 : i32} %{{.*}}[{{.*}}] sgprOffset %{{.*}} : memref<128xf32>, i32 -> f32