Skip to content

Commit 9ac6b24

Browse files
authored
Order generators with respect to their dependencies (#2458)
fix #2416
1 parent f766700 commit 9ac6b24

File tree

15 files changed

+185
-10
lines changed

15 files changed

+185
-10
lines changed

src/ast2ram/seminaive/ClauseTranslator.cpp

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -774,9 +774,6 @@ void ClauseTranslator::indexAggregatorBody(const ast::Aggregator& agg) {
774774
}
775775

776776
void ClauseTranslator::indexAggregators(const ast::Clause& clause) {
777-
// Add each aggregator as an internal generator
778-
visit(clause, [&](const ast::Aggregator& agg) { indexGenerator(agg); });
779-
780777
// Index aggregator bodies
781778
visit(clause, [&](const ast::Aggregator& agg) { indexAggregatorBody(agg); });
782779

@@ -791,13 +788,6 @@ void ClauseTranslator::indexAggregators(const ast::Clause& clause) {
791788
}
792789

793790
void ClauseTranslator::indexMultiResultFunctors(const ast::Clause& clause) {
794-
// Add each multi-result functor as an internal generator
795-
visit(clause, [&](const ast::IntrinsicFunctor& func) {
796-
if (ast::analysis::FunctorAnalysis::isMultiResult(func)) {
797-
indexGenerator(func);
798-
}
799-
});
800-
801791
// Add multi-result functor value introductions
802792
visit(clause, [&](const ast::BinaryConstraint& bc) {
803793
if (!isEqConstraint(bc.getBaseOperator())) return;
@@ -809,8 +799,94 @@ void ClauseTranslator::indexMultiResultFunctors(const ast::Clause& clause) {
809799
});
810800
}
811801

802+
void ClauseTranslator::indexGenerators(const ast::Clause& clause) {
803+
// generators must be indexed in topological order according to
804+
// dependencies between them.
805+
// see issue #2416
806+
807+
std::map<std::string, const ast::Argument*> varGenerator;
808+
809+
visit(clause, [&](const ast::BinaryConstraint& bc) {
810+
if (!isEqConstraint(bc.getBaseOperator())) return;
811+
const auto* lhs = as<ast::Variable>(bc.getLHS());
812+
const ast::Argument* rhs = as<ast::IntrinsicFunctor>(bc.getRHS());
813+
if (rhs == nullptr) {
814+
rhs = as<ast::Aggregator>(bc.getRHS());
815+
} else {
816+
if (!ast::analysis::FunctorAnalysis::isMultiResult(*as<ast::IntrinsicFunctor>(rhs))) return;
817+
}
818+
if (lhs == nullptr || rhs == nullptr) return;
819+
varGenerator[lhs->getName()] = rhs;
820+
});
821+
822+
// all the generators in the clause
823+
std::vector<const ast::Argument*> generators;
824+
825+
// 'predecessor' mapping from a generator to the generators that must
826+
// evaluate before.
827+
std::multimap<const ast::Argument*, const ast::Argument*> dependencies;
828+
829+
// harvest generators and dependencies
830+
visit(clause, [&](const ast::Argument& arg) {
831+
if (const ast::IntrinsicFunctor* func = as<ast::IntrinsicFunctor>(arg)) {
832+
if (ast::analysis::FunctorAnalysis::isMultiResult(*func)) {
833+
generators.emplace_back(func);
834+
visit(func, [&](const ast::Variable& use) {
835+
if (varGenerator.count(use.getName()) > 0) {
836+
dependencies.emplace(func, varGenerator.at(use.getName()));
837+
}
838+
});
839+
}
840+
} else if (const ast::Aggregator* agg = as<ast::Aggregator>(arg)) {
841+
generators.emplace_back(agg);
842+
visit(agg, [&](const ast::Variable& use) {
843+
if (varGenerator.count(use.getName()) > 0) {
844+
dependencies.emplace(agg, varGenerator.at(use.getName()));
845+
}
846+
});
847+
}
848+
});
849+
850+
// the set of already indexed generators
851+
std::set<const ast::Argument*> indexed;
852+
// the recursion stack to detect a cycle in the depth-first traversal
853+
std::set<const ast::Argument*> recStack;
854+
855+
// recursive depth-first traversal, perform a post-order indexing of genertors.
856+
const std::function<void(const ast::Argument*)> dfs = [&](const ast::Argument* reached) {
857+
if (indexed.count(reached)) {
858+
return;
859+
}
860+
861+
if (!recStack.emplace(reached).second) {
862+
// cycle detected
863+
fatal("cyclic dependency");
864+
}
865+
866+
auto range = dependencies.equal_range(reached);
867+
for (auto it = range.first; it != range.second; ++it) {
868+
if (it->second == reached) {
869+
continue; // ignore self-dependency
870+
}
871+
dfs(it->second);
872+
}
873+
874+
// index this generator
875+
indexGenerator(*reached);
876+
877+
indexed.insert(reached);
878+
recStack.erase(reached);
879+
};
880+
881+
// topological sorting by depth-first search
882+
for (const ast::Argument* root : generators) {
883+
dfs(root);
884+
}
885+
}
886+
812887
void ClauseTranslator::indexClause(const ast::Clause& clause) {
813888
indexAtoms(clause);
889+
indexGenerators(clause);
814890
indexAggregators(clause);
815891
indexMultiResultFunctors(clause);
816892
}

src/ast2ram/seminaive/ClauseTranslator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class ClauseTranslator : public ast2ram::ClauseTranslator {
8787

8888
/** Indexing */
8989
void indexClause(const ast::Clause& clause);
90+
void indexGenerators(const ast::Clause& clause);
9091
virtual void indexAtoms(const ast::Clause& clause);
9192
void indexAggregators(const ast::Clause& clause);
9293
void indexMultiResultFunctors(const ast::Clause& clause);

tests/semantic/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,5 @@ positive_output_stdout_test(output_stdout)
272272
positive_test(iteration_counter)
273273
positive_test(issue1896)
274274
positive_test(comp_params)
275+
positive_test(issue2416)
276+
positive_test(agg_range)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
.decl a(x: number, y: number, z: number)
2+
a(1, 2, 0).
3+
a(3, 4, 1).
4+
5+
6+
.decl b(item: number)
7+
b(1).
8+
b(2).
9+
b(3).
10+
b(4).
11+
12+
.decl c(x: number, y: number)
13+
.output c()
14+
c(x, y) :-
15+
mr = count : { a(_, y, _) },
16+
x = range(0, mr),
17+
b(y).
18+
19+
.decl d(x: number, y: number)
20+
.output d()
21+
d(x,y) :-
22+
y = range(z, z+3),
23+
z = v*4,
24+
v = x,
25+
x = range(0, 4).
26+
27+
.decl e(x: number, y: number)
28+
.output e()
29+
e(x,y) :-
30+
y = count : { a(_, x, _) },
31+
x = range(0, 5).
32+
33+
.decl f(m: number, n: number, x:number)
34+
.output f()
35+
f(m,n,x) :-
36+
x = range(m,n),
37+
m = range(1,3),
38+
n = range(4,6).
39+
40+
.decl g(x:number)
41+
.output g()
42+
g(x) :-
43+
x = count : { a(_, _, x) },
44+
x = range(0,3).

tests/semantic/agg_range/agg_range.err

Whitespace-only changes.

tests/semantic/agg_range/agg_range.out

Whitespace-only changes.

tests/semantic/agg_range/c.csv

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
0 2
2+
0 4

tests/semantic/agg_range/d.csv

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
0 0
2+
0 1
3+
0 2
4+
1 4
5+
1 5
6+
1 6
7+
2 8
8+
2 9
9+
2 10
10+
3 12
11+
3 13
12+
3 14

tests/semantic/agg_range/e.csv

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
0 0
2+
1 0
3+
2 1
4+
3 0
5+
4 1

tests/semantic/agg_range/f.csv

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
1 4 1
2+
1 4 2
3+
1 4 3
4+
1 5 1
5+
1 5 2
6+
1 5 3
7+
1 5 4
8+
2 4 2
9+
2 4 3
10+
2 5 2
11+
2 5 3
12+
2 5 4

0 commit comments

Comments
 (0)