Skip to content

Commit db8b568

Browse files
authored
Skip eclasses that do not contain eclass with operator
Use an egraph-wide map to remember which eclasses contain a particular operator. PR #21
2 parents 80958e3 + e4d7155 commit db8b568

File tree

5 files changed

+40
-11
lines changed

5 files changed

+40
-11
lines changed

src/egraph.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ pub struct EGraph<L, M> {
127127
memo: IndexMap<ENode<L>, Id>,
128128
classes: UnionFind<Id, EClass<L, M>>,
129129
unions_since_rebuild: usize,
130+
pub(crate) classes_by_op: IndexMap<(L, usize), Vec<Id>>,
130131
}
131132

132133
// manual debug impl to avoid L: Language bound on EGraph defn
@@ -147,6 +148,7 @@ impl<L, M> Default for EGraph<L, M> {
147148
memo: IndexMap::default(),
148149
classes: UnionFind::default(),
149150
unions_since_rebuild: 0,
151+
classes_by_op: IndexMap::default(),
150152
}
151153
}
152154
}
@@ -327,6 +329,7 @@ impl<L: Language, M: Metadata<L>> EGraph<L, M> {
327329
#[cfg(feature = "parent-pointers")]
328330
parents: IndexSet::new(),
329331
};
332+
330333
M::modify(&mut class);
331334
let next_id = self.classes.make_set(class);
332335
trace!("Added {:4}: {:?}", next_id, enode);
@@ -406,9 +409,11 @@ impl<L: Language, M: Metadata<L>> EGraph<L, M> {
406409
}
407410

408411
fn rebuild_classes(&mut self) -> usize {
412+
let (find, mut_values) = self.classes.split();
413+
414+
self.classes_by_op.clear();
409415
let mut trimmed = 0;
410416

411-
let (find, mut_values) = self.classes.split();
412417
for class in mut_values {
413418
let old_len = class.len();
414419

@@ -422,6 +427,19 @@ impl<L: Language, M: Metadata<L>> EGraph<L, M> {
422427

423428
class.nodes.clear();
424429
class.nodes.extend(unique);
430+
431+
let unique_op: IndexSet<(&L, usize)> = class
432+
.nodes
433+
.iter()
434+
.map(|node| (&node.op, node.children.len()))
435+
.collect();
436+
437+
for op in unique_op {
438+
self.classes_by_op
439+
.entry((op.0.clone(), op.1))
440+
.and_modify(|ids| ids.push(class.id))
441+
.or_insert(vec![class.id]);
442+
}
425443
}
426444

427445
trimmed
@@ -467,11 +485,6 @@ impl<L: Language, M: Metadata<L>> EGraph<L, M> {
467485
/// ```
468486
#[cfg(not(feature = "parent-pointers"))]
469487
pub fn rebuild(&mut self) -> usize {
470-
if self.unions_since_rebuild == 0 {
471-
info!("Skipping rebuild!");
472-
return 0;
473-
}
474-
475488
self.unions_since_rebuild = 0;
476489

477490
let old_hc_size = self.memo.len();

src/pattern.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ use crate::{
4949
/// // variable in the Pattern
5050
/// let same_add: Pattern<Math> = "(+ ?a ?a)".parse().unwrap();
5151
///
52+
/// // Rebuild before searching
53+
/// egraph.rebuild();
54+
///
5255
/// // This is the search method from the Searcher trait
5356
/// let matches = same_add.search(&egraph);
5457
/// let matched_eclasses: Vec<Id> = matches.iter().map(|m| m.eclass).collect();
@@ -144,10 +147,19 @@ where
144147
M: Metadata<L>,
145148
{
146149
fn search(&self, egraph: &EGraph<L, M>) -> Vec<SearchMatches> {
147-
egraph
148-
.classes()
149-
.filter_map(|e| self.search_eclass(egraph, e.id))
150-
.collect()
150+
match &self.ast {
151+
PatternAst::ENode(e) => {
152+
let key = (e.op.clone(), e.children.len());
153+
let ids: &[Id] = egraph.classes_by_op.get(&key).map_or(&[], Vec::as_slice);
154+
ids.iter()
155+
.filter_map(|&id| self.search_eclass(egraph, id))
156+
.collect()
157+
}
158+
PatternAst::Var(_) => egraph
159+
.classes()
160+
.filter_map(|e| self.search_eclass(egraph, e.id))
161+
.collect(),
162+
}
151163
}
152164

153165
fn search_eclass(&self, egraph: &EGraph<L, M>, eclass: Id) -> Option<SearchMatches> {

src/rewrite.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,9 @@ mod tests {
472472
"fold_add"; "(+ ?a ?b)" => { Appender }
473473
);
474474

475+
egraph.rebuild();
475476
fold_add.run(&mut egraph);
477+
egraph.rebuild();
476478
assert_eq!(egraph.equivs(&start, &goal), vec![egraph.find(root)]);
477479
}
478480
}

src/run.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ where
305305
/// [`stop_reason`](#structfield.stop_reason) is guaranteeed to be
306306
/// set.
307307
pub fn run(mut self, rules: &[Rewrite<L, M>]) -> Self {
308+
self.egraph.rebuild();
308309
// TODO check that we haven't
309310
loop {
310311
if let Err(stop_reason) = self.run_one(rules) {

tests/prop.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ fn const_fold() {
149149
let start_expr = start.parse().unwrap();
150150
let end = "false";
151151
let end_expr = end.parse().unwrap();
152-
let (eg, _) = EGraph::from_expr(&start_expr);
152+
let (mut eg, _) = EGraph::from_expr(&start_expr);
153+
eg.rebuild();
153154
assert!(!eg.equivs(&start_expr, &end_expr).is_empty());
154155
}

0 commit comments

Comments
 (0)