Skip to content

Commit c004447

Browse files
committed
Enable parallelization of forall statements with results assembled by ungrouped insertion
1 parent 295b1b9 commit c004447

File tree

6 files changed

+103
-24
lines changed

6 files changed

+103
-24
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,10 @@ std::vector<ir::Expr> createVars(const std::vector<TensorVar>& tensorVars,
10111011
std::map<TensorVar, ir::Expr>* vars,
10121012
bool isParameter=false);
10131013

1014+
/// Convert index notation tensor variables in the index statement to IR
1015+
/// pointer variables.
1016+
std::map<TensorVar,ir::Expr> createIRTensorVars(IndexStmt stmt);
1017+
10141018

10151019
/// Simplify an index expression by setting the zeroed Access expressions to
10161020
/// zero and then propagating and removing zeroes.

include/taco/index_notation/index_notation_rewriter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,5 +128,9 @@ IndexStmt replace(IndexStmt stmt,
128128
IndexStmt replace(IndexStmt stmt,
129129
const std::map<TensorVar,TensorVar>& substitutions);
130130

131+
/// Rewrites the statement to replace an index variable with a new variable.
132+
IndexStmt replace(IndexStmt stmt,
133+
const std::map<IndexVar,IndexVar>& substitutions);
134+
131135
}
132136
#endif

src/index_notation/index_notation.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2653,6 +2653,25 @@ vector<ir::Expr> createVars(const vector<TensorVar>& tensorVars,
26532653
return irVars;
26542654
}
26552655

2656+
std::map<TensorVar,ir::Expr> createIRTensorVars(IndexStmt stmt)
2657+
{
2658+
std::map<TensorVar,ir::Expr> tensorVars;
2659+
2660+
// Create result and parameter variables
2661+
vector<TensorVar> results = getResults(stmt);
2662+
vector<TensorVar> arguments = getArguments(stmt);
2663+
vector<TensorVar> temporaries = getTemporaries(stmt);
2664+
2665+
// Convert tensor results, arguments and temporaries to IR variables
2666+
map<TensorVar, ir::Expr> resultVars;
2667+
vector<ir::Expr> resultsIR = createVars(results, &resultVars);
2668+
tensorVars.insert(resultVars.begin(), resultVars.end());
2669+
vector<ir::Expr> argumentsIR = createVars(arguments, &tensorVars);
2670+
vector<ir::Expr> temporariesIR = createVars(temporaries, &tensorVars);
2671+
2672+
return tensorVars;
2673+
}
2674+
26562675
struct Zero : public IndexNotationRewriterStrict {
26572676
public:
26582677
Zero(const set<Access>& zeroed) : zeroed(zeroed) {}

src/index_notation/index_notation_rewriter.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,29 @@ struct ReplaceIndexVars : public IndexNotationRewriter {
349349
}
350350
}
351351

352-
// TODO: Replace in assignments
352+
void visit(const AssignmentNode* op) {
353+
IndexExpr rhs = rewrite(op->rhs);
354+
Access lhs = to<Access>(rewrite(op->lhs));
355+
if (rhs == op->rhs && lhs == op->lhs) {
356+
stmt = op;
357+
}
358+
else {
359+
stmt = new AssignmentNode(lhs, rhs, op->op);
360+
}
361+
}
362+
363+
void visit(const ForallNode* op) {
364+
IndexStmt s = rewrite(op->stmt);
365+
IndexVar iv = util::contains(substitutions, op->indexVar)
366+
? substitutions.at(op->indexVar) : op->indexVar;
367+
if (s == op->stmt && iv == op->indexVar) {
368+
stmt = op;
369+
}
370+
else {
371+
stmt = new ForallNode(iv, s, op->parallel_unit, op->output_race_strategy,
372+
op->unrollFactor);
373+
}
374+
}
353375
};
354376

355377
struct ReplaceTensorVars : public IndexNotationRewriter {
@@ -404,4 +426,9 @@ IndexStmt replace(IndexStmt stmt,
404426
return ReplaceTensorVars(substitutions).rewrite(stmt);
405427
}
406428

429+
IndexStmt replace(IndexStmt stmt,
430+
const std::map<IndexVar,IndexVar>& substitutions) {
431+
return ReplaceIndexVars(substitutions).rewrite(stmt);
432+
}
433+
407434
}

