Skip to content

Commit ad7037c

Browse files
Merge pull request WebAssembly#5 from MPurscheUnity/fix_remove_function_pass
Fix remove function pass
2 parents 76bb393 + 0ee0894 commit ad7037c

File tree

8 files changed

+422
-229
lines changed

8 files changed

+422
-229
lines changed

src/passes/LogExecution.cpp

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030

3131
#include "asmjs/shared-constants.h"
3232
#include "shared-constants.h"
33+
#include <map>
3334
#include <pass.h>
3435
#include <wasm-builder.h>
3536
#include <wasm.h>
36-
#include <map>
3737

3838
namespace wasm {
3939

@@ -44,17 +44,24 @@ struct LogExecution : public WalkerPass<PostWalker<LogExecution>> {
4444

4545
Index nextFreeIndex = 0;
4646

47-
// Tries to convert a string to a function index. Returns (Index)-1 on failure.
48-
Index stringToIndex(const char *s) {
49-
for(const char *q = s; *q; ++q)
50-
if (!isdigit(*q))
51-
return (Index)-1;
52-
return std::stoi(s);
47+
// Tries to convert a string to a function index. Returns (Index)-1 on
48+
// failure.
49+
Index stringToIndex(const char* s) {
50+
for (const char* q = s; *q; ++q) {
51+
if (!isdigit(*q)) {
52+
return (Index)-1;
53+
}
54+
}
55+
return std::stoi(s);
5356
}
5457

55-
void visitLoop(Loop* curr) { curr->body = makeLogCall(curr->body, nextFreeIndex++); }
58+
void visitLoop(Loop* curr) {
59+
curr->body = makeLogCall(curr->body, nextFreeIndex++);
60+
}
5661

57-
void visitReturn(Return* curr) { replaceCurrent(makeLogCall(curr, nextFreeIndex++)); }
62+
void visitReturn(Return* curr) {
63+
replaceCurrent(makeLogCall(curr, nextFreeIndex++));
64+
}
5865

5966
void visitFunction(Function* curr) {
6067
if (curr->imported()) {
@@ -68,63 +75,74 @@ struct LogExecution : public WalkerPass<PostWalker<LogExecution>> {
6875
}
6976

7077
if (functionOrdinals.find(curr) == functionOrdinals.end()) {
71-
Fatal() << "LogExecution: Internal mismatch in mapping functions to their ordinals for logging execution!";
78+
Fatal() << "LogExecution: Internal mismatch in mapping functions to "
79+
"their ordinals for logging execution!";
7280
}
7381

7482
curr->body = makeLogCall(curr->body, functionOrdinals.find(curr)->second);
7583
}
7684

7785
void doWalkModule(Module* curr) {
78-
// Add the import
79-
auto import =
80-
Builder::makeFunction(LOGGER, Signature(Type::i32, Type::none), {});
86+
// Add the import
87+
auto import =
88+
Builder::makeFunction(LOGGER, Signature(Type::i32, Type::none), {});
89+
90+
// Import the log function from import "env" if the module
91+
// imports other functions from that name.
92+
for (auto& func : curr->functions) {
93+
if (func->imported() && func->module == ENV) {
94+
import->module = func->module;
95+
break;
96+
}
97+
}
8198

82-
// Import the log function from import "env" if the module
83-
// imports other functions from that name.
99+
// If not, then pick the import name of the first function we find.
100+
if (!import->module) {
84101
for (auto& func : curr->functions) {
85-
if (func->imported() && func->module == ENV) {
102+
if (func->imported()) {
86103
import->module = func->module;
87104
break;
88105
}
89106
}
107+
}
90108

91-
// If not, then pick the import name of the first function we find.
92-
if (!import->module) {
93-
for (auto& func : curr->functions) {
94-
if (func->imported()) {
95-
import->module = func->module;
96-
break;
97-
}
98-
}
99-
}
100-
101-
import->base = LOGGER;
102-
curr->addFunction(std::move(import));
109+
import->base = LOGGER;
110+
curr->addFunction(std::move(import));
103111

104-
// Reserve all function indices up front for the function names. This is
105-
// so that the logged ordinal numbers will match up with the function ordinals.
106-
int idx = 0;
107-
for (auto& func : curr->functions) {
108-
if (func->imported()) ++idx;
109-
}
112+
// Reserve all function indices up front for the function names. This is
113+
// so that the logged ordinal numbers will match up with the function
114+
// ordinals.
115+
Index idx = 0;
116+
for (auto& func : curr->functions) {
117+
if (func->imported()) {
118+
++idx;
119+
}
120+
}
110121

111-
for (auto& func : curr->functions) {
112-
if (func->imported()) continue;
122+
for (auto& func : curr->functions) {
123+
if (func->imported()) {
124+
continue;
125+
}
113126

114-
Index currentFunctionIndex = (Index)stringToIndex(func->name.toString().c_str());
115-
if (currentFunctionIndex != (Index)-1) {
116-
if (currentFunctionIndex != idx)
117-
std::cerr << "Functions are not in ordinal order! currentFunctionIndex=" << currentFunctionIndex << ", vs idx=" << idx << std::endl;
127+
Index currentFunctionIndex =
128+
(Index)stringToIndex(func->name.toString().c_str());
129+
if (currentFunctionIndex != (Index)-1) {
130+
if (currentFunctionIndex != idx) {
131+
std::cerr
132+
<< "Functions are not in ordinal order! currentFunctionIndex="
133+
<< currentFunctionIndex << ", vs idx=" << idx << std::endl;
118134
}
119-
else
120-
currentFunctionIndex = idx;
121-
functionOrdinals[func.get()] = idx;
122-
std::cerr << "Function " << func->name << " has ordinal " << idx << std::endl;
123-
nextFreeIndex = std::max(nextFreeIndex, currentFunctionIndex + 1);
124-
++idx;
125-
}
135+
} else {
136+
currentFunctionIndex = idx;
137+
}
138+
functionOrdinals[func.get()] = idx;
139+
std::cerr << "Function " << func->name << " has ordinal " << idx
140+
<< std::endl;
141+
nextFreeIndex = std::max(nextFreeIndex, currentFunctionIndex + 1);
142+
++idx;
143+
}
126144

127-
PostWalker<LogExecution>::doWalkModule(curr);
145+
PostWalker<LogExecution>::doWalkModule(curr);
128146
}
129147

130148
private:

src/passes/RemoveFunctions.cpp

Lines changed: 102 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,83 +22,119 @@
2222
#include <cctype>
2323

2424
#include "pass.h"
25+
#include "support/file.h"
2526
#include "wasm-builder.h"
2627
#include "wasm.h"
27-
#include "support/file.h"
2828

2929
namespace wasm {
3030

31-
static std::vector<Name> parseFunctionList(const IString &functionList, Module *module) {
32-
std::vector<Name> functions;
33-
std::string input = functionList.toString().c_str();
34-
35-
// If --remove-functions=* is passed, remove everything possible. (track this as an empty function list)
36-
if (functionList == "*") {
37-
return functions;
38-
}
39-
40-
// Read function list from a file if prefixed with '@'
41-
if (functionList.startsWith(IString("@"))) {
42-
input = read_file<std::string>(input.substr(1), Flags::Text);
43-
}
44-
45-
// Split string to a string list, delimited by ; and \n
46-
size_t begin = 0;
47-
for(size_t end = 1; end <= input.length(); ++end) {
48-
if (input[end] == ';' || input[end] == '\n' || end == input.length()) {
49-
// Trim \r and whitespace
50-
size_t trimEnd = end;
51-
while(trimEnd > 0 && input[trimEnd-1] <= 32) --trimEnd;
52-
size_t trimBegin = begin;
53-
while(trimBegin < input.length() && input[trimBegin] <= 32) ++trimBegin;
54-
if (trimBegin < trimEnd) {
55-
std::string name = input.substr(trimBegin, trimEnd - trimBegin);
56-
if (std::isdigit(name[0])) {
57-
Index i = std::stoi(name);
58-
if (i >= module->functions.size()) {
59-
Fatal() << "Out of bounds function index " << i << "! (module has only " << module->functions.size() << " functions)";
60-
}
61-
// Assumes imports are at the beginning
62-
functions.push_back(module->functions[i]->name);
63-
} else {
64-
functions.push_back(name);
65-
}
66-
}
67-
begin = end + 1;
31+
static std::vector<Name> parseFunctionList(const IString& functionList,
32+
Module* module) {
33+
std::vector<Name> functions;
34+
std::string input = functionList.toString().c_str();
35+
36+
// If --remove-functions=* is passed, remove everything possible. (track this
37+
// as an empty function list)
38+
if (functionList == "*") {
39+
return functions;
40+
}
41+
42+
// Read function list from a file if prefixed with '@'
43+
if (functionList.startsWith(IString("@"))) {
44+
input = read_file<std::string>(input.substr(1), Flags::Text);
45+
}
46+
47+
// Split string to a string list, delimited by ; and \n
48+
size_t begin = 0;
49+
for (size_t end = 1; end <= input.length(); ++end) {
50+
if (input[end] == ';' || input[end] == '\n' || end == input.length()) {
51+
// Trim \r and whitespace
52+
size_t trimEnd = end;
53+
while (trimEnd > 0 && input[trimEnd - 1] <= 32) {
54+
--trimEnd;
55+
}
56+
size_t trimBegin = begin;
57+
while (trimBegin < input.length() && input[trimBegin] <= 32) {
58+
++trimBegin;
6859
}
69-
}
70-
if (functions.empty()) {
71-
Fatal() << "Unable to parse argument --remove-functions=" << functionList;
72-
}
73-
return functions;
60+
if (trimBegin < trimEnd) {
61+
std::string name = input.substr(trimBegin, trimEnd - trimBegin);
62+
if (std::isdigit(name[0])) {
63+
Index i = std::stoi(name);
64+
if (i >= module->functions.size()) {
65+
Fatal() << "Out of bounds function index " << i
66+
<< "! (module has only " << module->functions.size()
67+
<< " functions)";
68+
}
69+
// Assumes imports are at the beginning
70+
functions.push_back(module->functions[i]->name);
71+
} else {
72+
functions.push_back(name);
73+
}
74+
}
75+
begin = end + 1;
76+
}
77+
}
78+
if (functions.empty()) {
79+
Fatal() << "Unable to parse argument --remove-functions=" << functionList;
80+
}
81+
return functions;
7482
}
7583

76-
static void remove(PassRunner* runner, Module* module, std::vector<Name> functionsToRemove) {
84+
static void remove(PassRunner* runner,
85+
Module* module,
86+
std::vector<Name> functionsToRemove) {
7787

7888
Builder builder(*module);
7989

8090
for (auto& func : module->functions) {
81-
if (!func->imported() && (functionsToRemove.empty() || std::find(functionsToRemove.begin(), functionsToRemove.end(), func->name) != functionsToRemove.end())) {
82-
const Type returns = func->getResults();
83-
if (returns == Type::none) {
84-
std::cerr << "removing void function " << func->name << "\n";
85-
func->vars.clear();
86-
func->body = builder.makeReturn();
87-
}
88-
else if (returns == Type::i32 || returns == Type::i64 || returns == Type::f32 || returns == Type::f64) {
89-
std::cerr << "removing i32/i64/f32/f64 function " << func->name << "\n";
90-
func->vars.clear();
91-
func->body = builder.makeConst(Literal(int32_t(0)));
92-
}
93-
else {
94-
std::cerr << "unable to remove function " << func->name << "since it returns a " << returns << "\n";
95-
}
96-
}
91+
if (!func->imported() &&
92+
(functionsToRemove.empty() ||
93+
std::find(functionsToRemove.begin(),
94+
functionsToRemove.end(),
95+
func->name) != functionsToRemove.end())) {
96+
const Type returns = func->getResults();
97+
98+
if (returns == Type::unreachable) {
99+
std::cerr << "removing unreachable function " << func->name << "\n";
100+
func->vars.clear();
101+
func->body = builder.makeUnreachable();
102+
} else if (returns == Type::none) {
103+
std::cerr << "removing void function " << func->name << "\n";
104+
func->vars.clear();
105+
func->body = builder.makeReturn();
106+
} else if (returns == Type::i32) {
107+
std::cerr << "removing i32 function " << func->name << "\n";
108+
func->vars.clear();
109+
func->body = builder.makeConst(Literal(int32_t(0)));
110+
} else if (returns == Type::i64) {
111+
std::cerr << "removing i64 function " << func->name << "\n";
112+
func->vars.clear();
113+
func->body = builder.makeConst(Literal(int64_t(0)));
114+
} else if (returns == Type::f32) {
115+
std::cerr << "removing f32 function " << func->name << "\n";
116+
func->vars.clear();
117+
func->body = builder.makeConst(Literal(float(0.0f)));
118+
} else if (returns == Type::f64) {
119+
std::cerr << "removing f64 function " << func->name << "\n";
120+
func->vars.clear();
121+
func->body = builder.makeConst(Literal(double(0.0)));
122+
} else if (returns == Type::v128) {
123+
std::cerr << "removing v128 function " << func->name << "\n";
124+
func->vars.clear();
125+
std::array<uint8_t, 16> bytes;
126+
bytes.fill(0);
127+
func->body = builder.makeConst(Literal(bytes.data()));
128+
} else {
129+
std::cerr << "unable to remove function " << func->name
130+
<< "since it returns a " << returns << "\n";
131+
}
132+
}
97133
}
98134

99135
// Remove unneeded things.
100136
PassRunner postRunner(runner);
101-
// postRunner.add("inlining-optimizing");
137+
// postRunner.add("inlining-optimizing");
102138
postRunner.add("remove-unused-module-elements");
103139
postRunner.setIsNested(true);
104140
postRunner.run();
@@ -108,7 +144,11 @@ struct RemoveFunctions : public Pass {
108144
void run(Module* module) override {
109145
Name name = getPassRunner()->options.getArgument(
110146
"remove-functions",
111-
"RemoveFunctions usage: wasm-opt --remove-functions=FUNCTION_NAME"); // todo: multiple functions via --remove-functions=name1;name2;index3;... or [email protected]
147+
"RemoveFunctions usage: wasm-opt "
148+
"--remove-functions=FUNCTION_NAME"); // todo: multiple functions via
149+
// --remove-functions=name1;name2;index3;...
150+
// or
151+
112152
std::vector<Name> functionsToRemove = parseFunctionList(name, module);
113153
remove(getPassRunner(), module, functionsToRemove);
114154
}

0 commit comments

Comments
 (0)