Skip to content

Commit f527fdf

Browse files
committed
[mlir][sparse] Code cleanup for SparseTensorConversion
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D115004
1 parent 0e0f1b2 commit f527fdf

File tree

1 file changed

+58
-55
lines changed

1 file changed

+58
-55
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 58 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,19 @@ constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
142142
return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
143143
}
144144

145+
/// Returns the equivalent of `void*` for opaque arguments to the
146+
/// execution engine.
147+
static Type getOpaquePointerType(PatternRewriter &rewriter) {
148+
return LLVM::LLVMPointerType::get(rewriter.getI8Type());
149+
}
150+
145151
/// Returns a function reference (first hit also inserts into module). Sets
146152
/// the "_emit_c_interface" on the function declaration when requested,
147153
/// so that LLVM lowering generates a wrapper function that takes care
148154
/// of ABI complications with passing in and returning MemRefs to C functions.
149155
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
150156
TypeRange resultType, ValueRange operands,
151-
bool emitCInterface = false) {
157+
bool emitCInterface) {
152158
MLIRContext *context = op->getContext();
153159
auto module = op->getParentOfType<ModuleOp>();
154160
auto result = SymbolRefAttr::get(context, name);
@@ -165,6 +171,24 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
165171
return result;
166172
}
167173

174+
/// Creates a `CallOp` to the function reference returned by `getFunc()`.
175+
static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name,
176+
TypeRange resultType, ValueRange operands,
177+
bool emitCInterface = false) {
178+
auto fn = getFunc(op, name, resultType, operands, emitCInterface);
179+
return builder.create<CallOp>(op->getLoc(), resultType, fn, operands);
180+
}
181+
182+
/// Replaces the `op` with a `CallOp` to the function reference returned
183+
/// by `getFunc()`.
184+
static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op,
185+
StringRef name, TypeRange resultType,
186+
ValueRange operands,
187+
bool emitCInterface = false) {
188+
auto fn = getFunc(op, name, resultType, operands, emitCInterface);
189+
return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands);
190+
}
191+
168192
/// Generates dimension size call.
169193
static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
170194
SparseTensorEncodingAttr &enc, Value src,
@@ -173,25 +197,20 @@ static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
173197
if (AffineMap p = enc.getDimOrdering())
174198
idx = p.getPermutedPosition(idx);
175199
// Generate the call.
176-
Location loc = op->getLoc();
177200
StringRef name = "sparseDimSize";
178-
SmallVector<Value, 2> params;
179-
params.push_back(src);
180-
params.push_back(constantIndex(rewriter, loc, idx));
201+
SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)};
181202
Type iTp = rewriter.getIndexType();
182-
auto fn = getFunc(op, name, iTp, params);
183-
return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
203+
return createFuncCall(rewriter, op, name, iTp, params).getResult(0);
184204
}
185205

186206
/// Generates a call into the "swiss army knife" method of the sparse runtime
187207
/// support library for materializing sparse tensors into the computation.
188208
static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
189209
ArrayRef<Value> params) {
190-
Location loc = op->getLoc();
191210
StringRef name = "newSparseTensor";
192-
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
193-
auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
194-
auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
211+
Type pTp = getOpaquePointerType(rewriter);
212+
auto call = createFuncCall(rewriter, op, name, pTp, params,
213+
/*emitCInterface=*/true);
195214
return call.getResult(0);
196215
}
197216

@@ -210,8 +229,8 @@ static void sizesFromType(ConversionPatternRewriter &rewriter,
210229
static void sizesFromSrc(ConversionPatternRewriter &rewriter,
211230
SmallVector<Value, 4> &sizes, Location loc,
212231
Value src) {
213-
ShapedType stp = src.getType().cast<ShapedType>();
214-
for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
232+
unsigned rank = src.getType().cast<ShapedType>().getRank();
233+
for (unsigned i = 0; i < rank; i++)
215234
sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
216235
}
217236

@@ -221,12 +240,13 @@ static void sizesFromPtr(ConversionPatternRewriter &rewriter,
221240
SmallVector<Value, 4> &sizes, Operation *op,
222241
SparseTensorEncodingAttr &enc, ShapedType stp,
223242
Value src) {
243+
Location loc = op->getLoc();
224244
auto shape = stp.getShape();
225245
for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
226246
if (shape[i] == ShapedType::kDynamicSize)
227247
sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
228248
else
229-
sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
249+
sizes.push_back(constantIndex(rewriter, loc, shape[i]));
230250
}
231251

232252
/// Generates an uninitialized temporary buffer of the given size and
@@ -293,16 +313,15 @@ static void newParams(ConversionPatternRewriter &rewriter,
293313
}
294314
params.push_back(genBuffer(rewriter, loc, rev));
295315
// Secondary and primary types encoding.
296-
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
316+
Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType();
297317
params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
298318
params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
299-
params.push_back(
300-
constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
301-
// User action and pointer.
302-
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
303-
if (!ptr)
304-
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
319+
params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp));
320+
// User action.
305321
params.push_back(constantAction(rewriter, loc, action));
322+
// Payload pointer.
323+
if (!ptr)
324+
ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter));
306325
params.push_back(ptr);
307326
}
308327

