@@ -263,7 +263,11 @@ object Types {
263
263
/** True if this type is an instance of the given `cls` or an instance of
264
264
* a non-bottom subclass of `cls`.
265
265
*/
266
- final def derivesFrom (cls : Symbol , afterErasure : Boolean = false )(using Context ): Boolean = {
266
+ final def derivesFrom (cls : Symbol , isErased : Boolean = false )(using Context ): Boolean = {
267
+ def isLowerBottomType (tp : Type ) =
268
+ (if isErased then tp.isBottomTypeAfterErasure else tp.isBottomType)
269
+ && (tp.hasClassSymbol(defn.NothingClass )
270
+ || cls != defn.NothingClass && ! cls.isValueClass)
267
271
def loop (tp : Type ): Boolean = tp match {
268
272
case tp : TypeRef =>
269
273
val sym = tp.symbol
@@ -280,10 +284,6 @@ object Types {
280
284
// If the type is `T | Null` or `T | Nothing`, the class is != Nothing,
281
285
// and `T` derivesFrom the class, then the OrType derivesFrom the class.
282
286
// Otherwise, we need to check both sides derivesFrom the class.
283
- def isLowerBottomType (tp : Type ) =
284
- (if afterErasure then t.isBottomTypeAfterErasure else t.isBottomType)
285
- && (tp.hasClassSymbol(defn.NothingClass )
286
- || cls != defn.NothingClass && ! cls.isValueClass)
287
287
if isLowerBottomType(tp.tp1) then
288
288
loop(tp.tp2)
289
289
else if isLowerBottomType(tp.tp2) then
@@ -467,28 +467,45 @@ object Types {
467
467
* instance, or NoSymbol if none exists (either because this type is not a
468
468
* value type, or because superclasses are ambiguous).
469
469
*/
470
- final def classSymbol (using Context ): Symbol = this match {
471
- case tp : TypeRef =>
472
- val sym = tp.symbol
473
- if (sym.isClass) sym else tp.superType.classSymbol
474
- case tp : TypeProxy =>
475
- tp.underlying.classSymbol
476
- case tp : ClassInfo =>
477
- tp.cls
478
- case AndType (l, r) =>
479
- val lsym = l.classSymbol
480
- val rsym = r.classSymbol
481
- if (lsym isSubClass rsym) lsym
482
- else if (rsym isSubClass lsym) rsym
483
- else NoSymbol
484
- case tp : OrType =>
485
- tp.join.classSymbol
486
- case _ : JavaArrayType =>
487
- defn.ArrayClass
488
- case _ =>
489
- NoSymbol
490
- }
470
+ final def classSymbol (using Context ): Symbol = classSymbolWith(false )
471
+ final def classSymbolAfterErasure (using Context ): Symbol = classSymbolWith(true )
472
+
473
+ final private def classSymbolWith (isErased : Boolean )(using Context ): Symbol = {
474
+ def loop (tp: Type ): Symbol = tp match {
475
+ case tp : TypeRef =>
476
+ val sym = tp.symbol
477
+ if (sym.isClass) sym else loop(tp.superType)
478
+ case tp : TypeProxy =>
479
+ loop(tp.underlying)
480
+ case tp : ClassInfo =>
481
+ tp.cls
482
+ case AndType (l, r) =>
483
+ val lsym = loop(l)
484
+ val rsym = loop(r)
485
+ if (lsym isSubClass rsym) lsym
486
+ else if (rsym isSubClass lsym) rsym
487
+ else NoSymbol
488
+ case tp : OrType =>
489
+ if tp.tp1.hasClassSymbol(defn.NothingClass ) then
490
+ loop(tp.tp2)
491
+ else if tp.tp2.hasClassSymbol(defn.NothingClass ) then
492
+ loop(tp.tp1)
493
+ else
494
+ val tp1Null = tp.tp1.hasClassSymbol(defn.NullClass )
495
+ val tp2Null = tp.tp2.hasClassSymbol(defn.NullClass )
496
+ if isErased && (tp1Null || tp2Null) then
497
+ val otherSide = if tp1Null then loop(tp.tp2) else loop(tp.tp1)
498
+ if otherSide.isValueClass then defn.AnyClass else otherSide
499
+ else
500
+ loop(tp.join)
501
+ case _ : JavaArrayType =>
502
+ defn.ArrayClass
503
+ case _ =>
504
+ NoSymbol
505
+ }
491
506
507
+ loop(this )
508
+ }
492
509
/** The least (wrt <:<) set of symbols satisfying the `include` predicate of which this type is a subtype
493
510
*/
494
511
final def parentSymbols (include : Symbol => Boolean )(using Context ): List [Symbol ] = this match {
0 commit comments