@@ -210,12 +210,16 @@ struct FatRawBufferCastLowering
210
210
else
211
211
offset = descriptor.offset (rewriter, loc);
212
212
213
+ bool hasSizes = memrefType.getRank () > 0 ;
213
214
// No need to unpack() and pack() all the individual sizes and strides,
214
215
// so we'll just extract the arrays.
215
- Value sizes = rewriter.create <LLVM::ExtractValueOp>(
216
- loc, descriptor, kSizePosInMemRefDescriptor );
217
- Value strides = rewriter.create <LLVM::ExtractValueOp>(
218
- loc, descriptor, kStridePosInMemRefDescriptor );
216
+ Value sizes = hasSizes ? rewriter.create <LLVM::ExtractValueOp>(
217
+ loc, descriptor, kSizePosInMemRefDescriptor )
218
+ : Value{};
219
+ Value strides = hasSizes
220
+ ? rewriter.create <LLVM::ExtractValueOp>(
221
+ loc, descriptor, kStridePosInMemRefDescriptor )
222
+ : Value{};
219
223
220
224
Value rsrc = makeBufferRsrc (rewriter, loc, basePointer, numRecords,
221
225
adaptor.getBoundsCheck (), chipset,
@@ -232,11 +236,12 @@ struct FatRawBufferCastLowering
232
236
loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor );
233
237
result = rewriter.create <LLVM::InsertValueOp>(loc, result, offset,
234
238
kOffsetPosInMemRefDescriptor );
235
- result = rewriter.create <LLVM::InsertValueOp>(loc, result, sizes,
236
- kSizePosInMemRefDescriptor );
237
- result = rewriter.create <LLVM::InsertValueOp>(loc, result, strides,
238
- kStridePosInMemRefDescriptor );
239
-
239
+ if (hasSizes) {
240
+ result = rewriter.create <LLVM::InsertValueOp>(loc, result, sizes,
241
+ kSizePosInMemRefDescriptor );
242
+ result = rewriter.create <LLVM::InsertValueOp>(
243
+ loc, result, strides, kStridePosInMemRefDescriptor );
244
+ }
240
245
rewriter.replaceOp (op, result);
241
246
return success ();
242
247
}
0 commit comments