@@ -352,7 +371,6 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
352371
static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
353372
Type eltType, Value ptr, Value val, Value ind,
354373
Value perm) {
355-
Location loc = op->getLoc();
356374
StringRef name;
357375
if (eltType.isF64())
358376
name = "addEltF64";
@@ -368,14 +386,9 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
368386
name = "addEltI8";
369387
else
370388
llvm_unreachable("Unknown element type");
371-
SmallVector<Value, 8> params;
372-
params.push_back(ptr);
373-
params.push_back(val);
374-
params.push_back(ind);
375-
params.push_back(perm);
376-
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
377-
auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
378-
rewriter.create<CallOp>(loc, pTp, fn, params);
389+
SmallVector<Value, 4> params{ptr, val, ind, perm};
390+
Type pTp = getOpaquePointerType(rewriter);
391+
createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true);
379392
}
380393

381394
/// Generates a call to `iter->getNext()`. If there is a next element,
@@ -384,7 +397,6 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
384397
/// the memory for `iter` is freed and the return value is false.
385398
static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
386399
Value iter, Value ind, Value elemPtr) {
387-
Location loc = op->getLoc();
388400
Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
389401
StringRef name;
390402
if (elemTp.isF64())
@@ -401,13 +413,10 @@ static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
401413
name = "getNextI8";
402414
else
403415
llvm_unreachable("Unknown element type");
404-
SmallVector<Value, 3> params;
405-
params.push_back(iter);
406-
params.push_back(ind);
407-
params.push_back(elemPtr);
416+
SmallVector<Value, 3> params{iter, ind, elemPtr};
408417
Type i1 = rewriter.getI1Type();
409-
auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true);
410-
auto call = rewriter.create<CallOp>(loc, i1, fn, params);
418+
auto call = createFuncCall(rewriter, op, name, i1, params,
419+
/*emitCInterface=*/true);
411420
return call.getResult(0);
412421
}
413422

@@ -461,7 +470,7 @@ static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc,
461470
}
462471
Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes);
463472
Value zero = constantZero(rewriter, loc, elemTp);
464-
rewriter.create<linalg::FillOp>(loc, zero, mem).result();
473+
rewriter.create<linalg::FillOp>(loc, zero, mem);
465474
return mem;
466475
}
467476

@@ -754,9 +763,8 @@ class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
754763
matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
755764
ConversionPatternRewriter &rewriter) const override {
756765
StringRef name = "delSparseTensor";
757-
TypeRange none;
758-
auto fn = getFunc(op, name, none, adaptor.getOperands());
759-
rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
766+
TypeRange noTp;
767+
createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
760768
rewriter.eraseOp(op);
761769
return success();
762770
}
@@ -785,9 +793,8 @@ class SparseTensorToPointersConverter
785793
name = "sparsePointers8";
786794
else
787795
return failure();
788-
auto fn = getFunc(op, name, resType, adaptor.getOperands(),
789-
/*emitCInterface=*/true);
790-
rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
796+
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
797+
/*emitCInterface=*/true);
791798
return success();
792799
}
793800
};
@@ -814,9 +821,8 @@ class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
814821
name = "sparseIndices8";
815822
else
816823
return failure();
817-
auto fn = getFunc(op, name, resType, adaptor.getOperands(),
818-
/*emitCInterface=*/true);
819-
rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
824+
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
825+
/*emitCInterface=*/true);
820826
return success();
821827
}
822828
};
@@ -845,9 +851,8 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
845851
name = "sparseValuesI8";
846852
else
847853
return failure();
848-
auto fn = getFunc(op, name, resType, adaptor.getOperands(),
849-
/*emitCInterface=*/true);
850-
rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
854+
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
855+
/*emitCInterface=*/true);
851856
return success();
852857
}
853858
};
@@ -863,8 +868,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
863868
// Finalize any pending insertions.
864869
StringRef name = "endInsert";
865870
TypeRange noTp;
866-
auto fn = getFunc(op, name, noTp, adaptor.getOperands());
867-
rewriter.create<CallOp>(op.getLoc(), noTp, fn, adaptor.getOperands());
871+
createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
868872
}
869873
rewriter.replaceOp(op, adaptor.getOperands());
870874
return success();
@@ -896,9 +900,8 @@ class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
896900
else
897901
llvm_unreachable("Unknown element type");
898902
TypeRange noTp;
899-
auto fn =
900-
getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true);
901-
rewriter.replaceOpWithNewOp<CallOp>(op, noTp, fn, adaptor.getOperands());
903+
replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
904+
/*emitCInterface=*/true);
902905
return success();
903906
}
904907
};

0 commit comments

Comments
 (0)