Skip to content

Commit fede6b0

Browse files
Merge pull request tensor-compiler#391 from tensor-compiler/assembly-v2
Partial support for parallel assembly with compressed formats
2 parents f111251 + 62498ea commit fede6b0

39 files changed

+2196
-312
lines changed

apps/tensor_times_vector/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cmake_minimum_required(VERSION 2.8)
1+
cmake_minimum_required(VERSION 2.8.12)
22
project(tensor_times_vector)
33

44
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")

include/taco/codegen/module.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <vector>
66
#include <string>
77
#include <utility>
8+
#include <random>
89

910
#include "taco/target.h"
1011
#include "taco/ir/ir.h"
@@ -80,6 +81,10 @@ class Module {
8081

8182
void setJITLibname();
8283
void setJITTmpdir();
84+
85+
static std::string chars;
86+
static std::default_random_engine gen;
87+
static std::uniform_int_distribution<int> randint;
8388
};
8489

8590
} // namespace ir

include/taco/format.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ namespace taco {
1212
class ModeFormat;
1313
class ModeFormatPack;
1414
class ModeFormatImpl;
15+
class AttrQuery;
16+
class IndexVar;
1517

1618

1719
/// A Format describes the data layout of a tensor, and the sparse index data
@@ -135,6 +137,16 @@ class ModeFormat {
135137
bool hasInsert() const;
136138
bool hasAppend() const;
137139

140+
/// Returns true if a mode format has ungrouped insertion functions with
141+
/// specific attributes, false otherwise
142+
bool hasSeqInsertEdge() const;
143+
bool hasInsertCoord() const;
144+
bool isYieldPosPure() const;
145+
146+
std::vector<AttrQuery> getAttrQueries(
147+
std::vector<IndexVar> parentCoords,
148+
std::vector<IndexVar> childCoords) const;
149+
138150
/// Returns true if mode format is defined, false otherwise. An undefined mode
139151
/// type can be used to indicate a mode whose format is not (yet) known.
140152
bool defined() const;

include/taco/index_notation/index_notation.h

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "taco/index_notation/intrinsic.h"
2121
#include "taco/index_notation/index_notation_nodes_abstract.h"
2222
#include "taco/ir_tags.h"
23-
#include "taco/lower/iterator.h"
2423
#include "taco/index_notation/provenance_graph.h"
2524

2625
namespace taco {
@@ -57,6 +56,7 @@ struct YieldNode;
5756
struct ForallNode;
5857
struct WhereNode;
5958
struct SequenceNode;
59+
struct AssembleNode;
6060
struct MultiNode;
6161
struct SuchThatNode;
6262

@@ -224,16 +224,22 @@ class Access : public IndexExpr {
224224
Access() = default;
225225
Access(const Access&) = default;
226226
Access(const AccessNode*);
227-
Access(const TensorVar& tensorVar,
228-
const std::vector<IndexVar>& indices={},
229-
const std::map<int, std::shared_ptr<IndexVarIterationModifier>>& modifiers={});
227+
Access(const TensorVar& tensorVar, const std::vector<IndexVar>& indices={},
228+
const std::map<int, std::shared_ptr<IndexVarIterationModifier>>& modifiers={},
229+
bool isAccessingStructure=false);
230230

231231
/// Return the Access expression's TensorVar.
232232
const TensorVar &getTensorVar() const;
233233

234234
/// Returns the index variables used to index into the Access's TensorVar.
235235
const std::vector<IndexVar>& getIndexVars() const;
236236

237+
/// Returns whether access expression returns sparsity pattern of tensor.
238+
/// If true, the access expression returns 1 for every physically stored
239+
/// component. If false, the access expression returns the value that is
240+
/// stored for each corresponding component.
241+
bool isAccessingStructure() const;
242+
237243
/// hasWindowedModes returns true if any accessed modes are windowed.
238244
bool hasWindowedModes() const;
239245

@@ -675,6 +681,8 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
675681
/// integer number of iterations
676682
/// Preconditions: unrollFactor is a positive nonzero integer
677683
IndexStmt unroll(IndexVar i, size_t unrollFactor) const;
684+
685+
IndexStmt assemble(TensorVar result, AssembleStrategy strategy) const;
678686
};
679687

680688
/// Check if two index statements are isomorphic.
@@ -818,6 +826,27 @@ class Sequence : public IndexStmt {
818826
Sequence sequence(IndexStmt definition, IndexStmt mutation);
819827

820828

829+
class Assemble : public IndexStmt {
830+
public:
831+
typedef std::map<TensorVar,std::vector<std::vector<TensorVar>>> AttrQueryResults;
832+
833+
Assemble() = default;
834+
Assemble(const AssembleNode*);
835+
Assemble(IndexStmt queries, IndexStmt compute, AttrQueryResults results);
836+
837+
IndexStmt getQueries() const;
838+
IndexStmt getCompute() const;
839+
840+
const AttrQueryResults& getAttrQueryResults() const;
841+
842+
typedef AssembleNode Node;
843+
};
844+
845+
/// Create an assemble index statement.
846+
Assemble assemble(IndexStmt queries, IndexStmt compute,
847+
Assemble::AttrQueryResults results);
848+
849+
821850
/// A multi statement has two statements that are executed separately, and let
822851
/// us compute more than one tensor in a concrete index notation statement.
823852
class Multi : public IndexStmt {
@@ -1095,10 +1124,18 @@ std::vector<TensorVar> getArguments(IndexStmt stmt);
10951124
/// Returns the temporaries in the index statement, in the order they appear.
10961125
std::vector<TensorVar> getTemporaries(IndexStmt stmt);
10971126

1127+
/// Returns the attribute query results in the index statement, in the order
1128+
/// they appear.
1129+
std::vector<TensorVar> getAttrQueryResults(IndexStmt stmt);
1130+
10981131
// [Olivia]
10991132
/// Returns the temporaries in the index statement, in the order they appear.
11001133
std::map<Forall, Where> getTemporaryLocations(IndexStmt stmt);
11011134

1135+
/// Returns the results in the index statement that should be assembled by
1136+
/// ungrouped insertion.
1137+
std::vector<TensorVar> getAssembledByUngroupedInsertion(IndexStmt stmt);
1138+
11021139
/// Returns the tensors in the index statement.
11031140
std::vector<TensorVar> getTensorVars(IndexStmt stmt);
11041141

@@ -1123,6 +1160,10 @@ std::vector<ir::Expr> createVars(const std::vector<TensorVar>& tensorVars,
11231160
std::map<TensorVar, ir::Expr>* vars,
11241161
bool isParameter=false);
11251162

1163+
/// Convert index notation tensor variables in the index statement to IR
1164+
/// pointer variables.
1165+
std::map<TensorVar,ir::Expr> createIRTensorVars(IndexStmt stmt);
1166+
11261167

11271168
/// Simplify an index expression by setting the zeroed Access expressions to
11281169
/// zero and then propagating and removing zeroes.

include/taco/index_notation/index_notation_nodes.h

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,12 @@ struct IndexSet : IndexVarIterationModifier {
7171
};
7272

7373
struct AccessNode : public IndexExprNode {
74-
AccessNode(TensorVar tensorVar,
75-
const std::vector<IndexVar> &indices,
76-
const std::map<int, std::shared_ptr<IndexVarIterationModifier>> &modifiers = {})
77-
: IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar), indexVars(indices) {
74+
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices,
75+
const std::map<int, std::shared_ptr<IndexVarIterationModifier>> &modifiers,
76+
bool isAccessingStructure)
77+
: IndexExprNode(isAccessingStructure ? Bool : tensorVar.getType().getDataType()),
78+
tensorVar(tensorVar), indexVars(indices),
79+
isAccessingStructure(isAccessingStructure) {
7880
// Unpack the input modifiers into the appropriate maps for each mode.
7981
for (auto &it : modifiers) {
8082
IndexVarIterationModifier::match(it.second, [&](std::shared_ptr<AccessWindow> w) {
@@ -108,11 +110,14 @@ struct AccessNode : public IndexExprNode {
108110
std::vector<IndexVar> indexVars;
109111
std::map<int, AccessWindow> windowedModes;
110112
std::map<int, IndexSet> indexSetModes;
113+
bool isAccessingStructure;
111114

112115
protected:
113116
/// Initialize an AccessNode with just a TensorVar. If this constructor is used,
114117
/// then indexVars must be set afterwards.
115-
explicit AccessNode(TensorVar tensorVar) : IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar) {}
118+
explicit AccessNode(TensorVar tensorVar) :
119+
IndexExprNode(tensorVar.getType().getDataType()),
120+
tensorVar(tensorVar), isAccessingStructure(false) {}
116121
};
117122

118123
struct LiteralNode : public IndexExprNode {
@@ -360,6 +365,20 @@ struct SequenceNode : public IndexStmtNode {
360365
IndexStmt mutation;
361366
};
362367

368+
struct AssembleNode : public IndexStmtNode {
369+
AssembleNode(IndexStmt queries, IndexStmt compute,
370+
Assemble::AttrQueryResults results)
371+
: queries(queries), compute(compute), results(results) {}
372+
373+
void accept(IndexStmtVisitorStrict* v) const {
374+
v->visit(this);
375+
}
376+
377+
IndexStmt queries;
378+
IndexStmt compute;
379+
Assemble::AttrQueryResults results;
380+
};
381+
363382

364383
/// Returns true if expression e is of type E.
365384
template <typename E>

include/taco/index_notation/index_notation_printer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class IndexNotationPrinter : public IndexNotationVisitorStrict {
3535
void visit(const WhereNode*);
3636
void visit(const MultiNode*);
3737
void visit(const SequenceNode*);
38+
void visit(const AssembleNode*);
3839
void visit(const SuchThatNode*);
3940

4041
private:

include/taco/index_notation/index_notation_rewriter.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class IndexStmtRewriterStrict : public IndexStmtVisitorStrict {
5656
virtual void visit(const ForallNode* op) = 0;
5757
virtual void visit(const WhereNode* op) = 0;
5858
virtual void visit(const SequenceNode* op) = 0;
59+
virtual void visit(const AssembleNode* op) = 0;
5960
virtual void visit(const MultiNode* op) = 0;
6061
virtual void visit(const SuchThatNode* op) = 0;
6162
};
@@ -101,6 +102,7 @@ class IndexNotationRewriter : public IndexNotationRewriterStrict {
101102
virtual void visit(const ForallNode* op);
102103
virtual void visit(const WhereNode* op);
103104
virtual void visit(const SequenceNode* op);
105+
virtual void visit(const AssembleNode* op);
104106
virtual void visit(const MultiNode* op);
105107
virtual void visit(const SuchThatNode* op);
106108
};
@@ -126,5 +128,9 @@ IndexStmt replace(IndexStmt stmt,
126128
IndexStmt replace(IndexStmt stmt,
127129
const std::map<TensorVar,TensorVar>& substitutions);
128130

131+
/// Rewrites the statement to replace an index variable with a new variable.
132+
IndexStmt replace(IndexStmt stmt,
133+
const std::map<IndexVar,IndexVar>& substitutions);
134+
129135
}
130136
#endif

include/taco/index_notation/index_notation_visitor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ struct ForallNode;
3131
struct WhereNode;
3232
struct MultiNode;
3333
struct SequenceNode;
34+
struct AssembleNode;
3435
struct SuchThatNode;
3536

3637
/// Visit the nodes in an expression. This visitor provides some type safety
@@ -65,6 +66,7 @@ class IndexStmtVisitorStrict {
6566
virtual void visit(const ForallNode*) = 0;
6667
virtual void visit(const WhereNode*) = 0;
6768
virtual void visit(const SequenceNode*) = 0;
69+
virtual void visit(const AssembleNode*) = 0;
6870
virtual void visit(const MultiNode*) = 0;
6971
virtual void visit(const SuchThatNode*) = 0;
7072
};
@@ -107,6 +109,7 @@ class IndexNotationVisitor : public IndexNotationVisitorStrict {
107109
virtual void visit(const ForallNode* node);
108110
virtual void visit(const WhereNode* node);
109111
virtual void visit(const SequenceNode* node);
112+
virtual void visit(const AssembleNode* node);
110113
virtual void visit(const MultiNode* node);
111114
virtual void visit(const SuchThatNode* node);
112115
};
@@ -176,6 +179,7 @@ class Matcher : public IndexNotationVisitor {
176179
RULE(WhereNode)
177180
RULE(MultiNode)
178181
RULE(SequenceNode)
182+
RULE(AssembleNode)
179183
RULE(SuchThatNode)
180184
};
181185

include/taco/index_notation/provenance_graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef TACO_PROVENANCE_GRAPH_H
22
#define TACO_PROVENANCE_GRAPH_H
33

4+
#include "taco/lower/iterator.h"
5+
46
namespace taco {
57
struct IndexVarRelNode;
68
enum IndexVarRelType {UNDEFINED, SPLIT, DIVIDE, POS, FUSE, BOUND, PRECOMPUTE};

include/taco/index_notation/transformations.h

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class ForAllReplace;
2121
class AddSuchThatPredicates;
2222
class Parallelize;
2323
class TopoReorder;
24+
class SetAssembleStrategy;
2425

2526
/// A transformation is an optimization that transforms a statement in the
2627
/// concrete index notation into a new statement that computes the same result
@@ -34,6 +35,7 @@ class Transformation {
3435
Transformation(Parallelize);
3536
Transformation(TopoReorder);
3637
Transformation(AddSuchThatPredicates);
38+
Transformation(SetAssembleStrategy);
3739

3840
IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;
3941

@@ -108,6 +110,7 @@ class Precompute : public TransformationInterface {
108110
/// Print a precompute command.
109111
std::ostream &operator<<(std::ostream &, const Precompute &);
110112

113+
111114
/// Replaces all occurrences of directly nested forall nodes of pattern with
112115
/// directly nested loops of replacement
113116
class ForAllReplace : public TransformationInterface {
@@ -129,6 +132,10 @@ class ForAllReplace : public TransformationInterface {
129132
std::shared_ptr<Content> content;
130133
};
131134

135+
/// Print a ForAllReplace command.
136+
std::ostream &operator<<(std::ostream &, const ForAllReplace &);
137+
138+
132139
/// Adds a SuchThat node if it does not exist and adds the given IndexVarRels
133140
class AddSuchThatPredicates : public TransformationInterface {
134141
public:
@@ -147,6 +154,9 @@ class AddSuchThatPredicates : public TransformationInterface {
147154
std::shared_ptr<Content> content;
148155
};
149156

157+
std::ostream& operator<<(std::ostream&, const AddSuchThatPredicates&);
158+
159+
150160
/// The parallelize optimization tags a Forall as parallelized
151161
/// after checking for preconditions
152162
class Parallelize : public TransformationInterface {
@@ -169,13 +179,28 @@ class Parallelize : public TransformationInterface {
169179
std::shared_ptr<Content> content;
170180
};
171181

172-
/// Print a ForAllReplace command.
173-
std::ostream &operator<<(std::ostream &, const ForAllReplace &);
174-
175182
/// Print a parallelize command.
176183
std::ostream& operator<<(std::ostream&, const Parallelize&);
177184

178-
std::ostream& operator<<(std::ostream&, const AddSuchThatPredicates&);
185+
186+
class SetAssembleStrategy : public TransformationInterface {
187+
public:
188+
SetAssembleStrategy(TensorVar result, AssembleStrategy strategy);
189+
190+
TensorVar getResult() const;
191+
AssembleStrategy getAssembleStrategy() const;
192+
193+
IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;
194+
195+
void print(std::ostream &os) const;
196+
197+
private:
198+
struct Content;
199+
std::shared_ptr<Content> content;
200+
};
201+
202+
/// Print a SetAssembleStrategy command.
203+
std::ostream &operator<<(std::ostream &, const SetAssembleStrategy&);
179204

180205
// Autoscheduling functions
181206

@@ -207,5 +232,6 @@ IndexStmt scalarPromote(IndexStmt stmt);
207232
* 1. The result is a is scattered into but does not support random insert.
208233
*/
209234
IndexStmt insertTemporaries(IndexStmt stmt);
235+
210236
}
211237
#endif

0 commit comments

Comments
 (0)