Skip to content

Commit ea680ce

Browse files
committed
[CP-SAT] improve diffn clustering; more work on hints
1 parent c5ce41e commit ea680ce

28 files changed

+518
-231
lines changed

ortools/sat/BUILD.bazel

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,7 @@ cc_library(
12311231
":pb_constraint",
12321232
":sat_base",
12331233
":sat_parameters_cc_proto",
1234+
":synchronization",
12341235
":util",
12351236
"//ortools/base",
12361237
"//ortools/base:strong_vector",
@@ -2601,16 +2602,10 @@ cc_library(
26012602
":model",
26022603
":sat_base",
26032604
":util",
2604-
"//ortools/base",
26052605
"//ortools/base:stl_util",
26062606
"//ortools/base:strong_vector",
2607-
"//ortools/util:saturated_arithmetic",
2608-
"//ortools/util:sorted_interval_list",
26092607
"//ortools/util:strong_integers",
2610-
"//ortools/util:time_limit",
26112608
"@com_google_absl//absl/base:core_headers",
2612-
"@com_google_absl//absl/container:btree",
2613-
"@com_google_absl//absl/container:flat_hash_map",
26142609
"@com_google_absl//absl/log:check",
26152610
"@com_google_absl//absl/strings",
26162611
"@com_google_absl//absl/types:span",
@@ -3160,6 +3155,7 @@ cc_library(
31603155
":util",
31613156
"//ortools/util:saturated_arithmetic",
31623157
"//ortools/util:strong_integers",
3158+
"//ortools/util:time_limit",
31633159
"@com_google_absl//absl/container:flat_hash_set",
31643160
"@com_google_absl//absl/container:inlined_vector",
31653161
"@com_google_absl//absl/log",

ortools/sat/cp_model_presolve.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5882,14 +5882,14 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) {
58825882
indexed_intervals.push_back({x, IntegerValue(context_->StartMin(y)),
58835883
IntegerValue(context_->EndMax(y))});
58845884
}
5885-
std::vector<std::vector<int>> no_overlaps;
5886-
ConstructOverlappingSets(/*already_sorted=*/false, &indexed_intervals,
5887-
&no_overlaps);
5888-
for (const std::vector<int>& no_overlap : no_overlaps) {
5885+
CompactVectorVector<int> no_overlaps;
5886+
absl::c_sort(indexed_intervals, IndexedInterval::ComparatorByStart());
5887+
ConstructOverlappingSets(absl::MakeSpan(indexed_intervals), &no_overlaps);
5888+
for (int i = 0; i < no_overlaps.size(); ++i) {
58895889
ConstraintProto* new_ct = context_->working_model->add_constraints();
58905890
// Unfortunately, the Assign() method does not work in or-tools as the
58915891
// protobuf int32_t type is not the int type.
5892-
for (const int i : no_overlap) {
5892+
for (const int i : no_overlaps[i]) {
58935893
new_ct->mutable_no_overlap()->add_intervals(i);
58945894
}
58955895
}

ortools/sat/cp_model_search.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,13 @@ absl::flat_hash_map<std::string, SatParameters> GetNamedParameters(
704704
new_params.set_optimize_with_lb_tree_search(false);
705705
new_params.set_optimize_with_max_hs(false);
706706

707+
// Given that each workers work on a different part of the subtree, it might
708+
// not be a good idea to try to work on a global shared solution.
709+
//
710+
// TODO(user): Experiments more here, in particular we could follow it if
711+
// it falls into the current subtree.
712+
new_params.set_polarity_exploit_ls_hints(false);
713+
707714
strategies["shared_tree"] = new_params;
708715
}
709716

ortools/sat/cp_model_solver.cc

Lines changed: 22 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -682,12 +682,13 @@ void LogFinalStatistics(SharedClasses* shared) {
682682
shared->logger->FlushPendingThrottledLogs(/*ignore_rates=*/true);
683683
SOLVER_LOG(shared->logger, "");
684684

685-
shared->stat_tables.Display(shared->logger);
685+
shared->stat_tables->Display(shared->logger);
686686
shared->response->DisplayImprovementStatistics();
687687

688688
std::vector<std::vector<std::string>> table;
689689
table.push_back({"Solution repositories", "Added", "Queried", "Synchro"});
690690
table.push_back(shared->response->SolutionsRepository().TableLineStats());
691+
table.push_back(shared->ls_hints->TableLineStats());
691692
if (shared->lp_solutions != nullptr) {
692693
table.push_back(shared->lp_solutions->TableLineStats());
693694
}
@@ -914,35 +915,9 @@ class FullProblemSolver : public SubSolver {
914915
shared_->response->first_solution_solvers_should_stop());
915916
}
916917

917-
if (shared->response != nullptr) {
918-
local_model_.Register<SharedResponseManager>(shared->response);
919-
}
920-
921-
if (shared->lp_solutions != nullptr) {
922-
local_model_.Register<SharedLPSolutionRepository>(
923-
shared->lp_solutions.get());
924-
}
925-
926-
if (shared->incomplete_solutions != nullptr) {
927-
local_model_.Register<SharedIncompleteSolutionManager>(
928-
shared->incomplete_solutions.get());
929-
}
930-
931-
if (shared->bounds != nullptr) {
932-
local_model_.Register<SharedBoundsManager>(shared->bounds.get());
933-
}
934-
935-
if (shared->clauses != nullptr) {
936-
local_model_.Register<SharedClausesManager>(shared->clauses.get());
937-
}
938-
939-
if (local_parameters.use_shared_tree_search()) {
940-
local_model_.Register<SharedTreeManager>(shared->shared_tree_manager);
941-
}
942-
943918
// TODO(user): For now we do not count LNS statistics. We could easily
944919
// by registering the SharedStatistics class with LNS local model.
945-
local_model_.Register<SharedStatistics>(shared_->stats);
920+
shared_->RegisterSharedClassesInLocalModel(&local_model_);
946921

947922
// Setup the local logger, in multi-thread log_search_progress should be
948923
// false by default, but we might turn it on for debugging. It is on by
@@ -956,10 +931,10 @@ class FullProblemSolver : public SubSolver {
956931
CpSolverResponse response;
957932
shared_->response->FillSolveStatsInResponse(&local_model_, &response);
958933
shared_->response->AppendResponseToBeMerged(response);
959-
shared_->stat_tables.AddTimingStat(*this);
960-
shared_->stat_tables.AddLpStat(name(), &local_model_);
961-
shared_->stat_tables.AddSearchStat(name(), &local_model_);
962-
shared_->stat_tables.AddClausesStat(name(), &local_model_);
934+
shared_->stat_tables->AddTimingStat(*this);
935+
shared_->stat_tables->AddLpStat(name(), &local_model_);
936+
shared_->stat_tables->AddSearchStat(name(), &local_model_);
937+
shared_->stat_tables->AddClausesStat(name(), &local_model_);
963938
}
964939

965940
bool IsDone() override {
@@ -1104,30 +1079,11 @@ class FeasibilityPumpSolver : public SubSolver {
11041079
*(local_model_->GetOrCreate<SatParameters>()) = local_parameters;
11051080
shared_->time_limit->UpdateLocalLimit(
11061081
local_model_->GetOrCreate<TimeLimit>());
1107-
1108-
if (shared->response != nullptr) {
1109-
local_model_->Register<SharedResponseManager>(shared->response);
1110-
}
1111-
1112-
if (shared->lp_solutions != nullptr) {
1113-
local_model_->Register<SharedLPSolutionRepository>(
1114-
shared->lp_solutions.get());
1115-
}
1116-
1117-
if (shared->incomplete_solutions != nullptr) {
1118-
local_model_->Register<SharedIncompleteSolutionManager>(
1119-
shared->incomplete_solutions.get());
1120-
}
1121-
1122-
// Level zero variable bounds sharing.
1123-
if (shared_->bounds != nullptr) {
1124-
RegisterVariableBoundsLevelZeroImport(
1125-
shared_->model_proto, shared_->bounds.get(), local_model_.get());
1126-
}
1082+
shared_->RegisterSharedClassesInLocalModel(local_model_.get());
11271083
}
11281084

11291085
~FeasibilityPumpSolver() override {
1130-
shared_->stat_tables.AddTimingStat(*this);
1086+
shared_->stat_tables->AddTimingStat(*this);
11311087
}
11321088

11331089
bool IsDone() override { return shared_->SearchIsDone(); }
@@ -1216,8 +1172,8 @@ class LnsSolver : public SubSolver {
12161172
shared_(shared) {}
12171173

12181174
~LnsSolver() override {
1219-
shared_->stat_tables.AddTimingStat(*this);
1220-
shared_->stat_tables.AddLnsStat(
1175+
shared_->stat_tables->AddTimingStat(*this);
1176+
shared_->stat_tables->AddLnsStat(
12211177
name(),
12221178
/*num_fully_solved_calls=*/generator_->num_fully_solved_calls(),
12231179
/*num_calls=*/generator_->num_calls(),
@@ -1654,6 +1610,7 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) {
16541610
"synchronization_agent", [shared]() {
16551611
shared->response->Synchronize();
16561612
shared->response->MutableSolutionsRepository()->Synchronize();
1613+
shared->ls_hints->Synchronize();
16571614
if (shared->bounds != nullptr) {
16581615
shared->bounds->Synchronize();
16591616
}
@@ -1946,7 +1903,7 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) {
19461903

19471904
if (num_ls_default > 0) {
19481905
std::shared_ptr<SharedLsStates> states = std::make_shared<SharedLsStates>(
1949-
ls_name, params, &shared->stat_tables);
1906+
ls_name, params, shared->stat_tables);
19501907
for (int i = 0; i < num_ls_default; ++i) {
19511908
SatParameters local_params = params;
19521909
local_params.set_random_seed(
@@ -1956,14 +1913,15 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) {
19561913
std::make_unique<FeasibilityJumpSolver>(
19571914
ls_name, SubSolver::INCOMPLETE, get_linear_model(),
19581915
local_params, states, shared->time_limit, shared->response,
1959-
shared->bounds.get(), shared->stats, &shared->stat_tables));
1916+
shared->bounds.get(), shared->ls_hints, shared->stats,
1917+
shared->stat_tables));
19601918
}
19611919
}
19621920

19631921
if (num_ls_lin > 0) {
19641922
std::shared_ptr<SharedLsStates> lin_states =
19651923
std::make_shared<SharedLsStates>(lin_ls_name, params,
1966-
&shared->stat_tables);
1924+
shared->stat_tables);
19671925
for (int i = 0; i < num_ls_lin; ++i) {
19681926
SatParameters local_params = params;
19691927
local_params.set_random_seed(
@@ -1973,7 +1931,8 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) {
19731931
std::make_unique<FeasibilityJumpSolver>(
19741932
lin_ls_name, SubSolver::INCOMPLETE, get_linear_model(),
19751933
local_params, lin_states, shared->time_limit, shared->response,
1976-
shared->bounds.get(), shared->stats, &shared->stat_tables));
1934+
shared->bounds.get(), shared->ls_hints, shared->stats,
1935+
shared->stat_tables));
19771936
}
19781937
}
19791938
}
@@ -2011,13 +1970,13 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) {
20111970
if (local_params.feasibility_jump_linearization_level() == 0) {
20121971
if (fj_states == nullptr) {
20131972
fj_states = std::make_shared<SharedLsStates>(
2014-
local_params.name(), params, &shared->stat_tables);
1973+
local_params.name(), params, shared->stat_tables);
20151974
}
20161975
states = fj_states;
20171976
} else {
20181977
if (fj_lin_states == nullptr) {
20191978
fj_lin_states = std::make_shared<SharedLsStates>(
2020-
local_params.name(), params, &shared->stat_tables);
1979+
local_params.name(), params, shared->stat_tables);
20211980
}
20221981
states = fj_lin_states;
20231982
}
@@ -2026,8 +1985,8 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) {
20261985
std::make_unique<FeasibilityJumpSolver>(
20271986
local_params.name(), SubSolver::FIRST_SOLUTION,
20281987
get_linear_model(), local_params, states, shared->time_limit,
2029-
shared->response, shared->bounds.get(), shared->stats,
2030-
&shared->stat_tables));
1988+
shared->response, shared->bounds.get(), shared->ls_hints,
1989+
shared->stats, shared->stat_tables));
20311990
} else {
20321991
first_solution_full_subsolvers.push_back(
20331992
std::make_unique<FullProblemSolver>(

ortools/sat/cp_model_solver_helpers.cc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1971,8 +1971,10 @@ SharedClasses::SharedClasses(const CpModelProto* proto, Model* global_model)
19711971
time_limit(global_model->GetOrCreate<ModelSharedTimeLimit>()),
19721972
logger(global_model->GetOrCreate<SolverLogger>()),
19731973
stats(global_model->GetOrCreate<SharedStatistics>()),
1974+
stat_tables(global_model->GetOrCreate<SharedStatTables>()),
19741975
response(global_model->GetOrCreate<SharedResponseManager>()),
1975-
shared_tree_manager(global_model->GetOrCreate<SharedTreeManager>()) {
1976+
shared_tree_manager(global_model->GetOrCreate<SharedTreeManager>()),
1977+
ls_hints(global_model->GetOrCreate<SharedLsSolutionRepository>()) {
19761978
const SatParameters& params = *global_model->GetOrCreate<SatParameters>();
19771979

19781980
if (params.share_level_zero_bounds()) {
@@ -2007,6 +2009,31 @@ SharedClasses::SharedClasses(const CpModelProto* proto, Model* global_model)
20072009
}
20082010
}
20092011

2012+
void SharedClasses::RegisterSharedClassesInLocalModel(Model* local_model) {
2013+
// Note that we do not register the logger which is not a shared class.
2014+
local_model->Register<SharedResponseManager>(response);
2015+
local_model->Register<SharedLsSolutionRepository>(ls_hints);
2016+
local_model->Register<SharedTreeManager>(shared_tree_manager);
2017+
local_model->Register<SharedStatistics>(stats);
2018+
local_model->Register<SharedStatTables>(stat_tables);
2019+
2020+
// TODO(user): Use parameters and not the presence/absence of these class
2021+
// to decide when to use them.
2022+
if (lp_solutions != nullptr) {
2023+
local_model->Register<SharedLPSolutionRepository>(lp_solutions.get());
2024+
}
2025+
if (incomplete_solutions != nullptr) {
2026+
local_model->Register<SharedIncompleteSolutionManager>(
2027+
incomplete_solutions.get());
2028+
}
2029+
if (bounds != nullptr) {
2030+
local_model->Register<SharedBoundsManager>(bounds.get());
2031+
}
2032+
if (clauses != nullptr) {
2033+
local_model->Register<SharedClausesManager>(clauses.get());
2034+
}
2035+
}
2036+
20102037
bool SharedClasses::SearchIsDone() {
20112038
if (response->ProblemIsSolved()) {
20122039
// This is for cases where the time limit is checked more often.

ortools/sat/cp_model_solver_helpers.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,20 @@ struct SharedClasses {
5656
ModelSharedTimeLimit* const time_limit;
5757
SolverLogger* const logger;
5858
SharedStatistics* const stats;
59+
SharedStatTables* const stat_tables;
5960
SharedResponseManager* const response;
6061
SharedTreeManager* const shared_tree_manager;
62+
SharedLsSolutionRepository* const ls_hints;
6163

6264
// These can be nullptr depending on the options.
6365
std::unique_ptr<SharedBoundsManager> bounds;
6466
std::unique_ptr<SharedLPSolutionRepository> lp_solutions;
6567
std::unique_ptr<SharedIncompleteSolutionManager> incomplete_solutions;
6668
std::unique_ptr<SharedClausesManager> clauses;
6769

68-
// For displaying summary at the end.
69-
SharedStatTables stat_tables;
70+
// call local_model->Register() on most of the class here, this allow to
71+
// more easily depends on one of the shared class deep within the solver.
72+
void RegisterSharedClassesInLocalModel(Model* local_model);
7073

7174
bool SearchIsDone();
7275
};

0 commit comments

Comments
 (0)