1818#include " CodegenUtils.h"
1919#include " SparseTensorDescriptor.h"
2020
21- #include " llvm/Support/FormatVariadic.h"
22-
2321#include " mlir/Dialect/Arith/Utils/Utils.h"
2422#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
2523#include " mlir/Dialect/Func/IR/FuncOps.h"
@@ -116,31 +114,36 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
116114 const SparseTensorType stt (desc.getRankedTensorType ());
117115 Value linear = constantIndex (builder, loc, 1 );
118116 const Level lvlRank = stt.getLvlRank ();
119- for (Level l = startLvl; l < lvlRank; l ++) {
120- const auto dlt = stt.getLvlType (l );
121- if (isCompressedDLT (dlt)) {
117+ for (Level lvl = startLvl; lvl < lvlRank; lvl ++) {
118+ const auto dlt = stt.getLvlType (lvl );
119+ if (isCompressedDLT (dlt) || isLooseCompressedDLT (dlt) ) {
122120 // Append linear x positions, initialized to zero. Since each compressed
123121 // dimension initially already has a single zero entry, this maintains
124- // the desired "linear + 1" length property at all times.
122+ // the desired "linear + 1" length property at all times. For loose
123+ // compression, we multiply linear by two in order to append both the
124+ // lo/hi positions.
125125 Value posZero = constantZero (builder, loc, stt.getPosType ());
126- createPushback (builder, loc, desc, SparseTensorFieldKind::PosMemRef, l,
127- posZero, linear);
126+ if (isLooseCompressedDLT (dlt)) {
127+ Value two = constantIndex (builder, loc, 2 );
128+ linear = builder.create <arith::MulIOp>(loc, linear, two);
129+ }
130+ createPushback (builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
131+ /* value=*/ posZero, /* repeat=*/ linear);
128132 return ;
129- }
130- if (isSingletonDLT (dlt)) {
133+ } else if (isSingletonDLT (dlt) || is2OutOf4DLT (dlt)) {
131134 return ; // nothing to do
132135 }
133136 // Keep compounding the size, but nothing needs to be initialized
134137 // at this level. We will eventually reach a compressed level or
135138 // otherwise the values array for the from-here "all-dense" case.
136139 assert (isDenseDLT (dlt));
137- Value size = desc.getLvlSize (builder, loc, l );
140+ Value size = desc.getLvlSize (builder, loc, lvl );
138141 linear = builder.create <arith::MulIOp>(loc, linear, size);
139142 }
140143 // Reached values array so prepare for an insertion.
141144 Value valZero = constantZero (builder, loc, stt.getElementType ());
142145 createPushback (builder, loc, desc, SparseTensorFieldKind::ValMemRef,
143- std::nullopt , valZero, linear);
146+ std::nullopt , /* value= */ valZero, /* repeat= */ linear);
144147}
145148
146149// / Creates allocation operation.
@@ -157,12 +160,9 @@ static Value createAllocation(OpBuilder &builder, Location loc,
157160}
158161
159162// / Creates allocation for each field in sparse tensor type. Note that
160- // / for all dynamic memrefs, the memory size is really the capacity of
161- // / the "vector", while the actual size resides in the sizes array.
162- // /
163- // / TODO: for efficiency, we will need heuristics to make educated guesses
164- // / on the required capacities (see heuristic variable).
165- // /
163+ // / for all dynamic memrefs in the sparse tensor stroage layout, the
164+ // / memory size is really the capacity of the "vector", while the actual
165+ // / size resides in the sizes array.
166166static void createAllocFields (OpBuilder &builder, Location loc,
167167 SparseTensorType stt, ValueRange dynSizes,
168168 bool enableInit, SmallVectorImpl<Value> &fields,
@@ -206,6 +206,8 @@ static void createAllocFields(OpBuilder &builder, Location loc,
206206 constantIndex (builder, loc, 16 );
207207 }
208208
209+ // Initializes all fields. An initial storage specifier and allocated
210+ // positions/coordinates/values memrefs (with heuristic capacity).
209211 foreachFieldAndTypeInSparseTensor (
210212 stt,
211213 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
@@ -218,14 +220,16 @@ static void createAllocFields(OpBuilder &builder, Location loc,
218220 field = SparseTensorSpecifier::getInitValue (builder, loc, stt);
219221 break ;
220222 case SparseTensorFieldKind::PosMemRef:
223+ field = createAllocation (builder, loc, cast<MemRefType>(fType ),
224+ posHeuristic, enableInit);
225+ break ;
221226 case SparseTensorFieldKind::CrdMemRef:
227+ field = createAllocation (builder, loc, cast<MemRefType>(fType ),
228+ crdHeuristic, enableInit);
229+ break ;
222230 case SparseTensorFieldKind::ValMemRef:
223- field = createAllocation (
224- builder, loc, cast<MemRefType>(fType ),
225- (fKind == SparseTensorFieldKind::PosMemRef) ? posHeuristic
226- : (fKind == SparseTensorFieldKind::CrdMemRef) ? crdHeuristic
227- : valHeuristic,
228- enableInit);
231+ field = createAllocation (builder, loc, cast<MemRefType>(fType ),
232+ valHeuristic, enableInit);
229233 break ;
230234 }
231235 assert (field);
@@ -234,21 +238,19 @@ static void createAllocFields(OpBuilder &builder, Location loc,
234238 return true ;
235239 });
236240
241+ // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
242+ // and gives all position fields an initial zero entry, so that it is
243+ // easier to maintain the "linear + 1" length property.
237244 MutSparseTensorDescriptor desc (stt, fields);
238-
239- // Initialize the storage scheme to an empty tensor. Initialized memSizes
240- // to all zeros, sets the dimSizes to known values and gives all position
241- // fields an initial zero entry, so that it is easier to maintain the
242- // "linear + 1" length property.
243245 Value posZero = constantZero (builder, loc, stt.getPosType ());
244- for (Level lvlRank = stt.getLvlRank (), l = 0 ; l < lvlRank; l++) {
245- // Fills dim sizes array.
246+ for (Level lvl = 0 , lvlRank = stt.getLvlRank (); lvl < lvlRank; lvl++) {
246247 // FIXME: `toOrigDim` is deprecated.
247- desc.setLvlSize (builder, loc, l, dimSizes[toOrigDim (stt.getEncoding (), l)]);
248- // Pushes a leading zero to positions memref.
249- if (stt.isCompressedLvl (l))
250- createPushback (builder, loc, desc, SparseTensorFieldKind::PosMemRef, l,
251- posZero);
248+ desc.setLvlSize (builder, loc, lvl,
249+ dimSizes[toOrigDim (stt.getEncoding (), lvl)]);
250+ const auto dlt = stt.getLvlType (lvl);
251+ if (isCompressedDLT (dlt) || isLooseCompressedDLT (dlt))
252+ createPushback (builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
253+ /* value=*/ posZero);
252254 }
253255 allocSchemeForRank (builder, loc, desc, /* rank=*/ 0 );
254256}
@@ -347,7 +349,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
347349 Value mszp1 = builder.create <arith::AddIOp>(loc, msz, one);
348350 genStore (builder, loc, mszp1, positionsAtLvl, pp1);
349351 createPushback (builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
350- lvlCoords[lvl]);
352+ /* value= */ lvlCoords[lvl]);
351353 // Prepare the next level "as needed".
352354 if ((lvl + 1 ) < lvlRank)
353355 allocSchemeForRank (builder, loc, desc, lvl + 1 );
@@ -371,8 +373,6 @@ static void genEndInsert(OpBuilder &builder, Location loc,
371373 const Level lvlRank = stt.getLvlRank ();
372374 for (Level l = 0 ; l < lvlRank; l++) {
373375 const auto dlt = stt.getLvlType (l);
374- if (isLooseCompressedDLT (dlt))
375- llvm_unreachable (" TODO: Not yet implemented" );
376376 if (isCompressedDLT (dlt)) {
377377 // Compressed dimensions need a position cleanup for all entries
378378 // that were not visited during the insertion pass.
@@ -407,7 +407,8 @@ static void genEndInsert(OpBuilder &builder, Location loc,
407407 builder.setInsertionPointAfter (loop);
408408 }
409409 } else {
410- assert (isDenseDLT (dlt) || isSingletonDLT (dlt));
410+ assert (isDenseDLT (dlt) || isLooseCompressedDLT (dlt) ||
411+ isSingletonDLT (dlt) || is2OutOf4DLT (dlt));
411412 }
412413 }
413414}
@@ -483,33 +484,37 @@ class SparseInsertGenerator
483484 Value value = args.back ();
484485 Value parentPos = constantZero (builder, loc, builder.getIndexType ());
485486 // Generate code for every level.
486- for (Level l = 0 ; l < lvlRank; l ++) {
487- const auto dlt = stt.getLvlType (l );
488- if (isCompressedDLT (dlt)) {
487+ for (Level lvl = 0 ; lvl < lvlRank; lvl ++) {
488+ const auto dlt = stt.getLvlType (lvl );
489+ if (isCompressedDLT (dlt) || isLooseCompressedDLT (dlt) ) {
489490 // Create:
490491 // if (!present) {
491- // coordinates[l ].push_back(coords[l ])
492- // <update positions and prepare level l + 1>
492+ // coordinates[lvl ].push_back(coords[lvl ])
493+ // <update positions and prepare level lvl + 1>
493494 // }
494- // positions[l] = coordinates.size() - 1
495- // <insert @ positions[l] at next level l + 1>
495+ // positions[lvl] = coordinates.size() - 1
496+ // <insert @ positions[lvl] at next level lvl + 1>
497+ if (isLooseCompressedDLT (dlt)) {
498+ Value two = constantIndex (builder, loc, 2 );
499+ parentPos = builder.create <arith::MulIOp>(loc, parentPos, two);
500+ }
496501 parentPos =
497- genCompressed (builder, loc, desc, coords, value, parentPos, l );
498- } else if (isSingletonDLT (dlt)) {
502+ genCompressed (builder, loc, desc, coords, value, parentPos, lvl );
503+ } else if (isSingletonDLT (dlt) || is2OutOf4DLT (dlt) ) {
499504 // Create:
500- // coordinates[l ].push_back(coords[l ])
501- // positions[l ] = positions[l -1]
502- // <insert @ positions[l ] at next level l + 1>
503- createPushback (builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l,
504- coords[l ]);
505+ // coordinates[lvl ].push_back(coords[lvl ])
506+ // positions[lvl ] = positions[lvl -1]
507+ // <insert @ positions[lvl ] at next level lvl + 1>
508+ createPushback (builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
509+ lvl, /* value= */ coords[lvl ]);
505510 } else {
506511 assert (isDenseDLT (dlt));
507512 // Construct the new position as:
508- // positions[l ] = size * positions[l -1] + coords[l ]
509- // <insert @ positions[l ] at next level l + 1>
510- Value size = desc.getLvlSize (builder, loc, l );
513+ // positions[lvl ] = size * positions[lvl -1] + coords[lvl ]
514+ // <insert @ positions[lvl ] at next level lvl + 1>
515+ Value size = desc.getLvlSize (builder, loc, lvl );
511516 Value mult = builder.create <arith::MulIOp>(loc, size, parentPos);
512- parentPos = builder.create <arith::AddIOp>(loc, mult, coords[l ]);
517+ parentPos = builder.create <arith::AddIOp>(loc, mult, coords[lvl ]);
513518 }
514519 }
515520 // Reached the actual value append/insert.
@@ -526,7 +531,6 @@ class SparseInsertGenerator
526531 // <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
527532 constexpr const char kInsertFuncNamePrefix [] = " _insert_" ;
528533 const SparseTensorType stt (llvm::cast<RankedTensorType>(rtp));
529-
530534 SmallString<32 > nameBuffer;
531535 llvm::raw_svector_ostream nameOstream (nameBuffer);
532536 nameOstream << kInsertFuncNamePrefix ;
@@ -543,8 +547,8 @@ class SparseInsertGenerator
543547 // Static dim sizes are used in the generated code while dynamic sizes are
544548 // loaded from the dimSizes buffer. This is the reason for adding the shape
545549 // to the function name.
546- for (const auto sh : stt.getDimShape ())
547- nameOstream << sh << " _" ;
550+ for (const auto sz : stt.getDimShape ())
551+ nameOstream << sz << " _" ;
548552 // Permutation information is also used in generating insertion.
549553 if (!stt.isIdentity ())
550554 nameOstream << stt.getDimToLvl () << " _" ;
@@ -607,7 +611,6 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
607611 assert (retOffset < newCall.getNumResults ());
608612 auto retType = ret.getType ();
609613 if (failed (typeConverter->convertType (retType, sparseFlat)))
610- // This should never happen.
611614 llvm_unreachable (" Failed to convert type in sparse tensor codegen" );
612615
613616 // Converted types can not be empty when the type conversion succeed.
@@ -755,9 +758,7 @@ class SparseTensorAllocConverter
755758 const auto resType = getSparseTensorType (op);
756759 if (!resType.hasEncoding ())
757760 return failure ();
758-
759- // Construct allocation for each field.
760- const Location loc = op.getLoc ();
761+ Location loc = op.getLoc ();
761762 if (op.getCopy ()) {
762763 auto desc = getDescriptorFromTensorTuple (adaptor.getCopy ());
763764 SmallVector<Value> fields;
@@ -778,18 +779,18 @@ class SparseTensorAllocConverter
778779 return success ();
779780 }
780781
781- const Value sizeHint = op.getSizeHint ();
782- const ValueRange dynSizes = adaptor.getDynamicSizes ();
782+ // Construct allocation for each field.
783+ Value sizeHint = op.getSizeHint ();
784+ ValueRange dynSizes = adaptor.getDynamicSizes ();
783785 const size_t found = dynSizes.size ();
784786 const int64_t expected = resType.getNumDynamicDims ();
785787 if (found != static_cast <size_t >(expected))
786- return rewriter.notifyMatchFailure (
787- op, llvm::formatv (
788- " Got wrong number of dynamic sizes: Found={0}, Expected={1}" ,
789- found, expected));
788+ return rewriter.notifyMatchFailure (op,
789+ " Got wrong number of dynamic sizes" );
790790 SmallVector<Value> fields;
791791 createAllocFields (rewriter, loc, resType, dynSizes,
792792 enableBufferInitialization, fields, sizeHint);
793+
793794 // Replace operation with resulting memrefs.
794795 rewriter.replaceOp (op, genTuple (rewriter, loc, resType, fields));
795796 return success ();
@@ -817,19 +818,18 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
817818 return failure ();
818819
819820 // Construct allocation for each field.
820- const Location loc = op.getLoc ();
821- const Value sizeHint; // none
821+ Location loc = op.getLoc ();
822+ Value sizeHint; // none
822823 const ValueRange dynSizes = adaptor.getDynamicSizes ();
823824 const size_t found = dynSizes.size ();
824825 const int64_t expected = resType.getNumDynamicDims ();
825826 if (found != static_cast <size_t >(expected))
826- return rewriter.notifyMatchFailure (
827- op, llvm::formatv (
828- " Got wrong number of dynamic sizes: Found={0}, Expected={1}" ,
829- found, expected));
827+ return rewriter.notifyMatchFailure (op,
828+ " Got wrong number of dynamic sizes" );
830829 SmallVector<Value> fields;
831830 createAllocFields (rewriter, loc, resType, dynSizes,
832831 enableBufferInitialization, fields, sizeHint);
832+
833833 // Replace operation with resulting memrefs.
834834 rewriter.replaceOp (op, genTuple (rewriter, loc, resType, fields));
835835 return success ();
@@ -1496,7 +1496,6 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
14961496 SmallVector<Value> fields;
14971497 createAllocFields (rewriter, loc, dstTp, dynSizes, /* enableInit=*/ false ,
14981498 fields, nse);
1499- MutSparseTensorDescriptor desc (dstTp, fields);
15001499
15011500 // Now construct the dim2lvl and lvl2dim buffers.
15021501 Value dim2lvlBuffer;
@@ -1505,6 +1504,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
15051504 dim2lvlBuffer, lvl2dimBuffer);
15061505
15071506 // Read the COO tensor data.
1507+ MutSparseTensorDescriptor desc (dstTp, fields);
15081508 Value xs = desc.getAOSMemRef ();
15091509 Value ys = desc.getValMemRef ();
15101510 const Type boolTp = rewriter.getIntegerType (1 );
0 commit comments