Skip to content

Commit add182b

Browse files
committed
Sort implicits with a proper comparison function
Before this commit, we sorted implicits using `prefer` which relied on `compareOwner`, but compareOwner does not induce a total ordering: just because `compareOwner(x, y) == 0` does not mean that `compareOwner(x, z)` and `compareOwner(y, z)` have the same sign, this violates the contract of `java.util.Comparator#compare` and lead to an IllegalArgumentException sometimes being thrown (although I wasn't able to reproduce that, see #12479) This commit fixes by this by replacing the usage of `compareOwner` by a new `compareBaseClassesLength` which does induce a total ordering while still hopefully approximating `compareOwner` well enough for our purposes. We also replace `prefer` which returned a Boolean by `compareEligibles` which is directly usable as an Ordering we can pass to `sorted`, this is more efficient than using `sortBy(prefer)` because the latter might end up calling `prefer` twice for a single comparison. Fixes #12479 (I hope).
1 parent d2dd083 commit add182b

File tree

2 files changed

+79
-25
lines changed

2 files changed

+79
-25
lines changed

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,6 +1806,11 @@ object SymDenotations {
18061806
def baseClasses(implicit onBehalf: BaseData, ctx: Context): List[ClassSymbol] =
18071807
baseData._1
18081808

1809+
/** Like `baseClasses.length` but more efficient. */
1810+
def baseClassesLength(using BaseData, Context): Int =
1811+
// `+ 1` because the baseClassSet does not include the current class unlike baseClasses
1812+
baseClassSet.classIds.length + 1
1813+
18091814
/** A bitset that contains the superId's of all base classes */
18101815
private def baseClassSet(implicit onBehalf: BaseData, ctx: Context): BaseClassSet =
18111816
baseData._2

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,53 +1247,102 @@ trait Implicits:
12471247
|Consider using the scala.util.NotGiven class to implement similar functionality.""",
12481248
ctx.source.atSpan(span))
12491249

1250-
/** A relation that influences the order in which implicits are tried.
1250+
/** Compare the length of the baseClasses of two symbols (except for objects,
1251+
* where we use the length of the companion class instead if it's bigger).
1252+
*
1253+
* This relation is meant to approximate `Applications#compareOwner` while also
1254+
* inducing a total ordering: `compareOwner` returns `0` for unrelated symbols
1255+
* and therefore only induces a partial ordering, meaning it cannot be used
1256+
* as a sorting function (see `java.util.Comparator#compare`).
1257+
*/
1258+
def compareBaseClassesLength(sym1: Symbol, sym2: Symbol): Int =
1259+
def len(sym: Symbol) =
1260+
if sym.is(ModuleClass) && sym.companionClass.exists then
1261+
Math.max(sym.asClass.baseClassesLength, sym.companionClass.asClass.baseClassesLength)
1262+
else if sym.isClass then
1263+
sym.asClass.baseClassesLength
1264+
else
1265+
0
1266+
len(sym1) - len(sym2)
1267+
1268+
/** A relation that influences the order in which eligible implicits are tried.
1269+
*
12511270
* We prefer (in order of importance)
12521271
* 1. more deeply nested definitions
12531272
* 2. definitions with fewer implicit parameters
1254-
* 3. definitions in subclasses
1273+
* 3. definitions whose owner has more parents (see `compareBaseClassesLength`)
12551274
* The reason for (2) is that we want to fail fast if the search type
12561275
* is underconstrained. So we look for "small" goals first, because that
12571276
* will give an ambiguity quickly.
12581277
*/
1259-
def prefer(cand1: Candidate, cand2: Candidate): Boolean =
1260-
val level1 = cand1.level
1261-
val level2 = cand2.level
1262-
if level1 > level2 then return true
1263-
if level1 < level2 then return false
1264-
val sym1 = cand1.ref.symbol
1265-
val sym2 = cand2.ref.symbol
1278+
def compareEligibles(e1: Candidate, e2: Candidate): Int =
1279+
if e1 eq e2 then return 0
1280+
val cmpLevel = e1.level - e2.level
1281+
if cmpLevel != 0 then return -cmpLevel // 1.
1282+
val sym1 = e1.ref.symbol
1283+
val sym2 = e2.ref.symbol
12661284
val arity1 = sym1.info.firstParamTypes.length
12671285
val arity2 = sym2.info.firstParamTypes.length
1268-
if arity1 < arity2 then return true
1269-
if arity1 > arity2 then return false
1270-
compareOwner(sym1.owner, sym2.owner) == 1
1286+
val cmpArity = arity1 - arity2
1287+
if cmpArity != 0 then return cmpArity // 2.
1288+
val cmpBcs = compareBaseClassesLength(sym1.owner, sym2.owner)
1289+
-cmpBcs // 3.
12711290

1272-
/** Sort list of implicit references according to `prefer`.
1291+
/** Sort list of implicit references according to `compareEligibles`.
12731292
* This is just an optimization that aims at reducing the average
12741293
* number of candidates to be tested.
12751294
*/
1276-
def sort(eligible: List[Candidate]) = eligible match {
1295+
def sort(eligible: List[Candidate]) = eligible match
12771296
case Nil => eligible
12781297
case e1 :: Nil => eligible
12791298
case e1 :: e2 :: Nil =>
1280-
if (prefer(e2, e1)) e2 :: e1 :: Nil
1299+
if compareEligibles(e2, e1) < 0 then e2 :: e1 :: Nil
12811300
else eligible
12821301
case _ =>
1283-
try eligible.sortWith(prefer)
1302+
try eligible.sorted(using (a, b) => compareEligibles(a, b))
12841303
catch case ex: IllegalArgumentException =>
1285-
// diagnostic to see what went wrong
1304+
// Check if we violated the contract of java.util.Comparator#compare
12861305
for
1287-
e1 <- eligible
1288-
e2 <- eligible
1289-
if prefer(e1, e2)
1290-
e3 <- eligible
1291-
if prefer(e2, e3) && !prefer(e1, e3)
1306+
x <- eligible
1307+
y <- eligible
1308+
cmpXY = Integer.signum(compareEligibles(x, y))
1309+
cmpYX = Integer.signum(compareEligibles(y, x))
1310+
z <- eligible
1311+
cmpXZ = Integer.signum(compareEligibles(x, z))
1312+
cmpYZ = Integer.signum(compareEligibles(y, z))
12921313
do
1293-
val es = List(e1, e2, e3)
1294-
println(i"transitivity violated for $es%, %\n ${es.map(_.implicitRef.underlyingRef.symbol.showLocated)}%, %")
1314+
def reportViolation(msg: String): Unit =
1315+
Console.err.println(s"Internal error: comparison function violated ${msg.stripMargin}")
1316+
def showCandidate(c: Candidate): String =
1317+
s"$c (${c.ref.symbol.showLocated})"
1318+
1319+
if cmpXY != -cmpYX then
1320+
reportViolation(
1321+
s"""signum(cmp(x, y)) == -signum(cmp(y, x)) given:
1322+
|x = ${showCandidate(x)}
1323+
|y = ${showCandidate(y)}
1324+
|cmpXY = $cmpXY
1325+
|cmpYX = $cmpYX""")
1326+
if cmpXY != 0 && cmpXY == cmpYZ && cmpXZ != cmpXY then
1327+
reportViolation(
1328+
s"""transitivity given:
1329+
|x = ${showCandidate(x)}
1330+
|y = ${showCandidate(y)}
1331+
|z = ${showCandidate(z)}
1332+
|cmpXY = $cmpXY
1333+
|cmpXZ = $cmpXZ
1334+
|cmpYZ = $cmpYZ""")
1335+
if cmpXY == 0 && cmpXZ != cmpYZ then
1336+
reportViolation(
1337+
s"""cmp(x, y) == 0 implies that signum(cmp(x, z)) == signum(cmp(y, z)) given:
1338+
|x = ${showCandidate(x)}
1339+
|y = ${showCandidate(y)}
1340+
|z = ${showCandidate(z)}
1341+
|cmpXY = $cmpXY
1342+
|cmpXZ = $cmpXZ
1343+
|cmpYZ = $cmpYZ""")
1344+
end for
12951345
throw ex
1296-
}
12971346

12981347
rank(sort(eligible), NoMatchingImplicitsFailure, Nil)
12991348
end searchImplicit

0 commit comments

Comments
 (0)