20
20
#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_FILE_H
21
21
#define MLIR_EXECUTIONENGINE_SPARSETENSOR_FILE_H
22
22
23
+ #include " mlir/ExecutionEngine/SparseTensor/MapRef.h"
23
24
#include " mlir/ExecutionEngine/SparseTensor/Storage.h"
24
25
25
26
#include < fstream>
@@ -75,6 +76,10 @@ inline V readValue(char **linePtr, bool isPattern) {
75
76
76
77
} // namespace detail
77
78
79
+ // ===----------------------------------------------------------------------===//
80
+ //
81
+ // Reader class.
82
+ //
78
83
// ===----------------------------------------------------------------------===//
79
84
80
85
// / This class abstracts over the information stored in file headers,
@@ -132,6 +137,7 @@ class SparseTensorReader final {
132
137
// / Reads and parses the file's header.
133
138
void readHeader ();
134
139
140
+ // / Returns the stored value kind.
135
141
ValueKind getValueKind () const { return valueKind_; }
136
142
137
143
// / Checks if a header has been successfully read.
@@ -185,58 +191,37 @@ class SparseTensorReader final {
185
191
// / valid after parsing the header.
186
192
void assertMatchesShape (uint64_t rank, const uint64_t *shape) const ;
187
193
188
- // / Reads a sparse tensor element from the next line in the input file and
189
- // / returns the value of the element. Stores the coordinates of the element
190
- // / to the `dimCoords` array.
191
- template <typename V>
192
- V readElement (uint64_t dimRank, uint64_t *dimCoords) {
193
- assert (dimRank == getRank () && " rank mismatch" );
194
- char *linePtr = readCoords (dimCoords);
195
- return detail::readValue<V>(&linePtr, isPattern ());
196
- }
197
-
198
- // / Allocates a new COO object for `lvlSizes`, initializes it by reading
199
- // / all the elements from the file and applying `dim2lvl` to their
200
- // / dim-coordinates, and then closes the file. Templated on V only.
201
- template <typename V>
202
- SparseTensorCOO<V> *readCOO (uint64_t lvlRank, const uint64_t *lvlSizes,
203
- const uint64_t *dim2lvl);
204
-
205
194
// / Allocates a new sparse-tensor storage object with the given encoding,
206
195
// / initializes it by reading all the elements from the file, and then
207
196
// / closes the file. Templated on P, I, and V.
208
197
template <typename P, typename I, typename V>
209
198
SparseTensorStorage<P, I, V> *
210
199
readSparseTensor (uint64_t lvlRank, const uint64_t *lvlSizes,
211
- const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
212
- const uint64_t *dim2lvl) {
213
- auto *lvlCOO = readCOO<V>(lvlRank, lvlSizes, dim2lvl);
200
+ const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
201
+ const uint64_t *lvl2dim) {
202
+ const uint64_t dimRank = getRank ();
203
+ MapRef map (dimRank, lvlRank, dim2lvl, lvl2dim);
204
+ auto *coo = readCOO<V>(map, lvlSizes);
214
205
auto *tensor = SparseTensorStorage<P, I, V>::newFromCOO (
215
- getRank () , getDimSizes (), lvlRank, lvlTypes, lvl2dim, *lvlCOO );
216
- delete lvlCOO ;
206
+ dimRank , getDimSizes (), lvlRank, lvlTypes, lvl2dim, *coo );
207
+ delete coo ;
217
208
return tensor;
218
209
}
219
210
220
211
// / Reads the COO tensor from the file, stores the coordinates and values to
221
212
// / the given buffers, returns a boolean value to indicate whether the COO
222
213
// / elements are sorted.
223
- // / Precondition: the buffers should have enough space to hold the elements.
224
214
template <typename C, typename V>
225
215
bool readToBuffers (uint64_t lvlRank, const uint64_t *dim2lvl,
226
- C *lvlCoordinates, V *values);
216
+ const uint64_t *lvl2dim, C *lvlCoordinates, V *values);
227
217
228
218
private:
229
- // / Attempts to read a line from the file. Is private because there's
230
- // / no reason for client code to call it.
219
+ // / Attempts to read a line from the file.
231
220
void readLine ();
232
221
233
222
// / Reads the next line of the input file and parses the coordinates
234
223
// / into the `dimCoords` argument. Returns the position in the `line`
235
- // / buffer where the element's value should be parsed from. This method
236
- // / has been factored out from `readElement` to minimize code bloat
237
- // / for the generated library.
238
- // /
239
- // / Precondition: `dimCoords` is valid for `getRank()`.
224
+ // / buffer where the element's value should be parsed from.
240
225
template <typename C>
241
226
char *readCoords (C *dimCoords) {
242
227
readLine ();
@@ -251,24 +236,20 @@ class SparseTensorReader final {
251
236
return linePtr;
252
237
}
253
238
254
- // / The internal implementation of `readCOO`. We template over
255
- // / `IsPattern` in order to perform LICM without needing to duplicate the
256
- // / source code.
257
- //
258
- // TODO: We currently take the `dim2lvl` argument as a `PermutationRef`
259
- // since that's what `readCOO` creates. Once we update `readCOO` to
260
- // functionalize the mapping, then this helper will just take that
261
- // same function.
239
+ // / Reads all the elements from the file while applying the given map.
240
+ template <typename V>
241
+ SparseTensorCOO<V> *readCOO (const MapRef &map, const uint64_t *lvlSizes);
242
+
243
+ // / The implementation of `readCOO` that is templated `IsPattern` in order
244
+ // / to perform LICM without needing to duplicate the source code.
262
245
template <typename V, bool IsPattern>
263
- void readCOOLoop (uint64_t lvlRank, detail::PermutationRef dim2lvl,
264
- SparseTensorCOO<V> *lvlCOO);
246
+ void readCOOLoop (const MapRef &map, SparseTensorCOO<V> *coo);
265
247
266
- // / The internal implementation of `readToBuffers`. We template over
267
- // / `IsPattern` in order to perform LICM without needing to duplicate the
268
- // / source code.
248
+ // / The internal implementation of `readToBuffers`. We template over
249
+ // / `IsPattern` in order to perform LICM without needing to duplicate
250
+ // / the source code.
269
251
template <typename C, typename V, bool IsPattern>
270
- bool readToBuffersLoop (uint64_t lvlRank, detail::PermutationRef dim2lvl,
271
- C *lvlCoordinates, V *values);
252
+ bool readToBuffersLoop (const MapRef &map, C *lvlCoordinates, V *values);
272
253
273
254
// / Reads the MME header of a general sparse matrix of type real.
274
255
void readMMEHeader ();
@@ -288,96 +269,76 @@ class SparseTensorReader final {
288
269
char line[kColWidth ];
289
270
};
290
271
272
+ // ===----------------------------------------------------------------------===//
273
+ //
274
+ // Reader class methods.
275
+ //
291
276
// ===----------------------------------------------------------------------===//
292
277
293
278
template <typename V>
294
- SparseTensorCOO<V> *SparseTensorReader::readCOO (uint64_t lvlRank,
295
- const uint64_t *lvlSizes,
296
- const uint64_t *dim2lvl) {
279
+ SparseTensorCOO<V> *SparseTensorReader::readCOO (const MapRef &map,
280
+ const uint64_t *lvlSizes) {
297
281
assert (isValid () && " Attempt to readCOO() before readHeader()" );
298
- const uint64_t dimRank = getRank ();
299
- assert (lvlRank == dimRank && " Rank mismatch" );
300
- detail::PermutationRef d2l (dimRank, dim2lvl);
301
282
// Prepare a COO object with the number of stored elems as initial capacity.
302
- auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, getNSE ());
303
- // Do some manual LICM, to avoid assertions in the for-loop.
304
- const bool IsPattern = isPattern ();
305
- if (IsPattern)
306
- readCOOLoop<V, true >(lvlRank, d2l, lvlCOO);
283
+ auto *coo = new SparseTensorCOO<V>(map.getLvlRank (), lvlSizes, getNSE ());
284
+ // Enter the reading loop.
285
+ if (isPattern ())
286
+ readCOOLoop<V, true >(map, coo);
307
287
else
308
- readCOOLoop<V, false >(lvlRank, d2l, lvlCOO );
288
+ readCOOLoop<V, false >(map, coo );
309
289
// Close the file and return the COO.
310
290
closeFile ();
311
- return lvlCOO ;
291
+ return coo ;
312
292
}
313
293
314
294
template <typename V, bool IsPattern>
315
- void SparseTensorReader::readCOOLoop (uint64_t lvlRank,
316
- detail::PermutationRef dim2lvl,
317
- SparseTensorCOO<V> *lvlCOO) {
318
- const uint64_t dimRank = getRank ();
295
+ void SparseTensorReader::readCOOLoop (const MapRef &map,
296
+ SparseTensorCOO<V> *coo) {
297
+ const uint64_t dimRank = map.getDimRank ();
298
+ const uint64_t lvlRank = map.getLvlRank ();
299
+ assert (dimRank == getRank ());
319
300
std::vector<uint64_t > dimCoords (dimRank);
320
301
std::vector<uint64_t > lvlCoords (lvlRank);
321
- for (uint64_t nse = getNSE (), k = 0 ; k < nse; ++k) {
322
- // We inline `readElement` here in order to avoid redundant
323
- // assertions, since they're guaranteed by the call to `isValid()`
324
- // and the construction of `dimCoords` above.
302
+ for (uint64_t k = 0 , nse = getNSE (); k < nse; k++) {
325
303
char *linePtr = readCoords (dimCoords.data ());
326
304
const V value = detail::readValue<V, IsPattern>(&linePtr);
327
- dim2lvl.pushforward (dimRank, dimCoords.data (), lvlCoords.data ());
328
- // TODO: <https://github.com/llvm/llvm-project/issues/54179>
329
- lvlCOO->add (lvlCoords, value);
305
+ map.pushforward (dimCoords.data (), lvlCoords.data ());
306
+ coo->add (lvlCoords, value);
330
307
}
331
308
}
332
309
333
310
template <typename C, typename V>
334
311
bool SparseTensorReader::readToBuffers (uint64_t lvlRank,
335
312
const uint64_t *dim2lvl,
313
+ const uint64_t *lvl2dim,
336
314
C *lvlCoordinates, V *values) {
337
315
assert (isValid () && " Attempt to readCOO() before readHeader()" );
338
- // Construct a `PermutationRef` for the `pushforward` below.
339
- // TODO: This specific implementation does not generalize to arbitrary
340
- // mappings, but once we functionalize the `dim2lvl` argument we can
341
- // simply use that function instead.
342
- const uint64_t dimRank = getRank ();
343
- assert (lvlRank == dimRank && " Rank mismatch" );
344
- detail::PermutationRef d2l (dimRank, dim2lvl);
345
- // Do some manual LICM, to avoid assertions in the for-loop.
316
+ MapRef map (getRank (), lvlRank, dim2lvl, lvl2dim);
346
317
bool isSorted =
347
- isPattern ()
348
- ? readToBuffersLoop<C, V, true >(lvlRank, d2l, lvlCoordinates, values)
349
- : readToBuffersLoop<C, V, false >(lvlRank, d2l, lvlCoordinates,
350
- values);
351
-
352
- // Close the file and return isSorted.
318
+ isPattern () ? readToBuffersLoop<C, V, true >(map, lvlCoordinates, values)
319
+ : readToBuffersLoop<C, V, false >(map, lvlCoordinates, values);
353
320
closeFile ();
354
321
return isSorted;
355
322
}
356
323
357
324
template <typename C, typename V, bool IsPattern>
358
- bool SparseTensorReader::readToBuffersLoop (uint64_t lvlRank ,
359
- detail::PermutationRef dim2lvl,
360
- C *lvlCoordinates, V *values) {
361
- const uint64_t dimRank = getRank ();
325
+ bool SparseTensorReader::readToBuffersLoop (const MapRef &map, C *lvlCoordinates ,
326
+ V *values) {
327
+ const uint64_t dimRank = map. getDimRank ();
328
+ const uint64_t lvlRank = map. getLvlRank ();
362
329
const uint64_t nse = getNSE ();
330
+ assert (dimRank == getRank ());
363
331
std::vector<C> dimCoords (dimRank);
364
- // Read the first element with isSorted=false as a way to avoid accessing its
365
- // previous element.
366
332
bool isSorted = false ;
367
333
char *linePtr;
368
- // We inline `readElement` here in order to avoid redundant assertions,
369
- // since they're guaranteed by the call to `isValid()` and the construction
370
- // of `dimCoords` above.
371
334
const auto readNextElement = [&]() {
372
335
linePtr = readCoords<C>(dimCoords.data ());
373
- dim2lvl .pushforward (dimRank, dimCoords.data (), lvlCoordinates);
336
+ map .pushforward (dimCoords.data (), lvlCoordinates);
374
337
*values = detail::readValue<V, IsPattern>(&linePtr);
375
338
if (isSorted) {
376
- // Note that isSorted was set to false while reading the first element,
339
+ // Note that isSorted is set to false when reading the first element,
377
340
// to guarantee the safeness of using prevLvlCoords.
378
341
C *prevLvlCoords = lvlCoordinates - lvlRank;
379
- // TODO: define a new CoordsLT which is like ElementLT but doesn't have
380
- // the V parameter, and use it here.
381
342
for (uint64_t l = 0 ; l < lvlRank; ++l) {
382
343
if (prevLvlCoords[l] != lvlCoordinates[l]) {
383
344
if (prevLvlCoords[l] > lvlCoordinates[l])
@@ -393,7 +354,6 @@ bool SparseTensorReader::readToBuffersLoop(uint64_t lvlRank,
393
354
isSorted = true ;
394
355
for (uint64_t n = 1 ; n < nse; ++n)
395
356
readNextElement ();
396
-
397
357
return isSorted;
398
358
}
399
359
0 commit comments