Skip to content

Commit 859cfd2

Browse files
Merge pull request #352 from Infinoid/schedule-parser
Fix command-line schedule parsing
2 parents cb4731d + a79f5ee commit 859cfd2

File tree

4 files changed

+337
-145
lines changed

4 files changed

+337
-145
lines changed

include/taco/parser/schedule_parser.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef TACO_SCHEDULE_PARSER_H
2+
#define TACO_SCHEDULE_PARSER_H
3+
4+
#include <string>
5+
#include <vector>
6+
7+
namespace taco {
8+
namespace parser {
9+
10+
// parse a string of the form: "reorder(i,j),precompute(D(i,j)*E(j,k),j,j_pre)"
11+
// into string vectors of the form:
12+
// [ [ "reorder", "i", "j" ], [ "precompute", "D(i,j)*E(j,k)", "j", "j_pre" ] ]
13+
std::vector<std::vector<std::string>> ScheduleParser(const std::string);
14+
15+
// serialize the result of a parse (for debugging)
16+
std::string serializeParsedSchedule(std::vector<std::vector<std::string>>);
17+
18+
}}
19+
20+
#endif //TACO_EINSUM_PARSER_H

src/parser/schedule_parser.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include <string>
2+
#include <vector>
3+
#include <iostream>
4+
5+
#include "taco/parser/lexer.h"
6+
#include "taco/parser/schedule_parser.h"
7+
#include "taco/error.h"
8+
9+
using std::vector;
10+
using std::string;
11+
using std::cout;
12+
using std::endl;
13+
14+
namespace taco{
15+
namespace parser{
16+
17+
/// Parses command line schedule directives (`-s <directive>`).
18+
/// Example: "precompute(B(i,j),j,jpre),reorder(j,i)" is parsed as:
19+
/// [ [ "precompute", "B(i,j)", "j", "jpre" ],
20+
/// [ "reorder", "j", "i" ] ]
21+
/// The first element of each inner vector is the function name.
22+
/// Inner parens are preserved. All whitespace is removed.
23+
vector<vector<string>> ScheduleParser(const string argValue) {
24+
int parenthesesCnt;
25+
vector<vector<string>> parsed;
26+
vector<string> current_schedule;
27+
string current_element;
28+
parser::Lexer lexer(argValue);
29+
parser::Token tok;
30+
parenthesesCnt = 0;
31+
for(tok = lexer.getToken(); tok != parser::Token::eot; tok = lexer.getToken()) {
32+
switch(tok) {
33+
case parser::Token::lparen:
34+
if(parenthesesCnt == 0) {
35+
// The first opening paren separates the name of the scheduler directive from its first parameter
36+
current_schedule.push_back(current_element);
37+
current_element = "";
38+
}
39+
else {
40+
// pass inner parens through to the scheduler
41+
current_element += lexer.tokenString(tok);
42+
}
43+
parenthesesCnt++;
44+
break;
45+
case parser::Token::rparen:
46+
taco_uassert(parenthesesCnt > 0) << "mismatched parentheses (too many right-parens, negative nesting level) in schedule expression '" << argValue << "'";
47+
if(parenthesesCnt > 1)
48+
current_element += lexer.tokenString(tok);
49+
parenthesesCnt--;
50+
break;
51+
case parser::Token::comma:
52+
if(parenthesesCnt == 0) {
53+
// new schedule directive
54+
current_schedule.push_back(current_element);
55+
parsed.push_back(current_schedule);
56+
current_schedule.clear();
57+
current_element = "";
58+
} else if(parenthesesCnt == 1) {
59+
// new parameter to schedule directive
60+
current_schedule.push_back(current_element);
61+
current_element = "";
62+
} else {
63+
// probably multiple indexes inside of an IndexExpr; pass it through
64+
current_element += lexer.tokenString(tok);
65+
break;
66+
}
67+
break;
68+
// things where .getIdentifier() makes sense
69+
case parser::Token::identifier:
70+
case parser::Token::int_scalar:
71+
case parser::Token::uint_scalar:
72+
case parser::Token::float_scalar:
73+
case parser::Token::complex_scalar:
74+
current_element += lexer.getIdentifier();
75+
break;
76+
// .tokenstring() works for the remaining cases
77+
default:
78+
current_element += lexer.tokenString(tok);
79+
break;
80+
}
81+
}
82+
taco_uassert(parenthesesCnt == 0) << "imbalanced parentheses (too few right-parens) in schedule expression '" << argValue << "'";
83+
if(current_element.length() > 0)
84+
current_schedule.push_back(current_element);
85+
if(current_schedule.size() > 0)
86+
parsed.push_back(current_schedule);
87+
return parsed;
88+
}
89+
90+
string serializeParsedSchedule(vector<vector<string>> parsed) {
91+
std::stringstream ss;
92+
ss << "[ ";
93+
for(vector<string> current_schedule : parsed) {
94+
ss << "[ ";
95+
for(string element : current_schedule) {
96+
ss << "'" << element << "', ";
97+
}
98+
ss << "], ";
99+
}
100+
ss << "]";
101+
return ss.str();
102+
}
103+
}}

