Skip to content

lower: properly fix #355 #500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions src/lower/lowerer_impl_imperative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
inParallelLoopDepth++;
}

// Record that we might have some fresh locators that need to be recovered.
std::vector<Iterator> freshLocateIterators;

// Recover any available parents that were not recoverable previously
vector<Stmt> recoverySteps;
for (const IndexVar& varToRecover : provGraph.newlyRecoverableParents(forall.getIndexVar(), definedIndexVars)) {
Expand Down Expand Up @@ -634,17 +637,16 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
// the accessors for those locator variables as part of the recovery process.
// This is necessary after a fuse transformation, for example: If we fuse
// two index variables (i, j) into f, then after we've generated the loop for
// f, all locate accessors for i and j are now available for use.
// f, all locate accessors for i and j are now available for use. So, remember
// that we have some new locate iterators that should be recovered.
std::vector<Iterator> itersForVar;
for (auto& iters : iterators.levelIterators()) {
// Collect all level iterators that have locate and iterate over
// the recovered index variable.
if (iters.second.getIndexVar() == varToRecover && iters.second.hasLocate()) {
itersForVar.push_back(iters.second);
freshLocateIterators.push_back(iters.second);
}
}
// Finally, declare all of the collected iterators' position access variables.
recoverySteps.push_back(this->declLocatePosVars(itersForVar));

// place underived guard
std::vector<ir::Expr> iterBounds = provGraph.deriveIterBounds(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
Expand Down Expand Up @@ -799,7 +801,15 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
}
// Emit dimension coordinate iteration loop
else if (iterator.isDimensionIterator()) {
loops = lowerForallDimension(forall, point.locators(),
// A proper fix to #355. Adding information that those locate iterators are now ready is the
// correct way to recover them, rather than blindly duplicating the emitted locators.
auto locatorsCopy = std::vector<Iterator>(point.locators());
for (auto it : freshLocateIterators) {
if (!util::contains(locatorsCopy, it)) {
locatorsCopy.push_back(it);
}
}
loops = lowerForallDimension(forall, locatorsCopy,
inserters, appenders, reducedAccesses, recoveryStmt);
}
// Emit position iteration loop
Expand Down Expand Up @@ -1772,14 +1782,19 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt,
const set<Access>& reducedAccesses) {
Stmt initVals = resizeAndInitValues(appenders, reducedAccesses);

// Inserter positions
Stmt declInserterPosVars = declLocatePosVars(inserters);

// Locate positions
Stmt declLocatorPosVars = declLocatePosVars(locators);
// There can be overlaps between the inserters and locators, which results in
// duplicate emitting of variable declarations. We'll fix that here.
std::vector<Iterator> itersWithLocators;
for (auto it : inserters) {
if (!util::contains(itersWithLocators, it)) { itersWithLocators.push_back(it); }
}
for (auto it : locators) {
if (!util::contains(itersWithLocators, it)) { itersWithLocators.push_back(it); }
}
auto declPosVars = declLocatePosVars(itersWithLocators);

if (captureNextLocatePos) {
capturedLocatePos = Block::make(declInserterPosVars, declLocatorPosVars);
capturedLocatePos = declPosVars;
captureNextLocatePos = false;
}

Expand All @@ -1792,8 +1807,7 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt,
// TODO: Emit code to insert coordinates

return Block::make(initVals,
declInserterPosVars,
declLocatorPosVars,
declPosVars,
body,
appendCoords);
}
Expand Down