Skip to content

[mlir][sparse] implement sparse_tensor.reorder_coo #68916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ enum class Action : uint32_t {
kToCOO = 5,
kToIterator = 6,
kPack = 7,
// Sort an unordered COO in place.
kSortCOOInPlace = 8,
};

/// This enum defines all the sparse representations supportable by
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,6 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
// Whether the convert can be done by a single step (either a sort or a foreach),
// or it would require a tmp buffer (sort, then foreach).
bool directConvertable();

// Whether the convert is actually a sort coo
// TODO: The method will be removed when sort_coo operation is introduced.
bool isSortCOOConvert();
}];

let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
Expand Down
75 changes: 75 additions & 0 deletions mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {

~SparseTensorStorage() final = default;

void sortInPlace();

/// Partially specialize these getter methods based on template types.
void getPositions(std::vector<P> **out, uint64_t lvl) final {
assert(out && "Received nullptr for out parameter");
Expand All @@ -374,6 +376,24 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
/// Partially specialize lexicographical insertions based on template types.
void lexInsert(const uint64_t *lvlCoords, V val) final {
assert(lvlCoords && "Received nullptr for level-coordinates");
// TODO: get rid of this! canonicalize all-dense "sparse" array into dense
// tensors.
bool allDense = true;
for (DimLevelType lt : getLvlTypes()) {
if (!isDenseDLT(lt)) {
allDense = false;
break;
}
}
if (allDense) {
uint64_t lvlRank = getLvlRank();
uint64_t valIdx = 0;
// Linearize the address
for (size_t lvl = 0; lvl < lvlRank; lvl++)
valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
values[valIdx] = val;
return;
}
// First, wrap up pending insertion path.
uint64_t diffLvl = 0;
uint64_t full = 0;
Expand Down Expand Up @@ -956,6 +976,61 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::packFromLvlBuffers(
return tensor;
}

template <typename P, typename C, typename V>
void SparseTensorStorage<P, C, V>::sortInPlace() {
uint64_t nnz = values.size();
#ifndef NDEBUG
for (uint64_t l = 0; l < getLvlRank(); l++)
assert(nnz == coordinates[l].size());
#endif

// In-place permutation.
auto applyPerm = [this](std::vector<uint64_t> &perm) {
size_t length = perm.size();
size_t lvlRank = getLvlRank();
// Cache for the current level coordinates.
std::vector<P> lvlCrds(lvlRank);
for (size_t i = 0; i < length; i++) {
size_t current = i;
if (i != perm[current]) {
for (size_t l = 0; l < lvlRank; l++)
lvlCrds[l] = coordinates[l][i];
V val = values[i];
// Deals with a permutation cycle.
while (i != perm[current]) {
size_t next = perm[current];
// Swaps the level coordinates and value.
for (size_t l = 0; l < lvlRank; l++)
coordinates[l][current] = coordinates[l][next];
values[current] = values[next];
perm[current] = current;
current = next;
}
for (size_t l = 0; l < lvlRank; l++)
coordinates[l][current] = lvlCrds[l];
values[current] = val;
perm[current] = current;
}
}
};

std::vector<uint64_t> sortedIdx(nnz, 0);
for (uint64_t i = 0; i < nnz; i++)
sortedIdx[i] = i;

std::sort(sortedIdx.begin(), sortedIdx.end(),
[this](uint64_t lhs, uint64_t rhs) {
for (uint64_t l = 0; l < getLvlRank(); l++) {
if (coordinates[l][lhs] == coordinates[l][rhs])
continue;
return coordinates[l][lhs] < coordinates[l][rhs];
}
assert(false && "duplicate coordinates");
});

applyPerm(sortedIdx);
}

template <typename P, typename C, typename V>
SparseTensorStorage<P, C, V>::SparseTensorStorage(
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
Expand Down
19 changes: 1 addition & 18 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,20 +1060,12 @@ LogicalResult ConvertOp::verify() {
}

OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
Type dstType = getType();
// Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse
// convert for codegen to remove. This is because we use trivial
// sparse-to-sparse convert to tell bufferization that the sparse codegen
// will expand the tensor buffer into sparse tensor storage.
if (!getSparseTensorEncoding(dstType) && dstType == getSource().getType())
if (getType() == getSource().getType())
return getSource();
return {};
}

bool ConvertOp::directConvertable() {
if (isSortCOOConvert())
return false;

SparseTensorType srcStt = getSparseTensorType(getSource());
SparseTensorType dstStt = getSparseTensorType(getDest());

Expand All @@ -1099,15 +1091,6 @@ bool ConvertOp::directConvertable() {
return false;
}

bool ConvertOp::isSortCOOConvert() {
// TODO: we should instead use a different sort_coo operation to handle
// the conversion between COOs (but with different ordering).
return isUniqueCOOType(getSource().getType()) &&
isUniqueCOOType(getDest().getType()) &&
!getSparseTensorType(getSource()).isAllOrdered() &&
getSparseTensorType(getDest()).isAllOrdered();
}

LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
Expand Down
29 changes: 10 additions & 19 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,31 +680,26 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
};

// TODO: use a new SortCOO operation here instead of reusing convert op.
struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ConvertOp op, ConvertOpAdaptor adaptor,
matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Direct conversion should have already been lowered.
if (!op.isSortCOOConvert())
return failure();

Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();

SparseTensorType srcStt = getSparseTensorType(op.getSource());
SparseTensorType dstStt = getSparseTensorType(op.getDest());
SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());

// TODO: This should be verification rules for sort_coo operation.
// Should have been verified.
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
isUniqueCOOType(srcStt.getRankedTensorType()) &&
isUniqueCOOType(dstStt.getRankedTensorType()));

assert(dstStt.hasSameDimToLvl(srcStt));

// We don't need a mutable descriptor here as we perform sorting in-place.
auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getSource());
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
auto crd = desc.getAOSMemRef();
auto val = desc.getValMemRef();

Expand All @@ -715,12 +710,11 @@ struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);

rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
rewriter.getIndexAttr(0),
SparseTensorSortKind::HybridQuickSort);
rewriter.getIndexAttr(0), op.getAlgorithm());

// Since we do in-place sorting, the destinate tensor will have the same set
// of memrefs as the source tensor.
rewriter.replaceOp(op, adaptor.getSource());
rewriter.replaceOp(op, adaptor.getInputCoo());
return success();
}
};
Expand Down Expand Up @@ -1147,9 +1141,6 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
LogicalResult
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.isSortCOOConvert())
return failure();

SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
getSparseTensorEncoding(op.getSource().getType());
Expand Down Expand Up @@ -1603,7 +1594,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
SparseSortCOOConverter,
SparseReorderCOOConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
Expand Down
Loading