Skip to content

Commit 867e641

Browse files
committed
[CSOptimizer] Mark compiler synthesized disjunctions as optimized
If a disjunction has favored choices, let's mark it as optimized with a high score to make sure that such disjunctions are prioritized since only disjunctions that could have their choices fovored independently from the optimization algorithm are compiler synthesized ones for things like IUO references, explicit coercions etc.
1 parent 15c773b commit 867e641

File tree

1 file changed

+60
-51
lines changed

1 file changed

+60
-51
lines changed

lib/Sema/CSOptimizer.cpp

+60-51
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ using namespace constraints;
3434

3535
namespace {
3636

37+
struct DisjunctionInfo {
38+
/// The score of the disjunction is the highest score from its choices.
39+
/// If the score is nullopt it means that the disjunction is not optimizable.
40+
std::optional<double> Score;
41+
/// The highest scoring choices that could be favored when disjunction
42+
/// is attempted.
43+
llvm::TinyPtrVector<Constraint *> FavoredChoices;
44+
45+
DisjunctionInfo() = default;
46+
DisjunctionInfo(double score, ArrayRef<Constraint *> favoredChoices = {})
47+
: Score(score), FavoredChoices(favoredChoices) {}
48+
};
49+
3750
// TODO: both `isIntegerType` and `isFloatType` should be available on Type
3851
// as `isStdlib{Integer, Float}Type`.
3952

@@ -246,16 +259,30 @@ static void findFavoredChoicesBasedOnArity(
246259
/// favored choices in the current context.
247260
static void determineBestChoicesInContext(
248261
ConstraintSystem &cs, SmallVectorImpl<Constraint *> &disjunctions,
249-
llvm::DenseMap<Constraint *,
250-
std::pair<double, llvm::TinyPtrVector<Constraint *>>>
251-
&favorings) {
262+
llvm::DenseMap<Constraint *, DisjunctionInfo> &result) {
252263
double bestOverallScore = 0.0;
253-
// Tops scores across all of the disjunctions.
254-
llvm::DenseMap<Constraint *, double> disjunctionScores;
255-
llvm::DenseMap<Constraint *, llvm::TinyPtrVector<Constraint *>>
256-
favoredChoicesPerDisjunction;
264+
265+
auto recordResult = [&bestOverallScore, &result](Constraint *disjunction,
266+
DisjunctionInfo &&info) {
267+
bestOverallScore = std::max(bestOverallScore, info.Score.value_or(0));
268+
result.try_emplace(disjunction, info);
269+
};
257270

258271
for (auto *disjunction : disjunctions) {
272+
// If this is a compiler synthesized disjunction, mark it as supported
273+
// and record all of the previously favored choices. Such disjunctions
274+
// include - explicit coercions, IUO references,injected implicit
275+
// initializers for CGFloat<->Double conversions and restrictions with
276+
// multiple choices.
277+
if (disjunction->countFavoredNestedConstraints() > 0) {
278+
DisjunctionInfo info(/*score=*/2.0);
279+
llvm::copy_if(disjunction->getNestedConstraints(),
280+
std::back_inserter(info.FavoredChoices),
281+
[](Constraint *choice) { return choice->isFavored(); });
282+
recordResult(disjunction, std::move(info));
283+
continue;
284+
}
285+
259286
auto applicableFn =
260287
getApplicableFnConstraint(cs.getConstraintGraph(), disjunction);
261288

@@ -282,14 +309,14 @@ static void determineBestChoicesInContext(
282309
// of `OverloadedDeclRef` calls were favored purely
283310
// based on arity of arguments and parameters matching.
284311
{
285-
findFavoredChoicesBasedOnArity(
286-
cs, disjunction, argumentList, [&](Constraint *choice) {
287-
favoredChoicesPerDisjunction[disjunction].push_back(choice);
288-
});
289-
290-
if (!favoredChoicesPerDisjunction[disjunction].empty()) {
291-
disjunctionScores[disjunction] = 0.01;
292-
bestOverallScore = std::max(bestOverallScore, 0.01);
312+
llvm::TinyPtrVector<Constraint *> favoredChoices;
313+
findFavoredChoicesBasedOnArity(cs, disjunction, argumentList,
314+
[&favoredChoices](Constraint *choice) {
315+
favoredChoices.push_back(choice);
316+
});
317+
318+
if (!favoredChoices.empty()) {
319+
recordResult(disjunction, {/*score=*/0.01, favoredChoices});
293320
continue;
294321
}
295322
}
@@ -894,17 +921,16 @@ static void determineBestChoicesInContext(
894921
<< " with score " << bestScore << "\n";
895922
}
896923

897-
// No matching overload choices to favor.
898-
if (bestScore == 0.0)
899-
continue;
900-
901924
bestOverallScore = std::max(bestOverallScore, bestScore);
902925

903-
disjunctionScores[disjunction] = bestScore;
926+
DisjunctionInfo info(/*score=*/bestScore);
927+
904928
for (const auto &choice : favoredChoices) {
905929
if (choice.second == bestScore)
906-
favoredChoicesPerDisjunction[disjunction].push_back(choice.first);
930+
info.FavoredChoices.push_back(choice.first);
907931
}
932+
933+
recordResult(disjunction, std::move(info));
908934
}
909935

910936
if (cs.isDebugMode() && bestOverallScore > 0) {
@@ -935,14 +961,15 @@ static void determineBestChoicesInContext(
935961
getLogger(/*extraIndent=*/4)
936962
<< "Best overall score = " << bestOverallScore << '\n';
937963

938-
for (const auto &entry : disjunctionScores) {
964+
for (auto *disjunction : disjunctions) {
965+
auto &entry = result[disjunction];
939966
getLogger(/*extraIndent=*/4)
940967
<< "[Disjunction '"
941-
<< entry.first->getNestedConstraints()[0]->getFirstType()->getString(
968+
<< disjunction->getNestedConstraints()[0]->getFirstType()->getString(
942969
PO)
943-
<< "' with score = " << entry.second << '\n';
970+
<< "' with score = " << entry.Score.value_or(0) << '\n';
944971

945-
for (const auto *choice : favoredChoicesPerDisjunction[entry.first]) {
972+
for (const auto *choice : entry.FavoredChoices) {
946973
auto &log = getLogger(/*extraIndent=*/6);
947974

948975
log << "- ";
@@ -955,16 +982,6 @@ static void determineBestChoicesInContext(
955982

956983
getLogger() << ")\n";
957984
}
958-
959-
if (bestOverallScore == 0)
960-
return;
961-
962-
for (auto &entry : disjunctionScores) {
963-
TinyPtrVector<Constraint *> favoredChoices;
964-
for (auto *choice : favoredChoicesPerDisjunction[entry.first])
965-
favoredChoices.push_back(choice);
966-
favorings[entry.first] = std::make_pair(entry.second, favoredChoices);
967-
}
968985
}
969986

970987
// Attempt to find a disjunction of bind constraints where all options
@@ -1036,9 +1053,7 @@ ConstraintSystem::selectDisjunction() {
10361053
if (auto *disjunction = selectBestBindingDisjunction(*this, disjunctions))
10371054
return std::make_pair(disjunction, llvm::TinyPtrVector<Constraint *>());
10381055

1039-
llvm::DenseMap<Constraint *,
1040-
std::pair</*bestScore=*/double, llvm::TinyPtrVector<Constraint *>>>
1041-
favorings;
1056+
llvm::DenseMap<Constraint *, DisjunctionInfo> favorings;
10421057
determineBestChoicesInContext(*this, disjunctions, favorings);
10431058

10441059
// Pick the disjunction with the smallest number of favored, then active
@@ -1052,23 +1067,16 @@ ConstraintSystem::selectDisjunction() {
10521067
auto &[firstScore, firstFavoredChoices] = favorings[first];
10531068
auto &[secondScore, secondFavoredChoices] = favorings[second];
10541069

1055-
bool isFirstSupported = isSupportedDisjunction(first);
1056-
bool isSecondSupported = isSupportedDisjunction(second);
1057-
10581070
// Rank based on scores only if both disjunctions are supported.
1059-
if (isFirstSupported && isSecondSupported) {
1071+
if (firstScore && secondScore) {
10601072
// If both disjunctions have the same score they should be ranked
10611073
// based on number of favored/active choices.
1062-
if (firstScore != secondScore)
1063-
return firstScore > secondScore;
1074+
if (*firstScore != *secondScore)
1075+
return *firstScore > *secondScore;
10641076
}
10651077

1066-
unsigned numFirstFavored = isFirstSupported
1067-
? firstFavoredChoices.size()
1068-
: first->countFavoredNestedConstraints();
1069-
unsigned numSecondFavored =
1070-
isSecondSupported ? secondFavoredChoices.size()
1071-
: second->countFavoredNestedConstraints();
1078+
unsigned numFirstFavored = firstFavoredChoices.size();
1079+
unsigned numSecondFavored = secondFavoredChoices.size();
10721080

10731081
if (numFirstFavored == numSecondFavored) {
10741082
if (firstActive != secondActive)
@@ -1082,7 +1090,8 @@ ConstraintSystem::selectDisjunction() {
10821090
});
10831091

10841092
if (bestDisjunction != disjunctions.end())
1085-
return std::make_pair(*bestDisjunction, favorings[*bestDisjunction].second);
1093+
return std::make_pair(*bestDisjunction,
1094+
favorings[*bestDisjunction].FavoredChoices);
10861095

10871096
return std::nullopt;
10881097
}

0 commit comments

Comments
 (0)