Skip to content

Commit c672940

Browse files
authored
precompute-propagate pass (#1179)
Implements #1172: this adds a variant of precompute, "precompute-propagate", which also does constant propagation. Precompute by itself just runs the interpreter on each expression and sees if it is in fact a constant; precompute-propagate also looks at the graph of connections between get and set locals, and propagates those constant values. This helps with cases as noticed in #1168 - while in most cases LLVM will do this already, it's important when inlining, e.g. inlining of the clamping math functions. This new pass is run when inlining, and otherwise only in -O3/-Oz, as it does increase compilation time noticeably if run on everything (and for almost no benefit if LLVM has run). Most of the code here is just refactoring out from the ssa pass the get/set graph computation, so it can now be used by both the ssa pass and precompute-propagate.
1 parent 40f52f2 commit c672940

29 files changed

+1688
-914
lines changed

src/ast/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
SET(ast_SOURCES
22
ExpressionAnalyzer.cpp
33
ExpressionManipulator.cpp
4+
LocalGraph.cpp
45
)
56
ADD_LIBRARY(ast STATIC ${ast_SOURCES})

src/ast/LocalGraph.cpp

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
/*
2+
* Copyright 2017 WebAssembly Community Group participants
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <iterator>
18+
19+
#include <wasm-builder.h>
20+
#include <ast/find_all.h>
21+
#include <ast/local-graph.h>
22+
23+
namespace wasm {
24+
25+
LocalGraph::LocalGraph(Function* func, Module* module) {
26+
walkFunctionInModule(func, module);
27+
}
28+
29+
void LocalGraph::computeInfluences() {
30+
for (auto& pair : locations) {
31+
auto* curr = pair.first;
32+
if (auto* set = curr->dynCast<SetLocal>()) {
33+
FindAll<GetLocal> findAll(set->value);
34+
for (auto* get : findAll.list) {
35+
getInfluences[get].insert(set);
36+
}
37+
} else {
38+
auto* get = curr->cast<GetLocal>();
39+
for (auto* set : getSetses[get]) {
40+
setInfluences[set].insert(get);
41+
}
42+
}
43+
}
44+
}
45+
46+
void LocalGraph::doWalkFunction(Function* func) {
47+
numLocals = func->getNumLocals();
48+
if (numLocals == 0) return; // nothing to do
49+
// We begin with each param being assigned from the incoming value, and the zero-init for the locals,
50+
// so the initial state is the identity permutation
51+
currMapping.resize(numLocals);
52+
for (auto& set : currMapping) {
53+
set = { nullptr };
54+
}
55+
PostWalker<LocalGraph>::walk(func->body);
56+
}
57+
58+
// control flow
59+
60+
void LocalGraph::visitBlock(Block* curr) {
61+
if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) {
62+
auto& infos = breakMappings[curr->name];
63+
infos.emplace_back(std::move(currMapping));
64+
currMapping = std::move(merge(infos));
65+
breakMappings.erase(curr->name);
66+
}
67+
}
68+
69+
void LocalGraph::finishIf() {
70+
// that's it for this if, merge
71+
std::vector<Mapping> breaks;
72+
breaks.emplace_back(std::move(currMapping));
73+
breaks.emplace_back(std::move(mappingStack.back()));
74+
mappingStack.pop_back();
75+
currMapping = std::move(merge(breaks));
76+
}
77+
78+
void LocalGraph::afterIfCondition(LocalGraph* self, Expression** currp) {
79+
self->mappingStack.push_back(self->currMapping);
80+
}
81+
void LocalGraph::afterIfTrue(LocalGraph* self, Expression** currp) {
82+
auto* curr = (*currp)->cast<If>();
83+
if (curr->ifFalse) {
84+
auto afterCondition = std::move(self->mappingStack.back());
85+
self->mappingStack.back() = std::move(self->currMapping);
86+
self->currMapping = std::move(afterCondition);
87+
} else {
88+
self->finishIf();
89+
}
90+
}
91+
void LocalGraph::afterIfFalse(LocalGraph* self, Expression** currp) {
92+
self->finishIf();
93+
}
94+
void LocalGraph::beforeLoop(LocalGraph* self, Expression** currp) {
95+
// save the state before entering the loop, for calculation later of the merge at the loop top
96+
self->mappingStack.push_back(self->currMapping);
97+
self->loopGetStack.push_back({});
98+
}
99+
void LocalGraph::visitLoop(Loop* curr) {
100+
if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) {
101+
auto& infos = breakMappings[curr->name];
102+
infos.emplace_back(std::move(mappingStack.back()));
103+
auto before = infos.back();
104+
auto& merged = merge(infos);
105+
// every local we created a phi for requires us to update get_local operations in
106+
// the loop - the branch back has means that gets in the loop have potentially
107+
// more sets reaching them.
108+
// we can detect this as follows: if a get of oldIndex has the same sets
109+
// as the sets at the entrance to the loop, then it is affected by the loop
110+
// header sets, and we can add to there sets that looped back
111+
auto linkLoopTop = [&](Index i, Sets& getSets) {
112+
auto& beforeSets = before[i];
113+
if (getSets.size() < beforeSets.size()) {
114+
// the get trivially has fewer sets, so it overrode the loop entry sets
115+
return;
116+
}
117+
std::vector<SetLocal*> intersection;
118+
std::set_intersection(beforeSets.begin(), beforeSets.end(),
119+
getSets.begin(), getSets.end(),
120+
std::back_inserter(intersection));
121+
if (intersection.size() < beforeSets.size()) {
122+
// the get has not the same sets as in the loop entry
123+
return;
124+
}
125+
// the get has the entry sets, so add any new ones
126+
for (auto* set : merged[i]) {
127+
getSets.insert(set);
128+
}
129+
};
130+
auto& gets = loopGetStack.back();
131+
for (auto* get : gets) {
132+
linkLoopTop(get->index, getSetses[get]);
133+
}
134+
// and the same for the loop fallthrough: any local that still has the
135+
// entry sets should also have the loop-back sets as well
136+
for (Index i = 0; i < numLocals; i++) {
137+
linkLoopTop(i, currMapping[i]);
138+
}
139+
// finally, breaks still in flight must be updated too
140+
for (auto& iter : breakMappings) {
141+
auto name = iter.first;
142+
if (name == curr->name) continue; // skip our own (which is still in use)
143+
auto& mappings = iter.second;
144+
for (auto& mapping : mappings) {
145+
for (Index i = 0; i < numLocals; i++) {
146+
linkLoopTop(i, mapping[i]);
147+
}
148+
}
149+
}
150+
// now that we are done with using the mappings, erase our own
151+
breakMappings.erase(curr->name);
152+
}
153+
mappingStack.pop_back();
154+
loopGetStack.pop_back();
155+
}
156+
void LocalGraph::visitBreak(Break* curr) {
157+
if (curr->condition) {
158+
breakMappings[curr->name].emplace_back(currMapping);
159+
} else {
160+
breakMappings[curr->name].emplace_back(std::move(currMapping));
161+
setUnreachable(currMapping);
162+
}
163+
}
164+
void LocalGraph::visitSwitch(Switch* curr) {
165+
std::set<Name> all;
166+
for (auto target : curr->targets) {
167+
all.insert(target);
168+
}
169+
all.insert(curr->default_);
170+
for (auto target : all) {
171+
breakMappings[target].emplace_back(currMapping);
172+
}
173+
setUnreachable(currMapping);
174+
}
175+
void LocalGraph::visitReturn(Return *curr) {
176+
setUnreachable(currMapping);
177+
}
178+
void LocalGraph::visitUnreachable(Unreachable *curr) {
179+
setUnreachable(currMapping);
180+
}
181+
182+
// local usage
183+
184+
void LocalGraph::visitGetLocal(GetLocal* curr) {
185+
assert(currMapping.size() == numLocals);
186+
assert(curr->index < numLocals);
187+
for (auto& loopGets : loopGetStack) {
188+
loopGets.push_back(curr);
189+
}
190+
// current sets are our sets
191+
getSetses[curr] = currMapping[curr->index];
192+
locations[curr] = getCurrentPointer();
193+
}
194+
void LocalGraph::visitSetLocal(SetLocal* curr) {
195+
assert(currMapping.size() == numLocals);
196+
assert(curr->index < numLocals);
197+
// current sets are just this set
198+
currMapping[curr->index] = { curr }; // TODO optimize?
199+
locations[curr] = getCurrentPointer();
200+
}
201+
202+
// traversal
203+
204+
void LocalGraph::scan(LocalGraph* self, Expression** currp) {
205+
if (auto* iff = (*currp)->dynCast<If>()) {
206+
// if needs special handling
207+
if (iff->ifFalse) {
208+
self->pushTask(LocalGraph::afterIfFalse, currp);
209+
self->pushTask(LocalGraph::scan, &iff->ifFalse);
210+
}
211+
self->pushTask(LocalGraph::afterIfTrue, currp);
212+
self->pushTask(LocalGraph::scan, &iff->ifTrue);
213+
self->pushTask(LocalGraph::afterIfCondition, currp);
214+
self->pushTask(LocalGraph::scan, &iff->condition);
215+
} else {
216+
PostWalker<LocalGraph>::scan(self, currp);
217+
}
218+
219+
// loops need pre-order visiting too
220+
if ((*currp)->is<Loop>()) {
221+
self->pushTask(LocalGraph::beforeLoop, currp);
222+
}
223+
}
224+
225+
// helpers
226+
227+
void LocalGraph::setUnreachable(Mapping& mapping) {
228+
mapping.resize(numLocals); // may have been emptied by a move
229+
mapping[0].clear();
230+
}
231+
232+
bool LocalGraph::isUnreachable(Mapping& mapping) {
233+
// we must have some set for each index, if only the zero init, so empty means we emptied it for unreachable code
234+
return mapping[0].empty();
235+
}
236+
237+
// merges a bunch of infos into one.
238+
// if we need phis, writes them into the provided vector. the caller should
239+
// ensure those are placed in the right location
240+
LocalGraph::Mapping& LocalGraph::merge(std::vector<Mapping>& mappings) {
241+
assert(mappings.size() > 0);
242+
auto& out = mappings[0];
243+
if (mappings.size() == 1) {
244+
return out;
245+
}
246+
// merge into the first
247+
for (Index j = 1; j < mappings.size(); j++) {
248+
auto& other = mappings[j];
249+
for (Index i = 0; i < numLocals; i++) {
250+
auto& outSets = out[i];
251+
for (auto* set : other[i]) {
252+
outSets.insert(set);
253+
}
254+
}
255+
}
256+
return out;
257+
}
258+
259+
} // namespace wasm
260+

src/ast/find_all.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright 2017 WebAssembly Community Group participants
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef wasm_ast_find_all_h
18+
#define wasm_ast_find_all_h
19+
20+
#include <wasm-traversal.h>
21+
22+
namespace wasm {
23+
24+
// Find all instances of a certain node type
25+
26+
template<typename T>
27+
struct FindAll {
28+
std::vector<T*> list;
29+
30+
FindAll(Expression* ast) {
31+
struct Finder : public PostWalker<Finder, UnifiedExpressionVisitor<Finder>> {
32+
std::vector<T*>* list;
33+
void visitExpression(Expression* curr) {
34+
if (curr->is<T>()) {
35+
(*list).push_back(curr->cast<T>());
36+
}
37+
}
38+
};
39+
Finder finder;
40+
finder.list = &list;
41+
finder.walk(ast);
42+
}
43+
};
44+
45+
} // namespace wasm
46+
47+
#endif // wasm_ast_find_all_h
48+

0 commit comments

Comments
 (0)