@@ -142,13 +142,19 @@ constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
142142 return constantI8 (rewriter, loc, static_cast <uint8_t >(dlt2));
143143}
144144
145+ // / Returns the equivalent of `void*` for opaque arguments to the
146+ // / execution engine.
147+ static Type getOpaquePointerType (PatternRewriter &rewriter) {
148+ return LLVM::LLVMPointerType::get (rewriter.getI8Type ());
149+ }
150+
145151// / Returns a function reference (first hit also inserts into module). Sets
146152// / the "_emit_c_interface" on the function declaration when requested,
147153// / so that LLVM lowering generates a wrapper function that takes care
148154// / of ABI complications with passing in and returning MemRefs to C functions.
149155static FlatSymbolRefAttr getFunc (Operation *op, StringRef name,
150156 TypeRange resultType, ValueRange operands,
151- bool emitCInterface = false ) {
157+ bool emitCInterface) {
152158 MLIRContext *context = op->getContext ();
153159 auto module = op->getParentOfType <ModuleOp>();
154160 auto result = SymbolRefAttr::get (context, name);
@@ -165,6 +171,24 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
165171 return result;
166172}
167173
174+ // / Creates a `CallOp` to the function reference returned by `getFunc()`.
175+ static CallOp createFuncCall (OpBuilder &builder, Operation *op, StringRef name,
176+ TypeRange resultType, ValueRange operands,
177+ bool emitCInterface = false ) {
178+ auto fn = getFunc (op, name, resultType, operands, emitCInterface);
179+ return builder.create <CallOp>(op->getLoc (), resultType, fn, operands);
180+ }
181+
182+ // / Replaces the `op` with a `CallOp` to the function reference returned
183+ // / by `getFunc()`.
184+ static CallOp replaceOpWithFuncCall (PatternRewriter &rewriter, Operation *op,
185+ StringRef name, TypeRange resultType,
186+ ValueRange operands,
187+ bool emitCInterface = false ) {
188+ auto fn = getFunc (op, name, resultType, operands, emitCInterface);
189+ return rewriter.replaceOpWithNewOp <CallOp>(op, resultType, fn, operands);
190+ }
191+
168192// / Generates dimension size call.
169193static Value genDimSizeCall (ConversionPatternRewriter &rewriter, Operation *op,
170194 SparseTensorEncodingAttr &enc, Value src,
@@ -173,25 +197,20 @@ static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
173197 if (AffineMap p = enc.getDimOrdering ())
174198 idx = p.getPermutedPosition (idx);
175199 // Generate the call.
176- Location loc = op->getLoc ();
177200 StringRef name = " sparseDimSize" ;
178- SmallVector<Value, 2 > params;
179- params.push_back (src);
180- params.push_back (constantIndex (rewriter, loc, idx));
201+ SmallVector<Value, 2 > params{src, constantIndex (rewriter, op->getLoc (), idx)};
181202 Type iTp = rewriter.getIndexType ();
182- auto fn = getFunc (op, name, iTp, params);
183- return rewriter.create <CallOp>(loc, iTp, fn, params).getResult (0 );
203+ return createFuncCall (rewriter, op, name, iTp, params).getResult (0 );
184204}
185205
186206// / Generates a call into the "swiss army knife" method of the sparse runtime
187207// / support library for materializing sparse tensors into the computation.
188208static Value genNewCall (ConversionPatternRewriter &rewriter, Operation *op,
189209 ArrayRef<Value> params) {
190- Location loc = op->getLoc ();
191210 StringRef name = " newSparseTensor" ;
192- Type pTp = LLVM::LLVMPointerType::get (rewriter. getI8Type () );
193- auto fn = getFunc ( op, name, pTp, params, /* emitCInterface= */ true );
194- auto call = rewriter. create <CallOp>(loc, pTp, fn, params );
211+ Type pTp = getOpaquePointerType (rewriter);
212+ auto call = createFuncCall (rewriter, op, name, pTp, params,
213+ /* emitCInterface= */ true );
195214 return call.getResult (0 );
196215}
197216
@@ -210,8 +229,8 @@ static void sizesFromType(ConversionPatternRewriter &rewriter,
210229static void sizesFromSrc (ConversionPatternRewriter &rewriter,
211230 SmallVector<Value, 4 > &sizes, Location loc,
212231 Value src) {
213- ShapedType stp = src.getType ().cast <ShapedType>();
214- for (unsigned i = 0 , rank = stp. getRank () ; i < rank; i++)
232+ unsigned rank = src.getType ().cast <ShapedType>(). getRank ();
233+ for (unsigned i = 0 ; i < rank; i++)
215234 sizes.push_back (linalg::createOrFoldDimOp (rewriter, loc, src, i));
216235}
217236
@@ -221,12 +240,13 @@ static void sizesFromPtr(ConversionPatternRewriter &rewriter,
221240 SmallVector<Value, 4 > &sizes, Operation *op,
222241 SparseTensorEncodingAttr &enc, ShapedType stp,
223242 Value src) {
243+ Location loc = op->getLoc ();
224244 auto shape = stp.getShape ();
225245 for (unsigned i = 0 , rank = stp.getRank (); i < rank; i++)
226246 if (shape[i] == ShapedType::kDynamicSize )
227247 sizes.push_back (genDimSizeCall (rewriter, op, enc, src, i));
228248 else
229- sizes.push_back (constantIndex (rewriter, op-> getLoc () , shape[i]));
249+ sizes.push_back (constantIndex (rewriter, loc , shape[i]));
230250}
231251
232252// / Generates an uninitialized temporary buffer of the given size and
@@ -293,16 +313,15 @@ static void newParams(ConversionPatternRewriter &rewriter,
293313 }
294314 params.push_back (genBuffer (rewriter, loc, rev));
295315 // Secondary and primary types encoding.
296- ShapedType resType = op->getResult (0 ).getType ().cast <ShapedType>();
316+ Type elemTp = op->getResult (0 ).getType ().cast <ShapedType>(). getElementType ();
297317 params.push_back (constantPointerTypeEncoding (rewriter, loc, enc));
298318 params.push_back (constantIndexTypeEncoding (rewriter, loc, enc));
299- params.push_back (
300- constantPrimaryTypeEncoding (rewriter, loc, resType.getElementType ()));
301- // User action and pointer.
302- Type pTp = LLVM::LLVMPointerType::get (rewriter.getI8Type ());
303- if (!ptr)
304- ptr = rewriter.create <LLVM::NullOp>(loc, pTp);
319+ params.push_back (constantPrimaryTypeEncoding (rewriter, loc, elemTp));
320+ // User action.
305321 params.push_back (constantAction (rewriter, loc, action));
322+ // Payload pointer.
323+ if (!ptr)
324+ ptr = rewriter.create <LLVM::NullOp>(loc, getOpaquePointerType (rewriter));
306325 params.push_back (ptr);
307326}
308327
@@ -352,7 +371,6 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
352371static void genAddEltCall (ConversionPatternRewriter &rewriter, Operation *op,
353372 Type eltType, Value ptr, Value val, Value ind,
354373 Value perm) {
355- Location loc = op->getLoc ();
356374 StringRef name;
357375 if (eltType.isF64 ())
358376 name = " addEltF64" ;
@@ -368,14 +386,9 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
368386 name = " addEltI8" ;
369387 else
370388 llvm_unreachable (" Unknown element type" );
371- SmallVector<Value, 8 > params;
372- params.push_back (ptr);
373- params.push_back (val);
374- params.push_back (ind);
375- params.push_back (perm);
376- Type pTp = LLVM::LLVMPointerType::get (rewriter.getI8Type ());
377- auto fn = getFunc (op, name, pTp, params, /* emitCInterface=*/ true );
378- rewriter.create <CallOp>(loc, pTp, fn, params);
389+ SmallVector<Value, 4 > params{ptr, val, ind, perm};
390+ Type pTp = getOpaquePointerType (rewriter);
391+ createFuncCall (rewriter, op, name, pTp, params, /* emitCInterface=*/ true );
379392}
380393
381394// / Generates a call to `iter->getNext()`. If there is a next element,
@@ -384,7 +397,6 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
384397// / the memory for `iter` is freed and the return value is false.
385398static Value genGetNextCall (ConversionPatternRewriter &rewriter, Operation *op,
386399 Value iter, Value ind, Value elemPtr) {
387- Location loc = op->getLoc ();
388400 Type elemTp = elemPtr.getType ().cast <ShapedType>().getElementType ();
389401 StringRef name;
390402 if (elemTp.isF64 ())
@@ -401,13 +413,10 @@ static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
401413 name = " getNextI8" ;
402414 else
403415 llvm_unreachable (" Unknown element type" );
404- SmallVector<Value, 3 > params;
405- params.push_back (iter);
406- params.push_back (ind);
407- params.push_back (elemPtr);
416+ SmallVector<Value, 3 > params{iter, ind, elemPtr};
408417 Type i1 = rewriter.getI1Type ();
409- auto fn = getFunc ( op, name, i1, params, /* emitCInterface= */ true );
410- auto call = rewriter. create <CallOp>(loc, i1, fn, params );
418+ auto call = createFuncCall (rewriter, op, name, i1, params,
419+ /* emitCInterface= */ true );
411420 return call.getResult (0 );
412421}
413422
@@ -461,7 +470,7 @@ static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc,
461470 }
462471 Value mem = rewriter.create <memref::AllocOp>(loc, memTp, dynamicSizes);
463472 Value zero = constantZero (rewriter, loc, elemTp);
464- rewriter.create <linalg::FillOp>(loc, zero, mem). result () ;
473+ rewriter.create <linalg::FillOp>(loc, zero, mem);
465474 return mem;
466475}
467476
@@ -754,9 +763,8 @@ class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
754763 matchAndRewrite (ReleaseOp op, OpAdaptor adaptor,
755764 ConversionPatternRewriter &rewriter) const override {
756765 StringRef name = " delSparseTensor" ;
757- TypeRange none;
758- auto fn = getFunc (op, name, none, adaptor.getOperands ());
759- rewriter.create <CallOp>(op.getLoc (), none, fn, adaptor.getOperands ());
766+ TypeRange noTp;
767+ createFuncCall (rewriter, op, name, noTp, adaptor.getOperands ());
760768 rewriter.eraseOp (op);
761769 return success ();
762770 }
@@ -785,9 +793,8 @@ class SparseTensorToPointersConverter
785793 name = " sparsePointers8" ;
786794 else
787795 return failure ();
788- auto fn = getFunc (op, name, resType, adaptor.getOperands (),
789- /* emitCInterface=*/ true );
790- rewriter.replaceOpWithNewOp <CallOp>(op, resType, fn, adaptor.getOperands ());
796+ replaceOpWithFuncCall (rewriter, op, name, resType, adaptor.getOperands (),
797+ /* emitCInterface=*/ true );
791798 return success ();
792799 }
793800};
@@ -814,9 +821,8 @@ class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
814821 name = " sparseIndices8" ;
815822 else
816823 return failure ();
817- auto fn = getFunc (op, name, resType, adaptor.getOperands (),
818- /* emitCInterface=*/ true );
819- rewriter.replaceOpWithNewOp <CallOp>(op, resType, fn, adaptor.getOperands ());
824+ replaceOpWithFuncCall (rewriter, op, name, resType, adaptor.getOperands (),
825+ /* emitCInterface=*/ true );
820826 return success ();
821827 }
822828};
@@ -845,9 +851,8 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
845851 name = " sparseValuesI8" ;
846852 else
847853 return failure ();
848- auto fn = getFunc (op, name, resType, adaptor.getOperands (),
849- /* emitCInterface=*/ true );
850- rewriter.replaceOpWithNewOp <CallOp>(op, resType, fn, adaptor.getOperands ());
854+ replaceOpWithFuncCall (rewriter, op, name, resType, adaptor.getOperands (),
855+ /* emitCInterface=*/ true );
851856 return success ();
852857 }
853858};
@@ -863,8 +868,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
863868 // Finalize any pending insertions.
864869 StringRef name = " endInsert" ;
865870 TypeRange noTp;
866- auto fn = getFunc (op, name, noTp, adaptor.getOperands ());
867- rewriter.create <CallOp>(op.getLoc (), noTp, fn, adaptor.getOperands ());
871+ createFuncCall (rewriter, op, name, noTp, adaptor.getOperands ());
868872 }
869873 rewriter.replaceOp (op, adaptor.getOperands ());
870874 return success ();
@@ -896,9 +900,8 @@ class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
896900 else
897901 llvm_unreachable (" Unknown element type" );
898902 TypeRange noTp;
899- auto fn =
900- getFunc (op, name, noTp, adaptor.getOperands (), /* emitCInterface=*/ true );
901- rewriter.replaceOpWithNewOp <CallOp>(op, noTp, fn, adaptor.getOperands ());
903+ replaceOpWithFuncCall (rewriter, op, name, noTp, adaptor.getOperands (),
904+ /* emitCInterface=*/ true );
902905 return success ();
903906 }
904907};
0 commit comments