Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 62 additions & 26 deletions src/machine.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::*;
use std::result;

type Result = result::Result<(), ()>;

#[derive(Default)]
struct Machine {
Expand Down Expand Up @@ -31,13 +34,21 @@ enum ENodeOrReg<L> {
}

#[inline(always)]
fn for_each_matching_node<L, D>(eclass: &EClass<L, D>, node: &L, mut f: impl FnMut(&L))
fn for_each_matching_node<L, D>(
eclass: &EClass<L, D>,
node: &L,
mut f: impl FnMut(&L) -> Result,
) -> Result
where
L: Language,
{
#[allow(enum_intrinsics_non_enums)]
if eclass.nodes.len() < 50 {
eclass.nodes.iter().filter(|n| node.matches(n)).for_each(f)
eclass
.nodes
.iter()
.filter(|n| node.matches(n))
.try_for_each(f)
} else {
debug_assert!(node.all(|id| id == Id::from(0)));
debug_assert!(eclass.nodes.windows(2).all(|w| w[0] < w[1]));
Expand All @@ -50,7 +61,7 @@ where
break;
}
}
let matching = eclass.nodes[start..]
let mut matching = eclass.nodes[start..]
.iter()
.take_while(|&n| std::mem::discriminant(n) == discrim)
.filter(|n| node.matches(n));
Expand All @@ -68,7 +79,7 @@ where
.collect::<HashSet<_>>(),
eclass.nodes
);
matching.for_each(&mut f);
matching.try_for_each(&mut f)
}
}

