Skip to content

Commit 0e12ca0

Browse files
committed
Fixes to AMDGPUToROCDL PR
1 parent dd4b110 commit 0e12ca0

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
7171
structured indexing that is primarily seen in graphics applications. This
7272
is also incompatible with the simple indexing model supported by memref.
7373
}];
74-
let assemblyFormat = [{ `<` $value `>` }];
74+
let assemblyFormat = "`<` $value `>`";
7575
}
7676

7777
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,16 @@ struct FatRawBufferCastLowering
210210
else
211211
offset = descriptor.offset(rewriter, loc);
212212

213+
bool hasSizes = memrefType.getRank() > 0;
213214
// No need to unpack() and pack() all the individual sizes and strides,
214215
// so we'll just extract the arrays.
215-
Value sizes = rewriter.create<LLVM::ExtractValueOp>(
216-
loc, descriptor, kSizePosInMemRefDescriptor);
217-
Value strides = rewriter.create<LLVM::ExtractValueOp>(
218-
loc, descriptor, kStridePosInMemRefDescriptor);
216+
Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
217+
loc, descriptor, kSizePosInMemRefDescriptor)
218+
: Value{};
219+
Value strides = hasSizes
220+
? rewriter.create<LLVM::ExtractValueOp>(
221+
loc, descriptor, kStridePosInMemRefDescriptor)
222+
: Value{};
219223

220224
Value rsrc = makeBufferRsrc(rewriter, loc, basePointer, numRecords,
221225
adaptor.getBoundsCheck(), chipset,
@@ -232,11 +236,12 @@ struct FatRawBufferCastLowering
232236
loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
233237
result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
234238
kOffsetPosInMemRefDescriptor);
235-
result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
236-
kSizePosInMemRefDescriptor);
237-
result = rewriter.create<LLVM::InsertValueOp>(loc, result, strides,
238-
kStridePosInMemRefDescriptor);
239-
239+
if (hasSizes) {
240+
result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
241+
kSizePosInMemRefDescriptor);
242+
result = rewriter.create<LLVM::InsertValueOp>(
243+
loc, result, strides, kStridePosInMemRefDescriptor);
244+
}
240245
rewriter.replaceOp(op, result);
241246
return success();
242247
}

0 commit comments

Comments
 (0)