Skip to content

Commit 94dcdef

Browse files
refactor(pumpkin-core): Explicitly register predicates in NogoodPropagator (#332)
1 parent 0f63882 commit 94dcdef

File tree

11 files changed

+85
-120
lines changed

11 files changed

+85
-120
lines changed

pumpkin-crates/core/src/engine/cp/propagation/constructor.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,12 @@ impl PropagatorConstructorContext<'_> {
9797
/// Returns the [`PredicateId`] used by the solver to track the predicate.
9898
#[allow(unused, reason = "will become public API")]
9999
pub(crate) fn register_predicate(&mut self, predicate: Predicate) -> PredicateId {
100-
self.state
101-
.notification_engine
102-
.watch_predicate(predicate, self.propagator_id)
100+
self.state.notification_engine.watch_predicate(
101+
predicate,
102+
self.propagator_id,
103+
&mut self.state.trailed_values,
104+
&self.state.assignments,
105+
)
103106
}
104107

105108
/// Subscribes the propagator to the given [`DomainEvents`] when they are undone during

pumpkin-crates/core/src/engine/cp/propagation/contexts/propagation_context.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use crate::engine::TrailedValues;
77
use crate::engine::notifications::NotificationEngine;
88
use crate::engine::notifications::PredicateIdAssignments;
99
use crate::engine::predicates::predicate::Predicate;
10+
#[cfg(doc)]
11+
use crate::engine::propagation::Propagator;
1012
use crate::engine::propagation::PropagatorId;
1113
use crate::engine::reason::Reason;
1214
use crate::engine::reason::ReasonStore;
@@ -97,6 +99,19 @@ impl<'a> PropagationContextMut<'a> {
9799
self.notification_engine.get_id(predicate)
98100
}
99101

102+
/// Register the propagator to be enqueued when the provided [`Predicate`] becomes true.
103+
///
104+
/// Returns the [`PredicateId`] assigned to the provided predicate, which will be provided
105+
/// to [`Propagator::notify_predicate_satisfied`].
106+
pub(crate) fn register_predicate(&mut self, predicate: Predicate) -> PredicateId {
107+
self.notification_engine.watch_predicate(
108+
predicate,
109+
self.propagator_id,
110+
self.trailed_values,
111+
self.assignments,
112+
)
113+
}
114+
100115
/// Apply a reification literal to all the explanations that are passed to the context.
101116
pub(crate) fn with_reification(&mut self, reification_literal: Literal) {
102117
pumpkin_assert_simple!(

pumpkin-crates/core/src/engine/cp/propagation/propagator.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ pub(crate) trait Propagator: Downcast + DynClone {
125125

126126
/// Called when a [`PredicateId`] has been satisfied.
127127
///
128-
/// By default, the propagator does nothing when this method is called.
129-
fn notify_predicate_id_satisfied(&mut self, _predicate_id: PredicateId) {}
128+
/// By default, the propagator will be enqueued.
129+
fn notify_predicate_id_satisfied(&mut self, _predicate_id: PredicateId) -> EnqueueDecision {
130+
EnqueueDecision::Enqueue
131+
}
130132

131133
/// Called each time the [`ConstraintSatisfactionSolver`] backtracks, the propagator can then
132134
/// update its internal data structures given the new variable domains.

pumpkin-crates/core/src/engine/cp/propagation/store.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,6 @@ impl PropagatorStore {
9898
None
9999
}
100100
}
101-
102-
#[cfg(test)]
103-
pub(crate) fn keys(&self) -> impl Iterator<Item = PropagatorId> + '_ {
104-
self.propagators.keys()
105-
}
106101
}
107102

108103
impl Index<PropagatorId> for PropagatorStore {

pumpkin-crates/core/src/engine/cp/propagator_queue.rs

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@ use std::cmp::Reverse;
22
use std::collections::BinaryHeap;
33
use std::collections::VecDeque;
44

5-
use crate::containers::HashSet;
5+
use crate::containers::KeyedVec;
66
use crate::engine::cp::propagation::PropagatorId;
77
use crate::pumpkin_assert_moderate;
88

99
#[derive(Debug, Clone)]
1010
pub(crate) struct PropagatorQueue {
1111
queues: Vec<VecDeque<PropagatorId>>,
12-
present_propagators: HashSet<PropagatorId>,
12+
is_enqueued: KeyedVec<PropagatorId, bool>,
13+
num_enqueued: usize,
1314
present_priorities: BinaryHeap<Reverse<u32>>,
1415
}
1516

@@ -23,29 +24,28 @@ impl PropagatorQueue {
2324
pub(crate) fn new(num_priority_levels: u32) -> PropagatorQueue {
2425
PropagatorQueue {
2526
queues: vec![VecDeque::new(); num_priority_levels as usize],
26-
present_propagators: HashSet::default(),
27+
is_enqueued: KeyedVec::default(),
28+
num_enqueued: 0,
2729
present_priorities: BinaryHeap::new(),
2830
}
2931
}
3032

3133
pub(crate) fn is_empty(&self) -> bool {
32-
self.present_propagators.is_empty()
33-
}
34-
35-
#[cfg(test)]
36-
pub(crate) fn is_propagator_present(&self, propagator_id: PropagatorId) -> bool {
37-
self.present_propagators.contains(&propagator_id)
34+
self.num_enqueued == 0
3835
}
3936

4037
pub(crate) fn enqueue_propagator(&mut self, propagator_id: PropagatorId, priority: u32) {
4138
pumpkin_assert_moderate!((priority as usize) < self.queues.len());
4239

4340
if !self.is_propagator_enqueued(propagator_id) {
41+
self.is_enqueued.accomodate(propagator_id, false);
42+
self.is_enqueued[propagator_id] = true;
43+
self.num_enqueued += 1;
44+
4445
if self.queues[priority as usize].is_empty() {
4546
self.present_priorities.push(Reverse(priority));
4647
}
4748
self.queues[priority as usize].push_back(propagator_id);
48-
let _ = self.present_propagators.insert(propagator_id);
4949
}
5050
}
5151

@@ -59,13 +59,15 @@ impl PropagatorQueue {
5959

6060
let next_propagator_id = self.queues[top_priority].pop_front();
6161

62-
next_propagator_id.iter().for_each(|next_propagator_id| {
63-
let _ = self.present_propagators.remove(next_propagator_id);
62+
if let Some(propagator_id) = next_propagator_id {
63+
self.is_enqueued[propagator_id] = false;
6464

6565
if self.queues[top_priority].is_empty() {
6666
let _ = self.present_priorities.pop();
6767
}
68-
});
68+
}
69+
70+
self.num_enqueued -= 1;
6971

7072
next_propagator_id
7173
}
@@ -76,11 +78,19 @@ impl PropagatorQueue {
7678
pumpkin_assert_moderate!(!self.queues[priority].is_empty());
7779
self.queues[priority].clear();
7880
}
79-
self.present_propagators.clear();
81+
82+
for is_propagator_enqueued in self.is_enqueued.iter_mut() {
83+
*is_propagator_enqueued = false;
84+
}
85+
8086
self.present_priorities.clear();
87+
self.num_enqueued = 0;
8188
}
8289

83-
fn is_propagator_enqueued(&self, propagator_id: PropagatorId) -> bool {
84-
self.present_propagators.contains(&propagator_id)
90+
pub(crate) fn is_propagator_enqueued(&self, propagator_id: PropagatorId) -> bool {
91+
self.is_enqueued
92+
.get(propagator_id)
93+
.copied()
94+
.unwrap_or_default()
8595
}
8696
}

pumpkin-crates/core/src/engine/cp/test_solver.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ impl TestSolver {
110110
&mut self.state.propagators,
111111
&mut propagator_queue,
112112
);
113-
if propagator_queue.is_propagator_present(propagator) {
113+
if propagator_queue.is_propagator_enqueued(propagator) {
114114
EnqueueDecision::Enqueue
115115
} else {
116116
EnqueueDecision::Skip
@@ -138,7 +138,7 @@ impl TestSolver {
138138
&mut self.state.propagators,
139139
&mut propagator_queue,
140140
);
141-
if propagator_queue.is_propagator_present(propagator) {
141+
if propagator_queue.is_propagator_enqueued(propagator) {
142142
EnqueueDecision::Enqueue
143143
} else {
144144
EnqueueDecision::Skip
@@ -166,7 +166,7 @@ impl TestSolver {
166166
&mut self.state.propagators,
167167
&mut propagator_queue,
168168
);
169-
if propagator_queue.is_propagator_present(propagator) {
169+
if propagator_queue.is_propagator_enqueued(propagator) {
170170
EnqueueDecision::Enqueue
171171
} else {
172172
EnqueueDecision::Skip

pumpkin-crates/core/src/engine/notifications/mod.rs

Lines changed: 14 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,8 @@ use crate::engine::propagation::PropagatorId;
2424
use crate::engine::propagation::contexts::PropagationContextWithTrailedValues;
2525
use crate::engine::propagation::store::PropagatorStore;
2626
use crate::predicates::Predicate;
27-
use crate::propagators::nogoods::NogoodPropagator;
2827
use crate::pumpkin_assert_extreme;
2928
use crate::pumpkin_assert_simple;
30-
use crate::state::PropagatorHandle;
3129
use crate::variables::DomainId;
3230

3331
#[derive(Debug, Clone)]
@@ -142,13 +140,18 @@ impl NotificationEngine {
142140
&mut self,
143141
predicate: Predicate,
144142
propagator_id: PropagatorId,
143+
trailed_values: &mut TrailedValues,
144+
assignments: &Assignments,
145145
) -> PredicateId {
146146
let predicate_id = self.get_id(predicate);
147147

148148
self.watch_list_predicate_id
149149
.accomodate(predicate_id, vec![]);
150150
self.watch_list_predicate_id[predicate_id].push(propagator_id);
151151

152+
self.predicate_notifier
153+
.track_predicate(predicate_id, trailed_values, assignments);
154+
152155
predicate_id
153156
}
154157

@@ -264,7 +267,6 @@ impl NotificationEngine {
264267
assignments: &mut Assignments,
265268
trailed_values: &mut TrailedValues,
266269
propagators: &mut PropagatorStore,
267-
nogood_propagator_handle: PropagatorHandle<NogoodPropagator>,
268270
propagator_queue: &mut PropagatorQueue,
269271
) {
270272
// We first take the events because otherwise we get mutability issues when calling methods
@@ -275,17 +277,6 @@ impl NotificationEngine {
275277
// First we notify the predicate_notifier that a domain has been updated
276278
self.predicate_notifier
277279
.on_update(trailed_values, assignments, event, domain);
278-
// Special case: the nogood propagator is notified about each event.
279-
Self::notify_nogood_propagator(
280-
nogood_propagator_handle,
281-
&mut self.predicate_notifier.predicate_id_assignments,
282-
event,
283-
domain,
284-
propagators,
285-
propagator_queue,
286-
assignments,
287-
trailed_values,
288-
);
289280
// Now notify other propagators subscribed to this event.
290281
#[allow(clippy::unnecessary_to_owned, reason = "Not unnecessary?")]
291282
for propagator_var in self
@@ -310,7 +301,7 @@ impl NotificationEngine {
310301
self.events = events;
311302

312303
// Then we notify the propagators that a predicate has been satisfied.
313-
self.notify_predicate_id_satisfied(nogood_propagator_handle, propagators);
304+
self.notify_predicate_id_satisfied(propagators, propagator_queue);
314305

315306
self.last_notified_trail_index = assignments.num_trail_entries();
316307
}
@@ -348,50 +339,25 @@ impl NotificationEngine {
348339
/// Notifies the propagator that certain [`Predicate`]s have been satisfied.
349340
fn notify_predicate_id_satisfied(
350341
&mut self,
351-
nogood_propagator_handle: PropagatorHandle<NogoodPropagator>,
352342
propagators: &mut PropagatorStore,
343+
propagator_queue: &mut PropagatorQueue,
353344
) {
354345
for predicate_id in self.predicate_notifier.drain_satisfied_predicates() {
355346
if let Some(watch_list) = self.watch_list_predicate_id.get(predicate_id) {
356347
let propagators_to_notify = watch_list.iter().copied();
357348

358349
for propagator_id in propagators_to_notify {
359-
propagators[propagator_id].notify_predicate_id_satisfied(predicate_id);
350+
let propagator = &mut propagators[propagator_id];
351+
let enqueue_decision = propagator.notify_predicate_id_satisfied(predicate_id);
352+
353+
if enqueue_decision == EnqueueDecision::Enqueue {
354+
propagator_queue.enqueue_propagator(propagator_id, propagator.priority());
355+
}
360356
}
361357
}
362-
363-
propagators[nogood_propagator_handle.propagator_id()]
364-
.notify_predicate_id_satisfied(predicate_id);
365358
}
366359
}
367360

368-
#[allow(clippy::too_many_arguments, reason = "to be refactored later")]
369-
fn notify_nogood_propagator(
370-
nogood_propagator_id: PropagatorHandle<NogoodPropagator>,
371-
predicate_id_assignments: &mut PredicateIdAssignments,
372-
event: DomainEvent,
373-
domain: DomainId,
374-
propagators: &mut PropagatorStore,
375-
propagator_queue: &mut PropagatorQueue,
376-
assignments: &mut Assignments,
377-
trailed_values: &mut TrailedValues,
378-
) {
379-
// The nogood propagator is implicitly subscribed to every domain event for every variable.
380-
// For this reason, its local id matches the domain id.
381-
// This is special only for the nogood propagator.
382-
let local_id = LocalId::from(domain.id());
383-
Self::notify_propagator(
384-
predicate_id_assignments,
385-
nogood_propagator_id.propagator_id(),
386-
local_id,
387-
event,
388-
propagators,
389-
propagator_queue,
390-
assignments,
391-
trailed_values,
392-
);
393-
}
394-
395361
#[allow(clippy::too_many_arguments, reason = "Should be refactored")]
396362
fn notify_propagator(
397363
predicate_id_assignments: &mut PredicateIdAssignments,
@@ -425,16 +391,6 @@ impl NotificationEngine {
425391
let _ = self.events.drain();
426392
}
427393

428-
pub(crate) fn track_predicate(
429-
&mut self,
430-
predicate: PredicateId,
431-
trailed_values: &mut TrailedValues,
432-
assignments: &Assignments,
433-
) {
434-
self.predicate_notifier
435-
.track_predicate(predicate, trailed_values, assignments)
436-
}
437-
438394
#[cfg(test)]
439395
pub(crate) fn drain_backtrack_domain_events(
440396
&mut self,
@@ -459,12 +415,6 @@ impl NotificationEngine {
459415
propagators: &mut PropagatorStore,
460416
propagator_queue: &mut PropagatorQueue,
461417
) {
462-
// There may be a nogood propagator in the store. In that case we need to always
463-
// notify it.
464-
let nogood_propagator_handle = propagators
465-
.keys()
466-
.find_map(|id| propagators.as_propagator_handle::<NogoodPropagator>(id));
467-
468418
// Collect so that we can pass the assignments to the methods within the loop
469419
for (event, domain) in self.events.drain().collect::<Vec<_>>() {
470420
// First we notify the predicate_notifier that a domain has been updated
@@ -493,10 +443,7 @@ impl NotificationEngine {
493443
}
494444
}
495445

496-
if let Some(handle) = nogood_propagator_handle {
497-
// Then we notify the propagators that a predicate has been satisfied.
498-
self.notify_predicate_id_satisfied(handle, propagators);
499-
}
446+
self.notify_predicate_id_satisfied(propagators, propagator_queue);
500447

501448
self.last_notified_trail_index = assignments.num_trail_entries();
502449
}

pumpkin-crates/core/src/engine/state.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,6 @@ impl State {
573573
&mut self.assignments,
574574
&mut self.trailed_values,
575575
&mut self.propagators,
576-
PropagatorHandle::new(PropagatorId(0)),
577576
&mut self.propagator_queue,
578577
);
579578
pumpkin_assert_extreme!(
@@ -628,7 +627,6 @@ impl State {
628627
&mut self.assignments,
629628
&mut self.trailed_values,
630629
&mut self.propagators,
631-
PropagatorHandle::new(PropagatorId(0)),
632630
&mut self.propagator_queue,
633631
);
634632

0 commit comments

Comments
 (0)