Expand All @@ -83,8 +94,9 @@ impl Machine {
egraph: &EGraph<L, N>,
instructions: &[Instruction<L>],
subst: &Subst,
yield_fn: &mut impl FnMut(&Self, &Subst),
) where
yield_fn: &mut impl FnMut(&Self, &Subst) -> Result,
) -> Result
where
L: Language,
N: Analysis<L>,
{
Expand All @@ -104,13 +116,13 @@ impl Machine {
for class in egraph.classes() {
self.reg.truncate(out.0 as usize);
self.reg.push(class.id);
self.run(egraph, remaining_instructions, subst, yield_fn)
self.run(egraph, remaining_instructions, subst, yield_fn)?
}
return;
return Ok(());
}
Instruction::Compare { i, j } => {
if egraph.find(self.reg(*i)) != egraph.find(self.reg(*j)) {
return;
return Ok(());
}
}
Instruction::Lookup { term, i } => {
Expand All @@ -121,7 +133,7 @@ impl Machine {
let look = |i| self.lookup[usize::from(i)];
match egraph.lookup(node.clone().map_children(look)) {
Some(id) => self.lookup.push(id),
None => return,
None => return Ok(()),
}
}
ENodeOrReg::Reg(r) => {
Expand All @@ -132,7 +144,7 @@ impl Machine {

let id = egraph.find(self.reg(*i));
if self.lookup.last().copied() != Some(id) {
return;
return Ok(());
}
}
}
Expand Down Expand Up @@ -334,27 +346,51 @@ impl<L: Language> Program<L> {
where
A: Analysis<L>,
{
let mut machine = Machine::default();
self.run_with_limit(egraph, eclass, usize::MAX)
}

pub fn run_with_limit<A>(
&self,
egraph: &EGraph<L, A>,
eclass: Id,
mut limit: usize,
) -> Vec<Subst>
where
A: Analysis<L>,
{
assert!(egraph.clean, "Tried to search a dirty e-graph!");

if limit == 0 {
return vec![];
}

let mut machine = Machine::default();
assert_eq!(machine.reg.len(), 0);
machine.reg.push(eclass);

let mut matches = Vec::new();
machine.run(
egraph,
&self.instructions,
&self.subst,
&mut |machine, subst| {
let subst_vec = subst
.vec
.iter()
// HACK we are reusing Ids here, this is bad
.map(|(v, reg_id)| (*v, machine.reg(Reg(usize::from(*reg_id) as u32))))
.collect();
matches.push(Subst { vec: subst_vec });
},
);
machine
.run(
egraph,
&self.instructions,
&self.subst,
&mut |machine, subst| {
let subst_vec = subst
.vec
.iter()
// HACK we are reusing Ids here, this is bad
.map(|(v, reg_id)| (*v, machine.reg(Reg(usize::from(*reg_id) as u32))))
.collect();
matches.push(Subst { vec: subst_vec });
limit -= 1;
if limit != 0 {
Ok(())
} else {
Err(())
}
},
)
.unwrap_or_default();

log::trace!("Ran program, found {:?}", matches);
matches
Expand Down
65 changes: 54 additions & 11 deletions src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,28 +262,37 @@ impl<L: Language, A: Analysis<L>> Searcher<L, A> for Pattern<L> {
Some(&self.ast)
}

fn search(&self, egraph: &EGraph<L, A>) -> Vec<SearchMatches<L>> {
fn search_with_limit(&self, egraph: &EGraph<L, A>, limit: usize) -> Vec<SearchMatches<L>> {
match self.ast.as_ref().last().unwrap() {
ENodeOrVar::ENode(e) => {
#[allow(enum_intrinsics_non_enums)]
let key = std::mem::discriminant(e);
match egraph.classes_by_op.get(&key) {
None => vec![],
Some(ids) => ids
.iter()
.filter_map(|&id| self.search_eclass(egraph, id))
.collect(),
Some(ids) => rewrite::search_eclasses_with_limit(
self,
egraph,
ids.iter().cloned(),
limit,
),
}
}
ENodeOrVar::Var(_) => egraph
.classes()
.filter_map(|e| self.search_eclass(egraph, e.id))
.collect(),
ENodeOrVar::Var(_) => rewrite::search_eclasses_with_limit(
self,
egraph,
egraph.classes().map(|e| e.id),
limit,
),
}
}

fn search_eclass(&self, egraph: &EGraph<L, A>, eclass: Id) -> Option<SearchMatches<L>> {
let substs = self.program.run(egraph, eclass);
fn search_eclass_with_limit(
&self,
egraph: &EGraph<L, A>,
eclass: Id,
limit: usize,
) -> Option<SearchMatches<L>> {
let substs = self.program.run_with_limit(egraph, eclass, limit);
if substs.is_empty() {
None
} else {
Expand Down Expand Up @@ -467,4 +476,38 @@ mod tests {
assert_eq!(n_matches("(f ?x (g ?x))))"), 1);
assert_eq!(n_matches("(h ?x 0 0)"), 1);
}

#[test]
fn search_with_limit() {
crate::init_logger();
let init_expr = &"(+ 1 (+ 2 (+ 3 (+ 4 (+ 5 6)))))".parse().unwrap();
let rules: Vec<Rewrite<_, ()>> = vec![
rewrite!("comm"; "(+ ?x ?y)" => "(+ ?y ?x)"),
rewrite!("assoc"; "(+ ?x (+ ?y ?z))" => "(+ (+ ?x ?y) ?z)"),
];
let runner = Runner::default().with_expr(init_expr).run(&rules);
let egraph = &runner.egraph;

let len = |m: &Vec<SearchMatches<S>>| -> usize { m.iter().map(|m| m.substs.len()).sum() };

let pat = &"(+ ?x (+ ?y ?z))".parse::<Pattern<S>>().unwrap();
let m = pat.search(egraph);
let match_size = 2100;
assert_eq!(len(&m), match_size);

for limit in [1, 10, 100, 1000, 10000] {
let m = pat.search_with_limit(egraph, limit);
assert_eq!(len(&m), usize::min(limit, match_size));
}

let id = egraph.lookup_expr(init_expr).unwrap();
let m = pat.search_eclass(egraph, id).unwrap();
let match_size = 540;
assert_eq!(m.substs.len(), match_size);

for limit in [1, 10, 100, 1000] {
let m1 = pat.search_eclass_with_limit(egraph, id, limit).unwrap();
assert_eq!(m1.substs.len(), usize::min(limit, match_size));
}
}
}
74 changes: 73 additions & 1 deletion src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ impl<L: Language, N: Analysis<L>> Rewrite<L, N> {
self.searcher.search(egraph)
}

/// Call [`search_with_limit`] on the [`Searcher`].
///
/// [`search_with_limit`]: Searcher::search_with_limit()
pub fn search_with_limit(&self, egraph: &EGraph<L, N>, limit: usize) -> Vec<SearchMatches<L>> {
self.searcher.search_with_limit(egraph, limit)
}

/// Call [`apply_matches`] on the [`Applier`].
///
/// [`apply_matches`]: Applier::apply_matches()
Expand Down Expand Up @@ -115,6 +122,37 @@ impl<L: Language, N: Analysis<L>> Rewrite<L, N> {
}
}

/// Searches the given list of e-classes with a limit.
pub(crate) fn search_eclasses_with_limit<'a, I, S, L, N>(
searcher: &'a S,
egraph: &EGraph<L, N>,
eclasses: I,
mut limit: usize,
) -> Vec<SearchMatches<'a, L>>
where
L: Language,
N: Analysis<L>,
S: Searcher<L, N> + ?Sized,
I: IntoIterator<Item = Id>,
{
let mut ms = vec![];
for eclass in eclasses {
if limit == 0 {
break;
}
match searcher.search_eclass_with_limit(egraph, eclass, limit) {
None => continue,
Some(m) => {
let len = m.substs.len();
assert!(len <= limit);
limit -= len;
ms.push(m);
}
}
}
ms
}

/// The lefthand side of a [`Rewrite`].
///
/// A [`Searcher`] is something that can search the egraph and find
Expand All @@ -128,7 +166,34 @@ where
{
/// Search one eclass, returning None if no matches can be found.
/// This should not return a SearchMatches with no substs.
fn search_eclass(&self, egraph: &EGraph<L, N>, eclass: Id) -> Option<SearchMatches<L>>;
///
/// Implementation of [`Searcher`] should implement one of
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I appreciate the effort to not break the API, but this could be a little confusing because the "default" behavior is an infinite loop! I think we should require search_with_limit.

/// [`search_eclass`] or [`search_eclass_with_limit`].
///
/// [`search_eclass`]: Searcher::search_eclass
/// [`search_eclass_with_limit`]: Searcher::search_eclass_with_limit
fn search_eclass(&self, egraph: &EGraph<L, N>, eclass: Id) -> Option<SearchMatches<L>> {
self.search_eclass_with_limit(egraph, eclass, usize::MAX)
}

/// Similar to [`search_eclass`], but return at most `limit` many matches.
///
/// Implementation of [`Searcher`] should implement one of
/// [`search_eclass`] or [`search_eclass_with_limit`].
///
/// [`search_eclass`]: Searcher::search_eclass
/// [`search_eclass_with_limit`]: Searcher::search_eclass_with_limit
fn search_eclass_with_limit(
&self,
egraph: &EGraph<L, N>,
eclass: Id,
limit: usize,
) -> Option<SearchMatches<L>> {
self.search_eclass(egraph, eclass).map(|mut m| {
m.substs.truncate(limit);
m
})
}

/// Search the whole [`EGraph`], returning a list of all the
/// [`SearchMatches`] where something was found.
Expand All @@ -142,6 +207,13 @@ where
.collect()
}

/// Similar to [`search`], but return at most `limit` many matches.
///
/// [`search`]: Searcher::search
fn search_with_limit(&self, egraph: &EGraph<L, N>, limit: usize) -> Vec<SearchMatches<L>> {
search_eclasses_with_limit(self, egraph, egraph.classes().map(|e| e.id), limit)
}

/// Returns the number of matches in the e-graph
fn n_matches(&self, egraph: &EGraph<L, N>) -> usize {
self.search(egraph).iter().map(|m| m.substs.len()).sum()
Expand Down
4 changes: 2 additions & 2 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -860,9 +860,9 @@ where
return vec![];
}

let matches = rewrite.search(egraph);
let total_len: usize = matches.iter().map(|m| m.substs.len()).sum();
let threshold = stats.match_limit << stats.times_banned;
let matches = rewrite.search_with_limit(egraph, threshold + 1);
let total_len: usize = matches.iter().map(|m| m.substs.len()).sum();
if total_len > threshold {
let ban_length = stats.ban_length << stats.times_banned;
stats.times_banned += 1;
Expand Down