|
| 1 | +#include <iostream> |
| 2 | +#include <taco/parser/schedule_parser.h> |
| 3 | +#include "test.h" |
| 4 | + |
| 5 | +using std::cout; |
| 6 | +using std::endl; |
| 7 | +using std::string; |
| 8 | +using std::vector; |
| 9 | +using namespace taco::parser; |
| 10 | + |
| 11 | +void assert_string_vectors_equal(vector<string> a, vector<string> b) { |
| 12 | + ASSERT_EQ(a.size(), b.size()) << "Vectors are of unequal lengths: " << a.size() << " != " << b.size(); |
| 13 | + for(size_t i = 0; i < a.size(); i++) { |
| 14 | + EXPECT_EQ(a[i], b[i]) << "a[" << i << "] != b[" << i << "]: \"" << a[i] << "\" != \"" << b[i] << "\""; |
| 15 | + } |
| 16 | +} |
| 17 | + |
| 18 | +void assert_string_vector_vectors_equal(vector<vector<string>> a, vector<vector<string>> b) { |
| 19 | + ASSERT_EQ(a.size(), b.size()) << "Vector-vectors are of unequal lengths: " << a.size() << " != " << b.size(); |
| 20 | + for(size_t i = 0; i < a.size(); i++) { |
| 21 | + assert_string_vectors_equal(a[i], b[i]); |
| 22 | + } |
| 23 | +} |
| 24 | + |
| 25 | +TEST(schedule_parser, normal_operation) { |
| 26 | + struct { |
| 27 | + string str; |
| 28 | + vector<vector<string>> result; |
| 29 | + } cases[] = { |
| 30 | + // basic parsing |
| 31 | + { "i,j,k", { { "i" }, { "j" }, { "k" } } }, |
| 32 | + { "i(j,k)", { { "i", "j", "k" } } }, |
| 33 | + { "i(j,k),l(m,n)", { { "i", "j", "k" }, { "l", "m", "n" } } }, |
| 34 | + { "i(j,k),l(m(n,o),p)", { { "i", "j", "k" }, { "l", "m(n,o)", "p" } } }, |
| 35 | + { "i(j,k),l(m(n(o(p))),q)", { { "i", "j", "k" }, { "l", "m(n(o(p)))", "q" } } }, |
| 36 | + |
| 37 | + // whitespace |
| 38 | + { "i,j, k", { { "i" }, { "j" }, { "k" } } }, |
| 39 | + { "i(j, k)", { { "i", "j", "k" } } }, |
| 40 | + { "i(j,k), l(m,n)", { { "i", "j", "k" }, { "l", "m", "n" } } }, |
| 41 | + { "i(j,k),l(m(n, o),p)", { { "i", "j", "k" }, { "l", "m(n,o)", "p" } } }, |
| 42 | + { "i(j,k),l(m(n(o(p))), q)", { { "i", "j", "k" }, { "l", "m(n(o(p)))", "q" } } }, |
| 43 | + |
| 44 | + // empty slots |
| 45 | + { "", { } }, |
| 46 | + { ",j,k", { { "" }, { "j" }, { "k" } } }, |
| 47 | + { "i(,k)", { { "i", "", "k" } } }, |
| 48 | + { "(j,k)", { { "", "j", "k" } } }, |
| 49 | + { "i(j,),,l(m,n)", { { "i", "j", "" }, { "" }, { "l", "m", "n" } } }, |
| 50 | + |
| 51 | + // real scheduling directives |
| 52 | + { "split(i,i0,i1,16)", { { "split", "i", "i0", "i1", "16" } } }, |
| 53 | + { "precompute(A(i,j)*x(j),i,i)", { { "precompute", "A(i,j)*x(j)", "i", "i" } } }, |
| 54 | + { "split(i,i0,i1,16),precompute(A(i,j)*x(j),i,i)", |
| 55 | + { { "split", "i", "i0", "i1", "16" }, |
| 56 | + { "precompute", "A(i,j)*x(j)", "i", "i" } } }, |
| 57 | + }; |
| 58 | + for(auto test : cases) { |
| 59 | + auto actual = ScheduleParser(test.str); |
| 60 | + cout << "string \"" << test.str << "\"" << " parsed as: " << serializeParsedSchedule(actual) << endl; |
| 61 | + assert_string_vector_vectors_equal(test.result, actual); |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +TEST(schedule_parser, error_reporting) { |
| 66 | + struct { |
| 67 | + string str; |
| 68 | + string assertion; |
| 69 | + } cases[] = { |
| 70 | + { "i,j,k(", "too few right-parens" }, |
| 71 | + { "i(j,k", "too few right-parens" }, |
| 72 | + { "i,j,k)", "too many right-parens" }, |
| 73 | + { "i,j,k)(", "too many right-parens" }, |
| 74 | + }; |
| 75 | + for(auto test : cases) { |
| 76 | + try { |
| 77 | + auto actual = ScheduleParser(test.str); |
| 78 | + // should throw an exception before getting here |
| 79 | + ASSERT_TRUE(false); |
| 80 | + } catch (taco::TacoException &e) { |
| 81 | + string message = e.what(); |
| 82 | + EXPECT_TRUE(message.find(test.assertion) != string::npos) |
| 83 | + << "substring \"" << test.assertion << "\" not found in exception message \"" << message << "\""; |
| 84 | + } |
| 85 | + } |
| 86 | +} |
0 commit comments