Skip to content

Commit 3cfa26c

Browse files
feat: implement predicate and assumption solving in python wrapper (#150)
Assumption solving in the Rust layer currently performs semantic minimisation of the core. E.g. if the assumptions contain `[y <= 1]` and `[y != 0]`, and the domain of `y` starts at 0, then the core may contain `[y == 1]` rather than the original predicates in the assumptions.
1 parent 2462937 commit 3cfa26c

File tree

5 files changed

+173
-50
lines changed

5 files changed

+173
-50
lines changed

pumpkin-py/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ macro_rules! submodule {
2828
fn pumpkin_py(python: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
2929
m.add_class::<variables::IntExpression>()?;
3030
m.add_class::<variables::BoolExpression>()?;
31-
m.add_class::<model::Comparator>()?;
31+
m.add_class::<variables::Comparator>()?;
32+
m.add_class::<variables::Predicate>()?;
3233
m.add_class::<model::Model>()?;
3334
m.add_class::<result::SatisfactionResult>()?;
35+
m.add_class::<result::SatisfactionUnderAssumptionsResult>()?;
3436
m.add_class::<result::Solution>()?;
3537

3638
submodule!(constraints, python, m);

pumpkin-py/src/model.rs

Lines changed: 69 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ use crate::optimisation::Direction;
2121
use crate::optimisation::OptimisationResult;
2222
use crate::optimisation::Optimiser;
2323
use crate::result::SatisfactionResult;
24+
use crate::result::SatisfactionUnderAssumptionsResult;
2425
use crate::result::Solution;
2526
use crate::variables::BoolExpression;
2627
use crate::variables::BoolVariable;
2728
use crate::variables::IntExpression;
2829
use crate::variables::IntVariable;
30+
use crate::variables::Predicate;
2931
use crate::variables::VariableMap;
3032

3133
#[pyclass]
@@ -36,15 +38,6 @@ pub struct Model {
3638
constraints: Vec<ModelConstraint>,
3739
}
3840

39-
#[pyclass(eq, eq_int)]
40-
#[derive(Clone, Copy, PartialEq, Eq)]
41-
pub enum Comparator {
42-
NotEqual,
43-
Equal,
44-
LessThanOrEqual,
45-
GreaterThanOrEqual,
46-
}
47-
4841
#[pymethods]
4942
impl Model {
5043
#[new]
@@ -121,23 +114,13 @@ impl Model {
121114
}
122115
}
123116

124-
#[pyo3(signature = (integer, comparator, value, name=None))]
125-
fn predicate_as_boolean(
126-
&mut self,
127-
integer: IntExpression,
128-
comparator: Comparator,
129-
value: i32,
130-
name: Option<&str>,
131-
) -> BoolExpression {
117+
#[pyo3(signature = (predicate, name=None))]
118+
fn predicate_as_boolean(&mut self, predicate: Predicate, name: Option<&str>) -> BoolExpression {
132119
self.boolean_variables
133120
.push(ModelBoolVar {
134121
name: name.map(|n| n.to_owned()),
135122
integer_equivalent: None,
136-
predicate: Some(Predicate {
137-
integer,
138-
comparator,
139-
value,
140-
}),
123+
predicate: Some(predicate),
141124
})
142125
.into()
143126
}
@@ -191,6 +174,70 @@ impl Model {
191174
}
192175
}
193176

177+
#[pyo3(signature = (assumptions))]
178+
fn satisfy_under_assumptions(
179+
&self,
180+
assumptions: Vec<Predicate>,
181+
) -> SatisfactionUnderAssumptionsResult {
182+
let solver_setup = self.create_solver(None);
183+
184+
let Ok((mut solver, variable_map)) = solver_setup else {
185+
return SatisfactionUnderAssumptionsResult::Unsatisfiable();
186+
};
187+
188+
let mut brancher = solver.default_brancher();
189+
190+
let solver_assumptions = assumptions
191+
.iter()
192+
.map(|pred| pred.to_solver_predicate(&variable_map))
193+
.collect::<Vec<_>>();
194+
195+
// Maarten: I do not understand why it is necessary, but we have to create a local variable
196+
// here that is the result of the `match` statement. Otherwise the compiler
197+
// complains that `solver` and `brancher` potentially do not live long enough.
198+
//
199+
// Ideally this would not be necessary, but perhaps it is unavoidable with the setup we
200+
// currently have. Either way, we take the suggestion by the compiler.
201+
let result = match solver.satisfy_under_assumptions(&mut brancher, &mut Indefinite, &solver_assumptions) {
202+
pumpkin_solver::results::SatisfactionResultUnderAssumptions::Satisfiable(solution) => {
203+
SatisfactionUnderAssumptionsResult::Satisfiable(Solution {
204+
solver_solution: solution,
205+
variable_map,
206+
})
207+
}
208+
pumpkin_solver::results::SatisfactionResultUnderAssumptions::UnsatisfiableUnderAssumptions(mut result) => {
209+
// Maarten: For now we assume that the core _must_ consist of the predicates that
210+
// were the input to the solve call. In general this is not the case, e.g. when
211+
// the assumptions can be semantically minized (the assumptions [y <= 1],
212+
// [y >= 0] and [y != 0] will be compressed to [y == 1] which would end up in
213+
// the core).
214+
//
215+
// In the future, perhaps we should make the distinction between predicates and
216+
// literals in the python wrapper as well. For now, this is the simplest way
217+
// forward. I expect that the situation above almost never happens in practice.
218+
let core = result
219+
.extract_core()
220+
.iter()
221+
.map(|predicate| assumptions
222+
.iter()
223+
.find(|pred| pred.to_solver_predicate(&variable_map) == *predicate)
224+
.copied()
225+
.expect("predicates in core must be part of the assumptions"))
226+
.collect();
227+
228+
SatisfactionUnderAssumptionsResult::UnsatisfiableUnderAssumptions(core)
229+
}
230+
pumpkin_solver::results::SatisfactionResultUnderAssumptions::Unsatisfiable => {
231+
SatisfactionUnderAssumptionsResult::Unsatisfiable()
232+
}
233+
pumpkin_solver::results::SatisfactionResultUnderAssumptions::Unknown => {
234+
SatisfactionUnderAssumptionsResult::Unknown()
235+
}
236+
};
237+
238+
result
239+
}
240+
194241
#[pyo3(signature = (objective, optimiser=Optimiser::LinearSatUnsat, direction=Direction::Minimise, proof=None))]
195242
fn optimise(
196243
&self,
@@ -411,26 +458,3 @@ impl ModelBoolVar {
411458
Ok(literal)
412459
}
413460
}
414-
415-
struct Predicate {
416-
integer: IntExpression,
417-
comparator: Comparator,
418-
value: i32,
419-
}
420-
421-
impl Predicate {
422-
/// Convert the predicate in the model domain to a predicate in the solver domain.
423-
fn to_solver_predicate(
424-
&self,
425-
variable_map: &VariableMap,
426-
) -> pumpkin_solver::predicates::Predicate {
427-
let affine_view = self.integer.to_affine_view(variable_map);
428-
429-
match self.comparator {
430-
Comparator::NotEqual => predicate![affine_view != self.value],
431-
Comparator::Equal => predicate![affine_view == self.value],
432-
Comparator::LessThanOrEqual => predicate![affine_view <= self.value],
433-
Comparator::GreaterThanOrEqual => predicate![affine_view >= self.value],
434-
}
435-
}
436-
}

pumpkin-py/src/result.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use pyo3::prelude::*;
33

44
use crate::variables::BoolExpression;
55
use crate::variables::IntExpression;
6+
use crate::variables::Predicate;
67
use crate::variables::VariableMap;
78

89
#[pyclass]
@@ -13,6 +14,15 @@ pub enum SatisfactionResult {
1314
Unknown(),
1415
}
1516

17+
#[pyclass]
18+
#[allow(clippy::large_enum_variant)]
19+
pub enum SatisfactionUnderAssumptionsResult {
20+
Satisfiable(Solution),
21+
UnsatisfiableUnderAssumptions(Vec<Predicate>),
22+
Unsatisfiable(),
23+
Unknown(),
24+
}
25+
1626
#[pyclass]
1727
#[derive(Clone)]
1828
pub struct Solution {
@@ -32,3 +42,7 @@ impl Solution {
3242
.get_literal_value(variable.to_literal(&self.variable_map))
3343
}
3444
}
45+
46+
#[pyclass]
47+
#[derive(Clone)]
48+
pub struct CoreExtractor {}

pumpkin-py/src/variables.rs

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use pumpkin_solver::containers::KeyedVec;
22
use pumpkin_solver::containers::StorageKey;
3+
use pumpkin_solver::predicate;
34
use pumpkin_solver::variables::AffineView;
45
use pumpkin_solver::variables::DomainId;
56
use pumpkin_solver::variables::Literal;
67
use pumpkin_solver::variables::TransformableVariable;
78
use pyo3::prelude::*;
89

9-
#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)]
10+
#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
1011
pub struct IntVariable(usize);
1112

1213
impl StorageKey for IntVariable {
@@ -19,8 +20,8 @@ impl StorageKey for IntVariable {
1920
}
2021
}
2122

22-
#[pyclass]
23-
#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)]
23+
#[pyclass(eq, hash, frozen)]
24+
#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
2425
pub struct IntExpression {
2526
pub variable: IntVariable,
2627
pub offset: i32,
@@ -83,6 +84,52 @@ impl IntExpression {
8384
}
8485
}
8586

87+
#[pyclass(eq, eq_int, hash, frozen)]
88+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
89+
pub enum Comparator {
90+
NotEqual,
91+
Equal,
92+
LessThanOrEqual,
93+
GreaterThanOrEqual,
94+
}
95+
96+
#[pyclass(eq, get_all, hash, frozen)]
97+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
98+
pub struct Predicate {
99+
pub variable: IntExpression,
100+
pub comparator: Comparator,
101+
pub value: i32,
102+
}
103+
104+
#[pymethods]
105+
impl Predicate {
106+
#[new]
107+
fn new(variable: IntExpression, comparator: Comparator, value: i32) -> Self {
108+
Self {
109+
variable,
110+
comparator,
111+
value,
112+
}
113+
}
114+
}
115+
116+
impl Predicate {
117+
/// Convert the predicate in the model domain to a predicate in the solver domain.
118+
pub(crate) fn to_solver_predicate(
119+
self,
120+
variable_map: &VariableMap,
121+
) -> pumpkin_solver::predicates::Predicate {
122+
let affine_view = self.variable.to_affine_view(variable_map);
123+
124+
match self.comparator {
125+
Comparator::NotEqual => predicate![affine_view != self.value],
126+
Comparator::Equal => predicate![affine_view == self.value],
127+
Comparator::LessThanOrEqual => predicate![affine_view <= self.value],
128+
Comparator::GreaterThanOrEqual => predicate![affine_view >= self.value],
129+
}
130+
}
131+
}
132+
86133
#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)]
87134
pub struct BoolVariable(usize);
88135

@@ -96,7 +143,7 @@ impl StorageKey for BoolVariable {
96143
}
97144
}
98145

99-
#[pyclass]
146+
#[pyclass(eq)]
100147
#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)]
101148
pub struct BoolExpression(BoolVariable, bool);
102149

pumpkin-py/tests/test_assumptions.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from pumpkin_py import Comparator, Model, Predicate, SatisfactionUnderAssumptionsResult
2+
from pumpkin_py.constraints import LessThanOrEquals
3+
4+
5+
def test_assumptions_are_respected():
6+
model = Model()
7+
8+
x = model.new_integer_variable(1, 5, name="x")
9+
10+
assumption = Predicate(x, Comparator.LessThanOrEqual, 3)
11+
12+
result = model.satisfy_under_assumptions([assumption])
13+
assert isinstance(result, SatisfactionUnderAssumptionsResult.Satisfiable)
14+
15+
solution = result._0
16+
x_value = solution.int_value(x)
17+
assert x_value <= 3
18+
19+
20+
def test_core_extraction():
21+
model = Model()
22+
23+
x = model.new_integer_variable(1, 5, name="x")
24+
y = model.new_integer_variable(1, 5, name="x")
25+
26+
x_ge_3 = Predicate(x, Comparator.GreaterThanOrEqual, 3)
27+
y_ge_3 = Predicate(y, Comparator.GreaterThanOrEqual, 3)
28+
29+
model.add_constraint(LessThanOrEquals([x, y], 5))
30+
31+
result = model.satisfy_under_assumptions([x_ge_3, y_ge_3])
32+
assert isinstance(result, SatisfactionUnderAssumptionsResult.UnsatisfiableUnderAssumptions)
33+
34+
core = set(result._0)
35+
assert set([x_ge_3, y_ge_3]) == core
36+

0 commit comments

Comments
 (0)