Skip to content

Commit d3af653

Browse files
authored
[mlir][sparse] introduce MapRef, unify conversion/codegen for reader (#68360)
This revision introduces a MapRef, which will support a future generalization beyond permutations (e.g. block sparsity). This revision also unifies the conversion/codegen paths for the sparse_tensor.new operation from file (eg. the readers). Note that more unification is planned as well as general affine dim2lvl and lvl2dim (all marked with TODOs).
1 parent f045f2c commit d3af653

File tree

14 files changed

+437
-483
lines changed

14 files changed

+437
-483
lines changed

mlir/include/mlir/ExecutionEngine/SparseTensor/File.h

+58-98
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_FILE_H
2121
#define MLIR_EXECUTIONENGINE_SPARSETENSOR_FILE_H
2222

23+
#include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
2324
#include "mlir/ExecutionEngine/SparseTensor/Storage.h"
2425

2526
#include <fstream>
@@ -75,6 +76,10 @@ inline V readValue(char **linePtr, bool isPattern) {
7576

7677
} // namespace detail
7778

79+
//===----------------------------------------------------------------------===//
80+
//
81+
// Reader class.
82+
//
7883
//===----------------------------------------------------------------------===//
7984

8085
/// This class abstracts over the information stored in file headers,
@@ -132,6 +137,7 @@ class SparseTensorReader final {
132137
/// Reads and parses the file's header.
133138
void readHeader();
134139

140+
/// Returns the stored value kind.
135141
ValueKind getValueKind() const { return valueKind_; }
136142

137143
/// Checks if a header has been successfully read.
@@ -185,58 +191,37 @@ class SparseTensorReader final {
185191
/// valid after parsing the header.
186192
void assertMatchesShape(uint64_t rank, const uint64_t *shape) const;
187193

188-
/// Reads a sparse tensor element from the next line in the input file and
189-
/// returns the value of the element. Stores the coordinates of the element
190-
/// to the `dimCoords` array.
191-
template <typename V>
192-
V readElement(uint64_t dimRank, uint64_t *dimCoords) {
193-
assert(dimRank == getRank() && "rank mismatch");
194-
char *linePtr = readCoords(dimCoords);
195-
return detail::readValue<V>(&linePtr, isPattern());
196-
}
197-
198-
/// Allocates a new COO object for `lvlSizes`, initializes it by reading
199-
/// all the elements from the file and applying `dim2lvl` to their
200-
/// dim-coordinates, and then closes the file. Templated on V only.
201-
template <typename V>
202-
SparseTensorCOO<V> *readCOO(uint64_t lvlRank, const uint64_t *lvlSizes,
203-
const uint64_t *dim2lvl);
204-
205194
/// Allocates a new sparse-tensor storage object with the given encoding,
206195
/// initializes it by reading all the elements from the file, and then
207196
/// closes the file. Templated on P, I, and V.
208197
template <typename P, typename I, typename V>
209198
SparseTensorStorage<P, I, V> *
210199
readSparseTensor(uint64_t lvlRank, const uint64_t *lvlSizes,
211-
const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
212-
const uint64_t *dim2lvl) {
213-
auto *lvlCOO = readCOO<V>(lvlRank, lvlSizes, dim2lvl);
200+
const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
201+
const uint64_t *lvl2dim) {
202+
const uint64_t dimRank = getRank();
203+
MapRef map(dimRank, lvlRank, dim2lvl, lvl2dim);
204+
auto *coo = readCOO<V>(map, lvlSizes);
214205
auto *tensor = SparseTensorStorage<P, I, V>::newFromCOO(
215-
getRank(), getDimSizes(), lvlRank, lvlTypes, lvl2dim, *lvlCOO);
216-
delete lvlCOO;
206+
dimRank, getDimSizes(), lvlRank, lvlTypes, lvl2dim, *coo);
207+
delete coo;
217208
return tensor;
218209
}
219210

220211
/// Reads the COO tensor from the file, stores the coordinates and values to
221212
/// the given buffers, returns a boolean value to indicate whether the COO
222213
/// elements are sorted.
223-
/// Precondition: the buffers should have enough space to hold the elements.
224214
template <typename C, typename V>
225215
bool readToBuffers(uint64_t lvlRank, const uint64_t *dim2lvl,
226-
C *lvlCoordinates, V *values);
216+
const uint64_t *lvl2dim, C *lvlCoordinates, V *values);
227217

228218
private:
229-
/// Attempts to read a line from the file. Is private because there's
230-
/// no reason for client code to call it.
219+
/// Attempts to read a line from the file.
231220
void readLine();
232221

233222
/// Reads the next line of the input file and parses the coordinates
234223
/// into the `dimCoords` argument. Returns the position in the `line`
235-
/// buffer where the element's value should be parsed from. This method
236-
/// has been factored out from `readElement` to minimize code bloat
237-
/// for the generated library.
238-
///
239-
/// Precondition: `dimCoords` is valid for `getRank()`.
224+
/// buffer where the element's value should be parsed from.
240225
template <typename C>
241226
char *readCoords(C *dimCoords) {
242227
readLine();
@@ -251,24 +236,20 @@ class SparseTensorReader final {
251236
return linePtr;
252237
}
253238

254-
/// The internal implementation of `readCOO`. We template over
255-
/// `IsPattern` in order to perform LICM without needing to duplicate the
256-
/// source code.
257-
//
258-
// TODO: We currently take the `dim2lvl` argument as a `PermutationRef`
259-
// since that's what `readCOO` creates. Once we update `readCOO` to
260-
// functionalize the mapping, then this helper will just take that
261-
// same function.
239+
/// Reads all the elements from the file while applying the given map.
240+
template <typename V>
241+
SparseTensorCOO<V> *readCOO(const MapRef &map, const uint64_t *lvlSizes);
242+
243+
/// The implementation of `readCOO` that is templated `IsPattern` in order
244+
/// to perform LICM without needing to duplicate the source code.
262245
template <typename V, bool IsPattern>
263-
void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
264-
SparseTensorCOO<V> *lvlCOO);
246+
void readCOOLoop(const MapRef &map, SparseTensorCOO<V> *coo);
265247

266-
/// The internal implementation of `readToBuffers`. We template over
267-
/// `IsPattern` in order to perform LICM without needing to duplicate the
268-
/// source code.
248+
/// The internal implementation of `readToBuffers`. We template over
249+
/// `IsPattern` in order to perform LICM without needing to duplicate
250+
/// the source code.
269251
template <typename C, typename V, bool IsPattern>
270-
bool readToBuffersLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
271-
C *lvlCoordinates, V *values);
252+
bool readToBuffersLoop(const MapRef &map, C *lvlCoordinates, V *values);
272253

