Skip to content

Sort implicits with a proper comparison function #12562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1806,6 +1806,11 @@ object SymDenotations {
def baseClasses(implicit onBehalf: BaseData, ctx: Context): List[ClassSymbol] =
baseData._1

/** Like `baseClasses.length` but more efficient. */
def baseClassesLength(using BaseData, Context): Int =
// `+ 1` because the baseClassSet does not include the current class unlike baseClasses
baseClassSet.classIds.length + 1

/** A bitset that contains the superId's of all base classes */
private def baseClassSet(implicit onBehalf: BaseData, ctx: Context): BaseClassSet =
baseData._2
Expand Down
113 changes: 86 additions & 27 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1247,53 +1247,112 @@ trait Implicits:
|Consider using the scala.util.NotGiven class to implement similar functionality.""",
ctx.source.atSpan(span))

/** A relation that influences the order in which implicits are tried.
/** Compare the length of the baseClasses of two symbols (except for objects,
* where we use the length of the companion class instead if it's bigger).
*
* This relation is meant to approximate `Applications#compareOwner` while also
* inducing a total ordering: `compareOwner` returns `0` for unrelated symbols
* and therefore only induces a partial ordering, meaning it cannot be used
* as a sorting function (see `java.util.Comparator#compare`).
*/
def compareBaseClassesLength(sym1: Symbol, sym2: Symbol): Int =
def len(sym: Symbol) =
if sym.is(ModuleClass) && sym.companionClass.exists then
Math.max(sym.asClass.baseClassesLength, sym.companionClass.asClass.baseClassesLength)
else if sym.isClass then
sym.asClass.baseClassesLength
else
0
len(sym1) - len(sym2)

/** A relation that influences the order in which eligible implicits are tried.
*
* We prefer (in order of importance)
* 1. more deeply nested definitions
* 2. definitions with fewer implicit parameters
* 3. definitions in subclasses
* 3. definitions whose owner has more parents (see `compareBaseClassesLength`)
* The reason for (2) is that we want to fail fast if the search type
* is underconstrained. So we look for "small" goals first, because that
* will give an ambiguity quickly.
*/
def prefer(cand1: Candidate, cand2: Candidate): Boolean =
val level1 = cand1.level
val level2 = cand2.level
if level1 > level2 then return true
if level1 < level2 then return false
val sym1 = cand1.ref.symbol
val sym2 = cand2.ref.symbol
def compareEligibles(e1: Candidate, e2: Candidate): Int =
if e1 eq e2 then return 0
val cmpLevel = e1.level - e2.level
if cmpLevel != 0 then return -cmpLevel // 1.
val sym1 = e1.ref.symbol
val sym2 = e2.ref.symbol
val arity1 = sym1.info.firstParamTypes.length
val arity2 = sym2.info.firstParamTypes.length
if arity1 < arity2 then return true
if arity1 > arity2 then return false
compareOwner(sym1.owner, sym2.owner) == 1
val cmpArity = arity1 - arity2
if cmpArity != 0 then return cmpArity // 2.
val cmpBcs = compareBaseClassesLength(sym1.owner, sym2.owner)
-cmpBcs // 3.

/** Sort list of implicit references according to `prefer`.
/** Check if `ord` respects the contract of `Ordering`.
*
* More precisely, we check that its `compare` method respects the invariants listed
* in https://docs.oracle.com/javase/8/docs/api/java/util/Comparator.html#compare-T-T-
*/
def validateOrdering(ord: Ordering[Candidate]): Unit =
for
x <- eligible
y <- eligible
cmpXY = Integer.signum(ord.compare(x, y))
cmpYX = Integer.signum(ord.compare(y, x))
z <- eligible
cmpXZ = Integer.signum(ord.compare(x, z))
cmpYZ = Integer.signum(ord.compare(y, z))
do
def reportViolation(msg: String): Unit =
Console.err.println(s"Internal error: comparison function violated ${msg.stripMargin}")
def showCandidate(c: Candidate): String =
s"$c (${c.ref.symbol.showLocated})"

if cmpXY != -cmpYX then
reportViolation(
s"""signum(cmp(x, y)) == -signum(cmp(y, x)) given:
|x = ${showCandidate(x)}
|y = ${showCandidate(y)}
|cmpXY = $cmpXY
|cmpYX = $cmpYX""")
if cmpXY != 0 && cmpXY == cmpYZ && cmpXZ != cmpXY then
reportViolation(
s"""transitivity given:
|x = ${showCandidate(x)}
|y = ${showCandidate(y)}
|z = ${showCandidate(z)}
|cmpXY = $cmpXY
|cmpXZ = $cmpXZ
|cmpYZ = $cmpYZ""")
if cmpXY == 0 && cmpXZ != cmpYZ then
reportViolation(
s"""cmp(x, y) == 0 implies that signum(cmp(x, z)) == signum(cmp(y, z)) given:
|x = ${showCandidate(x)}
|y = ${showCandidate(y)}
|z = ${showCandidate(z)}
|cmpXY = $cmpXY
|cmpXZ = $cmpXZ
|cmpYZ = $cmpYZ""")
end validateOrdering

/** Sort list of implicit references according to `compareEligibles`.
* This is just an optimization that aims at reducing the average
* number of candidates to be tested.
*/
def sort(eligible: List[Candidate]) = eligible match {
def sort(eligible: List[Candidate]) = eligible match
case Nil => eligible
case e1 :: Nil => eligible
case e1 :: e2 :: Nil =>
if (prefer(e2, e1)) e2 :: e1 :: Nil
if compareEligibles(e2, e1) < 0 then e2 :: e1 :: Nil
else eligible
case _ =>
try eligible.sortWith(prefer)
val ord: Ordering[Candidate] = (a, b) => compareEligibles(a, b)
try eligible.sorted(using ord)
catch case ex: IllegalArgumentException =>
// diagnostic to see what went wrong
for
e1 <- eligible
e2 <- eligible
if prefer(e1, e2)
e3 <- eligible
if prefer(e2, e3) && !prefer(e1, e3)
do
val es = List(e1, e2, e3)
println(i"transitivity violated for $es%, %\n ${es.map(_.implicitRef.underlyingRef.symbol.showLocated)}%, %")
// This exception being thrown probably means that our comparison
// function is broken, check if that's the case
validateOrdering(ord)
throw ex
}

rank(sort(eligible), NoMatchingImplicitsFailure, Nil)
end searchImplicit
Expand Down