test/tests-schedule-parser.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include <iostream>
2+
#include <taco/parser/schedule_parser.h>
3+
#include "test.h"
4+
5+
using std::cout;
6+
using std::endl;
7+
using std::string;
8+
using std::vector;
9+
using namespace taco::parser;
10+
11+
void assert_string_vectors_equal(vector<string> a, vector<string> b) {
12+
ASSERT_EQ(a.size(), b.size()) << "Vectors are of unequal lengths: " << a.size() << " != " << b.size();
13+
for(size_t i = 0; i < a.size(); i++) {
14+
EXPECT_EQ(a[i], b[i]) << "a[" << i << "] != b[" << i << "]: \"" << a[i] << "\" != \"" << b[i] << "\"";
15+
}
16+
}
17+
18+
void assert_string_vector_vectors_equal(vector<vector<string>> a, vector<vector<string>> b) {
19+
ASSERT_EQ(a.size(), b.size()) << "Vector-vectors are of unequal lengths: " << a.size() << " != " << b.size();
20+
for(size_t i = 0; i < a.size(); i++) {
21+
assert_string_vectors_equal(a[i], b[i]);
22+
}
23+
}
24+
25+
TEST(schedule_parser, normal_operation) {
26+
struct {
27+
string str;
28+
vector<vector<string>> result;
29+
} cases[] = {
30+
// basic parsing
31+
{ "i,j,k", { { "i" }, { "j" }, { "k" } } },
32+
{ "i(j,k)", { { "i", "j", "k" } } },
33+
{ "i(j,k),l(m,n)", { { "i", "j", "k" }, { "l", "m", "n" } } },
34+
{ "i(j,k),l(m(n,o),p)", { { "i", "j", "k" }, { "l", "m(n,o)", "p" } } },
35+
{ "i(j,k),l(m(n(o(p))),q)", { { "i", "j", "k" }, { "l", "m(n(o(p)))", "q" } } },
36+
37+
// whitespace
38+
{ "i,j, k", { { "i" }, { "j" }, { "k" } } },
39+
{ "i(j, k)", { { "i", "j", "k" } } },
40+
{ "i(j,k), l(m,n)", { { "i", "j", "k" }, { "l", "m", "n" } } },
41+
{ "i(j,k),l(m(n, o),p)", { { "i", "j", "k" }, { "l", "m(n,o)", "p" } } },
42+
{ "i(j,k),l(m(n(o(p))), q)", { { "i", "j", "k" }, { "l", "m(n(o(p)))", "q" } } },
43+
44+
// empty slots
45+
{ "", { } },
46+
{ ",j,k", { { "" }, { "j" }, { "k" } } },
47+
{ "i(,k)", { { "i", "", "k" } } },
48+
{ "(j,k)", { { "", "j", "k" } } },
49+
{ "i(j,),,l(m,n)", { { "i", "j", "" }, { "" }, { "l", "m", "n" } } },
50+
51+
// real scheduling directives
52+
{ "split(i,i0,i1,16)", { { "split", "i", "i0", "i1", "16" } } },
53+
{ "precompute(A(i,j)*x(j),i,i)", { { "precompute", "A(i,j)*x(j)", "i", "i" } } },
54+
{ "split(i,i0,i1,16),precompute(A(i,j)*x(j),i,i)",
55+
{ { "split", "i", "i0", "i1", "16" },
56+
{ "precompute", "A(i,j)*x(j)", "i", "i" } } },
57+
};
58+
for(auto test : cases) {
59+
auto actual = ScheduleParser(test.str);
60+
cout << "string \"" << test.str << "\"" << " parsed as: " << serializeParsedSchedule(actual) << endl;
61+
assert_string_vector_vectors_equal(test.result, actual);
62+
}
63+
}
64+
65+
TEST(schedule_parser, error_reporting) {
66+
struct {
67+
string str;
68+
string assertion;
69+
} cases[] = {
70+
{ "i,j,k(", "too few right-parens" },
71+
{ "i(j,k", "too few right-parens" },
72+
{ "i,j,k)", "too many right-parens" },
73+
{ "i,j,k)(", "too many right-parens" },
74+
};
75+
for(auto test : cases) {
76+
try {
77+
auto actual = ScheduleParser(test.str);
78+
// should throw an exception before getting here
79+
ASSERT_TRUE(false);
80+
} catch (taco::TacoException &e) {
81+
string message = e.what();
82+
EXPECT_TRUE(message.find(test.assertion) != string::npos)
83+
<< "substring \"" << test.assertion << "\" not found in exception message \"" << message << "\"";
84+
}
85+
}
86+
}

0 commit comments

Comments
 (0)