diff --git a/include/taco/parser/schedule_parser.h b/include/taco/parser/schedule_parser.h new file mode 100644 index 000000000..72277dcbd --- /dev/null +++ b/include/taco/parser/schedule_parser.h @@ -0,0 +1,20 @@ +#ifndef TACO_SCHEDULE_PARSER_H +#define TACO_SCHEDULE_PARSER_H + +#include +#include + +namespace taco { +namespace parser { + +// parse a string of the form: "reorder(i,j),precompute(D(i,j)*E(j,k),j,j_pre)" +// into string vectors of the form: +// [ [ "reorder", "i", "j" ], [ "precompute", "D(i,j)*E(j,k)", "j", "j_pre" ] ] +std::vector> ScheduleParser(const std::string); + +// serialize the result of a parse (for debugging) +std::string serializeParsedSchedule(std::vector>); + +}} + +#endif //TACO_EINSUM_PARSER_H diff --git a/src/parser/schedule_parser.cpp b/src/parser/schedule_parser.cpp new file mode 100644 index 000000000..4858db58a --- /dev/null +++ b/src/parser/schedule_parser.cpp @@ -0,0 +1,103 @@ +#include +#include +#include + +#include "taco/parser/lexer.h" +#include "taco/parser/schedule_parser.h" +#include "taco/error.h" + +using std::vector; +using std::string; +using std::cout; +using std::endl; + +namespace taco{ +namespace parser{ + +/// Parses command line schedule directives (`-s `). +/// Example: "precompute(B(i,j),j,jpre),reorder(j,i)" is parsed as: +/// [ [ "precompute", "B(i,j)", "j", "jpre" ], +/// [ "reorder", "j", "i" ] ] +/// The first element of each inner vector is the function name. +/// Inner parens are preserved. All whitespace is removed. +vector> ScheduleParser(const string argValue) { + int parenthesesCnt; + vector> parsed; + vector current_schedule; + string current_element; + parser::Lexer lexer(argValue); + parser::Token tok; + parenthesesCnt = 0; + for(tok = lexer.getToken(); tok != parser::Token::eot; tok = lexer.getToken()) { + switch(tok) { + case parser::Token::lparen: + if(parenthesesCnt == 0) { + // The first opening paren separates the name of the scheduler directive from its first parameter + current_schedule.push_back(current_element); + current_element = ""; + } + else { + // pass inner parens through to the scheduler + current_element += lexer.tokenString(tok); + } + parenthesesCnt++; + break; + case parser::Token::rparen: + taco_uassert(parenthesesCnt > 0) << "mismatched parentheses (too many right-parens, negative nesting level) in schedule expression '" << argValue << "'"; + if(parenthesesCnt > 1) + current_element += lexer.tokenString(tok); + parenthesesCnt--; + break; + case parser::Token::comma: + if(parenthesesCnt == 0) { + // new schedule directive + current_schedule.push_back(current_element); + parsed.push_back(current_schedule); + current_schedule.clear(); + current_element = ""; + } else if(parenthesesCnt == 1) { + // new parameter to schedule directive + current_schedule.push_back(current_element); + current_element = ""; + } else { + // probably multiple indexes inside of an IndexExpr; pass it through + current_element += lexer.tokenString(tok); + break; + } + break; + // things where .getIdentifier() makes sense + case parser::Token::identifier: + case parser::Token::int_scalar: + case parser::Token::uint_scalar: + case parser::Token::float_scalar: + case parser::Token::complex_scalar: + current_element += lexer.getIdentifier(); + break; + // .tokenstring() works for the remaining cases + default: + current_element += lexer.tokenString(tok); + break; + } + } + taco_uassert(parenthesesCnt == 0) << "imbalanced parentheses (too few right-parens) in schedule expression '" << argValue << "'"; + if(current_element.length() > 0) + current_schedule.push_back(current_element); + if(current_schedule.size() > 0) + parsed.push_back(current_schedule); + return parsed; +} + +string serializeParsedSchedule(vector> parsed) { + std::stringstream ss; + ss << "[ "; + for(vector current_schedule : parsed) { + ss << "[ "; + for(string element : current_schedule) { + ss << "'" << element << "', "; + } + ss << "], "; + } + ss << "]"; + return ss.str(); +} +}} diff --git a/test/tests-schedule-parser.cpp b/test/tests-schedule-parser.cpp new file mode 100644 index 000000000..9fc6310a8 --- /dev/null +++ b/test/tests-schedule-parser.cpp @@ -0,0 +1,86 @@ +#include +#include +#include "test.h" + +using std::cout; +using std::endl; +using std::string; +using std::vector; +using namespace taco::parser; + +void assert_string_vectors_equal(vector a, vector b) { + ASSERT_EQ(a.size(), b.size()) << "Vectors are of unequal lengths: " << a.size() << " != " << b.size(); + for(size_t i = 0; i < a.size(); i++) { + EXPECT_EQ(a[i], b[i]) << "a[" << i << "] != b[" << i << "]: \"" << a[i] << "\" != \"" << b[i] << "\""; + } +} + +void assert_string_vector_vectors_equal(vector> a, vector> b) { + ASSERT_EQ(a.size(), b.size()) << "Vector-vectors are of unequal lengths: " << a.size() << " != " << b.size(); + for(size_t i = 0; i < a.size(); i++) { + assert_string_vectors_equal(a[i], b[i]); + } +} + +TEST(schedule_parser, normal_operation) { + struct { + string str; + vector> result; + } cases[] = { + // basic parsing + { "i,j,k", { { "i" }, { "j" }, { "k" } } }, + { "i(j,k)", { { "i", "j", "k" } } }, + { "i(j,k),l(m,n)", { { "i", "j", "k" }, { "l", "m", "n" } } }, + { "i(j,k),l(m(n,o),p)", { { "i", "j", "k" }, { "l", "m(n,o)", "p" } } }, + { "i(j,k),l(m(n(o(p))),q)", { { "i", "j", "k" }, { "l", "m(n(o(p)))", "q" } } }, + + // whitespace + { "i,j, k", { { "i" }, { "j" }, { "k" } } }, + { "i(j, k)", { { "i", "j", "k" } } }, + { "i(j,k), l(m,n)", { { "i", "j", "k" }, { "l", "m", "n" } } }, + { "i(j,k),l(m(n, o),p)", { { "i", "j", "k" }, { "l", "m(n,o)", "p" } } }, + { "i(j,k),l(m(n(o(p))), q)", { { "i", "j", "k" }, { "l", "m(n(o(p)))", "q" } } }, + + // empty slots + { "", { } }, + { ",j,k", { { "" }, { "j" }, { "k" } } }, + { "i(,k)", { { "i", "", "k" } } }, + { "(j,k)", { { "", "j", "k" } } }, + { "i(j,),,l(m,n)", { { "i", "j", "" }, { "" }, { "l", "m", "n" } } }, + + // real scheduling directives + { "split(i,i0,i1,16)", { { "split", "i", "i0", "i1", "16" } } }, + { "precompute(A(i,j)*x(j),i,i)", { { "precompute", "A(i,j)*x(j)", "i", "i" } } }, + { "split(i,i0,i1,16),precompute(A(i,j)*x(j),i,i)", + { { "split", "i", "i0", "i1", "16" }, + { "precompute", "A(i,j)*x(j)", "i", "i" } } }, + }; + for(auto test : cases) { + auto actual = ScheduleParser(test.str); + cout << "string \"" << test.str << "\"" << " parsed as: " << serializeParsedSchedule(actual) << endl; + assert_string_vector_vectors_equal(test.result, actual); + } +} + +TEST(schedule_parser, error_reporting) { + struct { + string str; + string assertion; + } cases[] = { + { "i,j,k(", "too few right-parens" }, + { "i(j,k", "too few right-parens" }, + { "i,j,k)", "too many right-parens" }, + { "i,j,k)(", "too many right-parens" }, + }; + for(auto test : cases) { + try { + auto actual = ScheduleParser(test.str); + // should throw an exception before getting here + ASSERT_TRUE(false); + } catch (taco::TacoException &e) { + string message = e.what(); + EXPECT_TRUE(message.find(test.assertion) != string::npos) + << "substring \"" << test.assertion << "\" not found in exception message \"" << message << "\""; + } + } +} diff --git a/tools/taco.cpp b/tools/taco.cpp index f99fe1bbe..fcc654e08 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -9,7 +9,9 @@ #include "taco.h" #include "taco/error.h" +#include "taco/parser/lexer.h" #include "taco/parser/parser.h" +#include "taco/parser/schedule_parser.h" #include "taco/storage/storage.h" #include "taco/ir/ir.h" #include "taco/ir/ir_printer.h" @@ -210,31 +212,31 @@ static void printCommandLine(ostream& os, int argc, char* argv[]) { } } -static bool setSchedulingCommands(istream& in, parser::Parser& parser, IndexStmt& stmt) { +static bool setSchedulingCommands(vector> scheduleCommands, parser::Parser& parser, IndexStmt& stmt) { auto findVar = [&stmt](string name) { - ProvenanceGraph graph(stmt); + ProvenanceGraph graph(stmt); for (auto v : graph.getAllIndexVars()) { if (v.getName() == name) { return v; } } - throw "Index variable not defined in statement."; + taco_uassert(0) << "Index variable '" << name << "' not defined in statement " << stmt; + abort(); // to silence a warning: control reaches end of non-void function }; - bool isGPU = false; + bool isGPU = false; - while (true) { - string command; - in >> command; + for(vector scheduleCommand : scheduleCommands) { + string command = scheduleCommand[0]; + scheduleCommand.erase(scheduleCommand.begin()); if (command == "pos") { - string i, ipos; - in >> i; - in >> ipos; - - string tensor; - in >> tensor; + taco_uassert(scheduleCommand.size() == 3) << "'pos' scheduling directive takes 3 parameters: pos(i, ipos, tensor)"; + string i, ipos, tensor; + i = scheduleCommand[0]; + ipos = scheduleCommand[1]; + tensor = scheduleCommand[2]; for (auto a : getArgumentAccesses(stmt)) { if (a.getTensorVar().getName() == tensor) { @@ -245,77 +247,79 @@ static bool setSchedulingCommands(istream& in, parser::Parser& parser, IndexStmt } } else if (command == "fuse") { - string i, j, f; - in >> i; - in >> j; - in >> f; + taco_uassert(scheduleCommand.size() == 3) << "'fuse' scheduling directive takes 3 parameters: fuse(i, j, f)"; + string i, j, f; + i = scheduleCommand[0]; + j = scheduleCommand[1]; + f = scheduleCommand[2]; - IndexVar fused(f); - stmt = stmt.fuse(findVar(i), findVar(j), fused); + IndexVar fused(f); + stmt = stmt.fuse(findVar(i), findVar(j), fused); } else if (command == "split") { - string i, i1, i2; - in >> i; - in >> i1; - in >> i2; - - size_t splitFactor; - in >> splitFactor; + taco_uassert(scheduleCommand.size() == 4) << "'split' scheduling directive takes 4 parameters: split(i, i1, i2, splitFactor)"; + string i, i1, i2; + size_t splitFactor; + i = scheduleCommand[0]; + i1 = scheduleCommand[1]; + i2 = scheduleCommand[2]; + taco_uassert(sscanf(scheduleCommand[3].c_str(), "%zu", &splitFactor) == 1) << "failed to parse fourth parameter to `split` directive as a size_t"; IndexVar split1(i1); IndexVar split2(i2); stmt = stmt.split(findVar(i), split1, split2, splitFactor); // } else if (command == "divide") { - // string i, i1, i2; - // in >> i; - // in >> i1; - // in >> i2; + // string i, i1, i2; + // in >> i; + // in >> i1; + // in >> i2; - // size_t divideFactor; - // in >> divideFactor; + // size_t divideFactor; + // in >> divideFactor; // IndexVar divide1(i1); // IndexVar divide2(i2); // stmt = stmt.divide(findVar(i), divide1, divide2, divideFactor); } else if (command == "precompute") { - string exprStr, i, iw; - in >> exprStr; - in >> i; - in >> iw; + string exprStr, i, iw; + taco_uassert(scheduleCommand.size() == 3) << "'precompute' scheduling directive takes 3 parameters: precompute(expr, i, iw)"; + exprStr = scheduleCommand[0]; + i = scheduleCommand[1]; + iw = scheduleCommand[2]; IndexVar orig = findVar(i); - IndexVar pre; + IndexVar pre; try { pre = findVar(iw); - } catch (const char* e) { + } catch (TacoException &e) { pre = IndexVar(iw); } struct GetExpr : public IndexNotationVisitor { using IndexNotationVisitor::visit; - - string exprStr; - IndexExpr expr; + + string exprStr; + IndexExpr expr; void setExprStr(string input) { - exprStr = input; - exprStr.erase(remove(exprStr.begin(), exprStr.end(), ' '), exprStr.end()); + exprStr = input; + exprStr.erase(remove(exprStr.begin(), exprStr.end(), ' '), exprStr.end()); } string toString(IndexExpr e) { - stringstream tempStream; - tempStream << e; + stringstream tempStream; + tempStream << e; string tempStr = tempStream.str(); tempStr.erase(remove(tempStr.begin(), tempStr.end(), ' '), tempStr.end()); return tempStr; } - + void visit(const AccessNode* node) { - IndexExpr currentExpr(node); + IndexExpr currentExpr(node); if (toString(currentExpr) == exprStr) { - expr = currentExpr; + expr = currentExpr; } else { IndexNotationVisitor::visit(node); @@ -323,9 +327,9 @@ static bool setSchedulingCommands(istream& in, parser::Parser& parser, IndexStmt } void visit(const UnaryExprNode* node) { - IndexExpr currentExpr(node); + IndexExpr currentExpr(node); if (toString(currentExpr) == exprStr) { - expr = currentExpr; + expr = currentExpr; } else { IndexNotationVisitor::visit(node); @@ -335,7 +339,7 @@ static bool setSchedulingCommands(istream& in, parser::Parser& parser, IndexStmt void visit(const BinaryExprNode* node) { IndexExpr currentExpr(node); if (toString(currentExpr) == exprStr) { - expr = currentExpr; + expr = currentExpr; } else { IndexNotationVisitor::visit(node); @@ -344,14 +348,14 @@ static bool setSchedulingCommands(istream& in, parser::Parser& parser, IndexStmt }; GetExpr visitor; - visitor.setExprStr(exprStr); + visitor.setExprStr(exprStr); stmt.accept(&visitor); - Dimension dim; + Dimension dim; auto domains = stmt.getIndexVarDomains(); auto it = domains.find(orig); if (it != domains.end()) { - dim = it->second; + dim = it->second; } else { dim = Dimension(orig); } @@ -360,109 +364,105 @@ static bool setSchedulingCommands(istream& in, parser::Parser& parser, IndexStmt stmt = stmt.precompute(visitor.expr, orig, pre, workspace); } else if (command == "reorder") { - string line; - getline(in, line); - stringstream temp; - temp << line; - - vector reorderedVars; - string var; - while (temp >> var) { + taco_uassert(scheduleCommand.size() > 1) << "'reorder' scheduling directive needs at least 2 parameters: reorder(outermost, ..., innermost)"; + + vector reorderedVars; + for (string var : scheduleCommand) { reorderedVars.push_back(findVar(var)); } stmt = stmt.reorder(reorderedVars); } else if (command == "bound") { - string i, i1; - in >> i; - in >> i1; - + taco_uassert(scheduleCommand.size() == 2) << "'bound' scheduling directive takes 4 parameters: bound(i, i1, bound, type)"; + string i, i1, type; size_t bound; - in >> bound; - - string type; - in >> type; - - BoundType bound_type; - if (type == "MinExact") { - bound_type = BoundType::MinExact; - } else if (type == "MinConstraint") { - bound_type = BoundType::MinConstraint; + i = scheduleCommand[0]; + i1 = scheduleCommand[1]; + taco_uassert(sscanf(scheduleCommand[2].c_str(), "%zu", &bound) == 1) << "failed to parse third parameter to `bound` directive as a size_t"; + type = scheduleCommand[3]; + + BoundType bound_type; + if (type == "MinExact") { + bound_type = BoundType::MinExact; + } else if (type == "MinConstraint") { + bound_type = BoundType::MinConstraint; } else if (type == "MaxExact") { - bound_type = BoundType::MaxExact; + bound_type = BoundType::MaxExact; } else if (type == "MaxConstraint") { - bound_type = BoundType::MaxConstraint; + bound_type = BoundType::MaxConstraint; } else { taco_uerror << "Bound type not defined."; - goto end; + goto end; } IndexVar bound1(i1); stmt = stmt.bound(findVar(i), bound1, bound, bound_type); } else if (command == "unroll") { - string i; - in >> i; - - size_t unrollFactor; - in >> unrollFactor; + taco_uassert(scheduleCommand.size() == 2) << "'unroll' scheduling directive takes 2 parameters: unroll(i, unrollFactor)"; + string i; + size_t unrollFactor; + i = scheduleCommand[0]; + taco_uassert(sscanf(scheduleCommand[1].c_str(), "%zu", &unrollFactor) == 1) << "failed to parse second parameter to `unroll` directive as a size_t"; stmt = stmt.unroll(findVar(i), unrollFactor); - + } else if (command == "parallelize") { - string i, unit, strategy; - in >> i; - in >> unit; - in >> strategy; - - ParallelUnit parallel_unit; - if (unit == "NotParallel") { - parallel_unit = ParallelUnit::NotParallel; + string i, unit, strategy; + taco_uassert(scheduleCommand.size() == 3) << "'parallelize' scheduling directive takes 3 parameters: parallelize(i, unit, strategy)"; + i = scheduleCommand[0]; + unit = scheduleCommand[1]; + strategy = scheduleCommand[2]; + + ParallelUnit parallel_unit; + if (unit == "NotParallel") { + parallel_unit = ParallelUnit::NotParallel; } else if (unit == "GPUBlock") { parallel_unit = ParallelUnit::GPUBlock; - isGPU = true; + isGPU = true; } else if (unit == "GPUWarp") { parallel_unit = ParallelUnit::GPUWarp; - isGPU = true; + isGPU = true; } else if (unit == "GPUThread") { parallel_unit = ParallelUnit::GPUThread; - isGPU = true; + isGPU = true; } else if (unit == "CPUThread") { - parallel_unit = ParallelUnit::CPUThread; + parallel_unit = ParallelUnit::CPUThread; } else if (unit == "CPUVector") { parallel_unit = ParallelUnit::CPUVector; } else { taco_uerror << "Parallel hardware not defined."; - goto end; + goto end; } - OutputRaceStrategy output_race_strategy; + OutputRaceStrategy output_race_strategy; if (strategy == "IgnoreRaces") { - output_race_strategy = OutputRaceStrategy::IgnoreRaces; + output_race_strategy = OutputRaceStrategy::IgnoreRaces; } else if (strategy == "NoRaces") { - output_race_strategy = OutputRaceStrategy::NoRaces; - } else if (strategy == "Atomics") { - output_race_strategy = OutputRaceStrategy::Atomics; + output_race_strategy = OutputRaceStrategy::NoRaces; + } else if (strategy == "Atomics") { + output_race_strategy = OutputRaceStrategy::Atomics; } else if (strategy == "Temporary") { output_race_strategy = OutputRaceStrategy::Temporary; } else if (strategy == "ParallelReduction") { output_race_strategy = OutputRaceStrategy::ParallelReduction; - } else { - taco_uerror << "Race strategy not defined."; - goto end; + } else { + taco_uerror << "Race strategy not defined."; + goto end; } stmt = stmt.parallelize(findVar(i), parallel_unit, output_race_strategy); } else { - break; + taco_uerror << "Unknown scheduling function \"" << command << "\""; + break; } - end:; + end:; } - return isGPU; + return isGPU; } int main(int argc, char* argv[]) { @@ -492,7 +492,7 @@ int main(int argc, char* argv[]) { bool readKernels = false; bool cuda = false; - bool setSchedule = false; + bool setSchedule = false; ParallelSchedule sched = ParallelSchedule::Static; int chunkSize = 0; @@ -501,7 +501,7 @@ int main(int argc, char* argv[]) { taco::util::TimeResults compileTime; taco::util::TimeResults assembleTime; - + int repeat = 1; taco::util::TimeResults timevalue; @@ -523,7 +523,7 @@ int main(int argc, char* argv[]) { vector kernelFilenames; - vector scheduleCommands; + vector> scheduleCommands; for (int i = 1; i < argc; i++) { string arg = argv[i]; @@ -812,25 +812,13 @@ int main(int argc, char* argv[]) { else if ("-print-kernels" == argName) { printKernels = true; } - else if ("-s" == argName) { - setSchedule = true; - int parenthesesCnt = 0; + else if ("-s" == argName) { + setSchedule = true; + vector> parsed = parser::ScheduleParser(argValue); - std::replace_if(argValue.begin(), argValue.end(), [&parenthesesCnt](char c) { - if (c == '(') { - if (parenthesesCnt++ == 0) { // '(' for a call - return true; - } - } else if (c == ',') { - return parenthesesCnt <= 1; - } else if (c == ')') { - if (--parenthesesCnt == 0) { // ')' for a call - return true; - } - } - return false; - }, ' '); - scheduleCommands.push_back(argValue); + taco_uassert(parsed.size() > 0) << "-s parameter got no scheduling directives?"; + for(vector directive : parsed) + scheduleCommands.push_back(directive); } else if ("-prefix" == argName) { prefix = argValue; @@ -858,7 +846,7 @@ int main(int argc, char* argv[]) { for (auto& tensorNames : inputFilenames) { string name = tensorNames.first; string filename = tensorNames.second; - + if (util::contains(dataTypes, name) && dataTypes.at(name) != Float64) { return reportError("Loaded tensors can only be type double", 7); } @@ -927,12 +915,7 @@ int main(int argc, char* argv[]) { stmt = reorderLoopsTopologically(stmt); if (setSchedule) { - stringstream scheduleStream; - for (string command : scheduleCommands) { - scheduleStream << command << endl; - } - - cuda |= setSchedulingCommands(scheduleStream, parser, stmt); + cuda |= setSchedulingCommands(scheduleCommands, parser, stmt); } else { stmt = insertTemporaries(stmt); @@ -970,7 +953,7 @@ int main(int argc, char* argv[]) { module->addFunction(evaluate); module->compile(); , "Compile: ", compileTime); - + void* compute = module->getFuncPtr(prefix+"compute"); void* assemble = module->getFuncPtr(prefix+"assemble"); void* evaluate = module->getFuncPtr(prefix+"evaluate"); @@ -1041,7 +1024,7 @@ int main(int argc, char* argv[]) { evaluate = lower(stmt, prefix+"evaluate", true, true); } - string packComment = + string packComment = "/*\n" " * The `pack` functions convert coordinate and value arrays in COO format,\n" " * with nonzeros sorted lexicographically by their coordinates, to the\n" @@ -1052,9 +1035,9 @@ int main(int argc, char* argv[]) { " *\n" " * For both, the `_COO_pos` arrays contain two elements, where the first is 0\n" " * and the second is the number of nonzeros in the tensor.\n" - " */"; - - vector packs; + " */"; + + vector packs; for (auto a : getArgumentAccesses(stmt)) { TensorVar tensor = a.getTensorVar(); if (tensor.getOrder() == 0) { @@ -1152,7 +1135,7 @@ int main(int argc, char* argv[]) { if (unpack.defined()) { cout << endl << packComment << endl; } - + for (auto pack : packs) { codegen->compile(pack, false); cout << endl << endl; @@ -1186,7 +1169,7 @@ int main(int argc, char* argv[]) { << "," << timevalue.stdev << "," << timevalue.median << endl; filestream.close(); } - + if (writeCompute) { std::ofstream filestream; filestream.open(writeComputeFilename, @@ -1221,7 +1204,7 @@ int main(int argc, char* argv[]) { std::shared_ptr codegenFile = ir::CodeGen::init_default(filestream, ir::CodeGen::ImplementationGen); bool hasPrinted = false; - + if (compute.defined() ) { codegenFile->compile(compute, !hasPrinted); hasPrinted = true;