src/index_notation/transformations.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,20 +483,27 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const {
483483

484484
Parallelize parallelize;
485485
ProvenanceGraph provGraph;
486+
map<TensorVar,ir::Expr> tensorVars;
487+
vector<ir::Expr> assembledByUngroupedInsert;
486488
set<IndexVar> definedIndexVars;
487489
set<ParallelUnit> parentParallelUnits;
488490
std::string reason = "";
489491

490492
IndexStmt rewriteParallel(IndexStmt stmt) {
491493
provGraph = ProvenanceGraph(stmt);
494+
tensorVars = createIRTensorVars(stmt);
495+
assembledByUngroupedInsert.clear();
496+
for (const auto& result : getAssembledByUngroupedInsertion(stmt)) {
497+
assembledByUngroupedInsert.push_back(tensorVars[result]);
498+
}
492499
return rewrite(stmt);
493500
}
494501

495502
void visit(const ForallNode* node) {
496503
Forall foralli(node);
497504
IndexVar i = parallelize.geti();
498505

499-
Iterators iterators(foralli);
506+
Iterators iterators(foralli, tensorVars);
500507
definedIndexVars.insert(foralli.getIndexVar());
501508
MergeLattice lattice = MergeLattice::make(foralli, iterators, provGraph, definedIndexVars);
502509
// Precondition 3: No parallelization of variables under a reduction
@@ -517,6 +524,9 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const {
517524

518525
// Precondition 2: Every result iterator must have insert capability
519526
for (Iterator iterator : lattice.results()) {
527+
if (util::contains(assembledByUngroupedInsert, iterator.getTensor())) {
528+
continue;
529+
}
520530
while (true) {
521531
if (!iterator.hasInsert()) {
522532
reason = "Precondition failed: The output tensor must allow inserts";
@@ -614,6 +624,34 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const {
614624
}
615625
IndexNotationRewriter::visit(node);
616626
}
627+
628+
void visit(const AssembleNode* op) {
629+
IndexVar i = parallelize.geti();
630+
IndexStmt queries = util::contains(op->queries.getIndexVars(), i)
631+
? rewrite(op->queries) : op->queries;
632+
IndexStmt compute = util::contains(op->compute.getIndexVars(), i)
633+
? rewrite(op->compute) : op->compute;
634+
if (queries == op->queries && compute == op->compute) {
635+
stmt = op;
636+
}
637+
else {
638+
stmt = new AssembleNode(queries, compute, op->results);
639+
}
640+
}
641+
642+
void visit(const WhereNode* op) {
643+
IndexVar i = parallelize.geti();
644+
IndexStmt producer = util::contains(op->producer.getIndexVars(), i)
645+
? rewrite(op->producer) : op->producer;
646+
IndexStmt consumer = util::contains(op->consumer.getIndexVars(), i)
647+
? rewrite(op->consumer) : op->consumer;
648+
if (producer == op->producer && consumer == op->consumer) {
649+
stmt = op;
650+
}
651+
else {
652+
stmt = new WhereNode(consumer, producer);
653+
}
654+
}
617655
};
618656

619657
ParallelizeRewriter rewriter;
@@ -663,6 +701,12 @@ AssembleStrategy SetAssembleStrategy::getAssembleStrategy() const {
663701
IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
664702
INIT_REASON(reason);
665703

704+
std::map<IndexVar,IndexVar> ivReplacements;
705+
for (const auto& indexVar : getIndexVars(stmt)) {
706+
ivReplacements[indexVar] = IndexVar("q" + indexVar.getName());
707+
}
708+
IndexStmt loweredQueries = replace(stmt, ivReplacements);
709+
666710
// Tracks all tensors that correspond to attribute query results or that are
667711
// used to compute attribute queries
668712
std::set<TensorVar> insertedResults;
@@ -812,8 +856,8 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
812856
expr = op;
813857
}
814858
};
815-
IndexStmt loweredQueries =
816-
LowerAttrQuery(queryResults, insertedResults).lower(stmt);
859+
loweredQueries =
860+
LowerAttrQuery(queryResults, insertedResults).lower(loweredQueries);
817861
std::cout << loweredQueries << std::endl;
818862

819863
struct ReduceToAssign : public IndexNotationRewriter {
@@ -956,6 +1000,7 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
9561000
std::cout << loweredQueries << std::endl;
9571001

9581002
//loweredQueries = parallelizeOuterLoop(loweredQueries);
1003+
//stmt = parallelizeOuterLoop(stmt);
9591004
//std::cout << loweredQueries << std::endl;
9601005

9611006
return Assemble(loweredQueries, stmt, queryResults);

src/lower/iterator.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -426,26 +426,6 @@ Iterators::Iterators()
426426
}
427427

428428

429-
static std::map<TensorVar, ir::Expr> createIRTensorVars(IndexStmt stmt)
430-
{
431-
std::map<TensorVar, ir::Expr> tensorVars;
432-
433-
// Create result and parameter variables
434-
vector<TensorVar> results = getResults(stmt);
435-
vector<TensorVar> arguments = getArguments(stmt);
436-
vector<TensorVar> temporaries = getTemporaries(stmt);
437-
438-
// Convert tensor results, arguments and temporaries to IR variables
439-
map<TensorVar, Expr> resultVars;
440-
vector<Expr> resultsIR = createVars(results, &resultVars);
441-
tensorVars.insert(resultVars.begin(), resultVars.end());
442-
vector<Expr> argumentsIR = createVars(arguments, &tensorVars);
443-
vector<Expr> temporariesIR = createVars(temporaries, &tensorVars);
444-
445-
return tensorVars;
446-
}
447-
448-
449429
Iterators::Iterators(IndexStmt stmt) : Iterators(stmt, createIRTensorVars(stmt))
450430
{
451431
}

0 commit comments

Comments
 (0)