273254
/// Reads the MME header of a general sparse matrix of type real.
274255
void readMMEHeader();
@@ -288,96 +269,76 @@ class SparseTensorReader final {
288269
char line[kColWidth];
289270
};
290271

272+
//===----------------------------------------------------------------------===//
273+
//
274+
// Reader class methods.
275+
//
291276
//===----------------------------------------------------------------------===//
292277

293278
template <typename V>
294-
SparseTensorCOO<V> *SparseTensorReader::readCOO(uint64_t lvlRank,
295-
const uint64_t *lvlSizes,
296-
const uint64_t *dim2lvl) {
279+
SparseTensorCOO<V> *SparseTensorReader::readCOO(const MapRef &map,
280+
const uint64_t *lvlSizes) {
297281
assert(isValid() && "Attempt to readCOO() before readHeader()");
298-
const uint64_t dimRank = getRank();
299-
assert(lvlRank == dimRank && "Rank mismatch");
300-
detail::PermutationRef d2l(dimRank, dim2lvl);
301282
// Prepare a COO object with the number of stored elems as initial capacity.
302-
auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, getNSE());
303-
// Do some manual LICM, to avoid assertions in the for-loop.
304-
const bool IsPattern = isPattern();
305-
if (IsPattern)
306-
readCOOLoop<V, true>(lvlRank, d2l, lvlCOO);
283+
auto *coo = new SparseTensorCOO<V>(map.getLvlRank(), lvlSizes, getNSE());
284+
// Enter the reading loop.
285+
if (isPattern())
286+
readCOOLoop<V, true>(map, coo);
307287
else
308-
readCOOLoop<V, false>(lvlRank, d2l, lvlCOO);
288+
readCOOLoop<V, false>(map, coo);
309289
// Close the file and return the COO.
310290
closeFile();
311-
return lvlCOO;
291+
return coo;
312292
}
313293

