diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index bd91d892e..51fb8770c 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -1514,6 +1514,14 @@ IndexStmt IndexStmt::precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorV IndexStmt transformed = *this; string reason; + if (i != iw) { + IndexVarRel rel = IndexVarRel(new PrecomputeRelNode(i, iw)); + transformed = Transformation(AddSuchThatPredicates({rel})).apply(transformed, &reason); + if (!transformed.defined()) { + taco_uerror << reason; + } + } + transformed = Transformation(Precompute(expr, i, iw, workspace)).apply(transformed, &reason); if (!transformed.defined()) { taco_uerror << reason; diff --git a/src/index_notation/provenance_graph.cpp b/src/index_notation/provenance_graph.cpp index bc0a72cc2..54c7ab372 100644 --- a/src/index_notation/provenance_graph.cpp +++ b/src/index_notation/provenance_graph.cpp @@ -1120,9 +1120,12 @@ bool ProvenanceGraph::isAvailable(IndexVar indexVar, std::set defined) bool ProvenanceGraph::isRecoverable(taco::IndexVar indexVar, std::set defined) const { // all children are either defined or recoverable from their children - for (const IndexVar& child : getChildren(indexVar)) { - if (!defined.count(child) && (isFullyDerived(child) || !isRecoverable(child, defined))) { - return false; + // precompute relations are always recoverable since their children never appear in the same loop + if (!(childRelMap.count(indexVar) && childRelMap.at(indexVar).getRelType() == IndexVarRelType::PRECOMPUTE)) { + for (const IndexVar& child : getChildren(indexVar)) { + if (!defined.count(child) && (isFullyDerived(child) || !isRecoverable(child, defined))) { + return false; + } } } return true; diff --git a/src/lower/iterator.cpp b/src/lower/iterator.cpp index 6fc337c46..0f0c024c5 100644 --- a/src/lower/iterator.cpp +++ b/src/lower/iterator.cpp @@ -531,6 +531,15 @@ Iterators::Iterators(IndexStmt stmt, const map& tensorVars) underivedAdded.insert(underived); } } + + // Insert all children of current index variable into iterators as well + for (const IndexVar& child : provGraph.getChildren(n->indexVar)) { + if (!underivedAdded.count(child)) { + content->modeIterators.insert({child, child}); + underivedAdded.insert(child); + } + } + m->match(n->stmt); }) ); diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index d4d8f8362..3e73dfb85 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -253,7 +253,6 @@ LowererImpl::lower(IndexStmt stmt, string name, } } argumentsIR.insert(argumentsIR.begin(), indexSetArgs.begin(), indexSetArgs.end()); - // Create variables for temporaries // TODO Remove this for (auto& temp : temporaries) { @@ -815,8 +814,6 @@ Stmt LowererImpl::lowerForall(Forall forall) loops = lowerMergeLattice(lattice, underivedAncestors[0], forall.getStmt(), reducedAccesses); } -// taco_iassert(loops.defined()); - if (!generateComputeCode() && !hasStores(loops)) { // If assembly loop does not modify output arrays, then it can be safely // omitted.