Skip to content

DRAFT: Modify iterators to resolve precompute bugs #470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/index_notation/index_notation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,14 @@ IndexStmt IndexStmt::precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorV
IndexStmt transformed = *this;
string reason;

if (i != iw) {
IndexVarRel rel = IndexVarRel(new PrecomputeRelNode(i, iw));
transformed = Transformation(AddSuchThatPredicates({rel})).apply(transformed, &reason);
if (!transformed.defined()) {
taco_uerror << reason;
}
}

transformed = Transformation(Precompute(expr, i, iw, workspace)).apply(transformed, &reason);
if (!transformed.defined()) {
taco_uerror << reason;
Expand Down
9 changes: 6 additions & 3 deletions src/index_notation/provenance_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1120,9 +1120,12 @@ bool ProvenanceGraph::isAvailable(IndexVar indexVar, std::set<IndexVar> defined)

bool ProvenanceGraph::isRecoverable(taco::IndexVar indexVar, std::set<taco::IndexVar> defined) const {
// all children are either defined or recoverable from their children
for (const IndexVar& child : getChildren(indexVar)) {
if (!defined.count(child) && (isFullyDerived(child) || !isRecoverable(child, defined))) {
return false;
// precompute relations are always recoverable since their children never appear in the same loop
if (!(childRelMap.count(indexVar) && childRelMap.at(indexVar).getRelType() == IndexVarRelType::PRECOMPUTE)) {
for (const IndexVar& child : getChildren(indexVar)) {
if (!defined.count(child) && (isFullyDerived(child) || !isRecoverable(child, defined))) {
return false;
}
}
}
return true;
Expand Down
9 changes: 9 additions & 0 deletions src/lower/iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,15 @@ Iterators::Iterators(IndexStmt stmt, const map<TensorVar, Expr>& tensorVars)
underivedAdded.insert(underived);
}
}

// Insert all children of current index variable into iterators as well
for (const IndexVar& child : provGraph.getChildren(n->indexVar)) {
if (!underivedAdded.count(child)) {
content->modeIterators.insert({child, child});
underivedAdded.insert(child);
}
}

m->match(n->stmt);
})
);
Expand Down
3 changes: 0 additions & 3 deletions src/lower/lowerer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ LowererImpl::lower(IndexStmt stmt, string name,
}
}
argumentsIR.insert(argumentsIR.begin(), indexSetArgs.begin(), indexSetArgs.end());

// Create variables for temporaries
// TODO Remove this
for (auto& temp : temporaries) {
Expand Down Expand Up @@ -815,8 +814,6 @@ Stmt LowererImpl::lowerForall(Forall forall)
loops = lowerMergeLattice(lattice, underivedAncestors[0],
forall.getStmt(), reducedAccesses);
}
// taco_iassert(loops.defined());

if (!generateComputeCode() && !hasStores(loops)) {
// If assembly loop does not modify output arrays, then it can be safely
// omitted.
Expand Down