314294
template <typename V, bool IsPattern>
315-
void SparseTensorReader::readCOOLoop(uint64_t lvlRank,
316-
detail::PermutationRef dim2lvl,
317-
SparseTensorCOO<V> *lvlCOO) {
318-
const uint64_t dimRank = getRank();
295+
void SparseTensorReader::readCOOLoop(const MapRef &map,
296+
SparseTensorCOO<V> *coo) {
297+
const uint64_t dimRank = map.getDimRank();
298+
const uint64_t lvlRank = map.getLvlRank();
299+
assert(dimRank == getRank());
319300
std::vector<uint64_t> dimCoords(dimRank);
320301
std::vector<uint64_t> lvlCoords(lvlRank);
321-
for (uint64_t nse = getNSE(), k = 0; k < nse; ++k) {
322-
// We inline `readElement` here in order to avoid redundant
323-
// assertions, since they're guaranteed by the call to `isValid()`
324-
// and the construction of `dimCoords` above.
302+
for (uint64_t k = 0, nse = getNSE(); k < nse; k++) {
325303
char *linePtr = readCoords(dimCoords.data());
326304
const V value = detail::readValue<V, IsPattern>(&linePtr);
327-
dim2lvl.pushforward(dimRank, dimCoords.data(), lvlCoords.data());
328-
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
329-
lvlCOO->add(lvlCoords, value);
305+
map.pushforward(dimCoords.data(), lvlCoords.data());
306+
coo->add(lvlCoords, value);
330307
}
331308
}
332309

333310
template <typename C, typename V>
334311
bool SparseTensorReader::readToBuffers(uint64_t lvlRank,
335312
const uint64_t *dim2lvl,
313+
const uint64_t *lvl2dim,
336314
C *lvlCoordinates, V *values) {
337315
assert(isValid() && "Attempt to readCOO() before readHeader()");
338-
// Construct a `PermutationRef` for the `pushforward` below.
339-
// TODO: This specific implementation does not generalize to arbitrary
340-
// mappings, but once we functionalize the `dim2lvl` argument we can
341-
// simply use that function instead.
342-
const uint64_t dimRank = getRank();
343-
assert(lvlRank == dimRank && "Rank mismatch");
344-
detail::PermutationRef d2l(dimRank, dim2lvl);
345-
// Do some manual LICM, to avoid assertions in the for-loop.
316+
MapRef map(getRank(), lvlRank, dim2lvl, lvl2dim);
346317
bool isSorted =
347-
isPattern()
348-
? readToBuffersLoop<C, V, true>(lvlRank, d2l, lvlCoordinates, values)
349-
: readToBuffersLoop<C, V, false>(lvlRank, d2l, lvlCoordinates,
350-
values);
351-
352-
// Close the file and return isSorted.
318+
isPattern() ? readToBuffersLoop<C, V, true>(map, lvlCoordinates, values)
319+
: readToBuffersLoop<C, V, false>(map, lvlCoordinates, values);
353320
closeFile();
354321
return isSorted;
355322
}
356323

357324
template <typename C, typename V, bool IsPattern>
358-
bool SparseTensorReader::readToBuffersLoop(uint64_t lvlRank,
359-
detail::PermutationRef dim2lvl,
360-
C *lvlCoordinates, V *values) {
361-
const uint64_t dimRank = getRank();
325+
bool SparseTensorReader::readToBuffersLoop(const MapRef &map, C *lvlCoordinates,
326+
V *values) {
327+
const uint64_t dimRank = map.getDimRank();
328+
const uint64_t lvlRank = map.getLvlRank();
362329
const uint64_t nse = getNSE();
330+
assert(dimRank == getRank());
363331
std::vector<C> dimCoords(dimRank);
364-
// Read the first element with isSorted=false as a way to avoid accessing its
365-
// previous element.
366332
bool isSorted = false;
367333
char *linePtr;
368-
// We inline `readElement` here in order to avoid redundant assertions,
369-
// since they're guaranteed by the call to `isValid()` and the construction
370-
// of `dimCoords` above.
371334
const auto readNextElement = [&]() {
372335
linePtr = readCoords<C>(dimCoords.data());
373-
dim2lvl.pushforward(dimRank, dimCoords.data(), lvlCoordinates);
336+
map.pushforward(dimCoords.data(), lvlCoordinates);
374337
*values = detail::readValue<V, IsPattern>(&linePtr);
375338
if (isSorted) {
376-
// Note that isSorted was set to false while reading the first element,
339+
// Note that isSorted is set to false when reading the first element,
377340
// to guarantee the safeness of using prevLvlCoords.
378341
C *prevLvlCoords = lvlCoordinates - lvlRank;
379-
// TODO: define a new CoordsLT which is like ElementLT but doesn't have
380-
// the V parameter, and use it here.
381342
for (uint64_t l = 0; l < lvlRank; ++l) {
382343
if (prevLvlCoords[l] != lvlCoordinates[l]) {
383344
if (prevLvlCoords[l] > lvlCoordinates[l])
@@ -393,7 +354,6 @@ bool SparseTensorReader::readToBuffersLoop(uint64_t lvlRank,
393354
isSorted = true;
394355
for (uint64_t n = 1; n < nse; ++n)
395356
readNextElement();
396-
397357
return isSorted;
398358
}
399359

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//===- MapRef.h - A dim2lvl/lvl2dim map encoding ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// A dim2lvl/lvl2dim map encoding class, with utility methods.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H
14+
#define MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H
15+
16+
#include <cinttypes>
17+
18+
#include <cassert>
19+
#include <vector>
20+
21+
namespace mlir {
22+
namespace sparse_tensor {
23+
24+
/// A class for capturing the sparse tensor type map with a compact encoding.
25+
///
26+
/// Currently, the following situations are supported:
27+
/// (1) map is a permutation
28+
/// (2) map has affine ops (restricted set)
29+
///
30+
/// The pushforward/backward operations are fast for (1) but incur some obvious
31+
/// overhead for situation (2).
32+
///
33+
class MapRef final {
34+
public:
35+
MapRef(uint64_t d, uint64_t l, const uint64_t *d2l, const uint64_t *l2d);
36+
37+
//
38+
// Push forward maps from dimensions to levels.
39+
//
40+
41+
template <typename T>
42+
inline void pushforward(const T *in, T *out) const {
43+
if (isPermutation) {
44+
for (uint64_t i = 0; i < lvlRank; ++i)
45+
out[i] = in[lvl2dim[i]];
46+
} else {
47+
assert(0 && "coming soon");
48+
}
49+
}
50+
51+
//
52+
// Push backward maps from levels to dimensions.
53+
//
54+
55+
template <typename T>
56+
inline void pushbackward(const T *in, T *out) const {
57+
if (isPermutation) {
58+
for (uint64_t i = 0; i < dimRank; ++i)
59+
out[i] = in[dim2lvl[i]];
60+
} else {
61+
assert(0 && "coming soon");
62+
}
63+
}
64+
65+
uint64_t getDimRank() const { return dimRank; }
66+
uint64_t getLvlRank() const { return lvlRank; }
67+
68+
private:
69+
bool isPermutationMap() const;
70+
71+
const uint64_t dimRank;
72+
const uint64_t lvlRank;
73+
const uint64_t *const dim2lvl; // non-owning pointer
74+
const uint64_t *const lvl2dim; // non-owning pointer
75+
const bool isPermutation;
76+
};
77+
78+
} // namespace sparse_tensor
79+
} // namespace mlir
80+
81+
#endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_MAPREF_H

0 commit comments

Comments
 (0)