diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index d4cf61f35829..16604c9e83b1 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1010,12 +1010,17 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { /** `tree ne null` (might need a cast to be type correct) */ def testNotNull(using Context): Tree = { - val receiver = if (tree.tpe.isBottomType) - // If the receiver is of type `Nothing` or `Null`, add an ascription so that the selection - // succeeds: e.g. `null.ne(null)` doesn't type, but `(null: AnyRef).ne(null)` does. - Typed(tree, TypeTree(defn.AnyRefType)) - else tree.ensureConforms(defn.ObjectType) - receiver.select(defn.Object_ne).appliedTo(nullLiteral).withSpan(tree.span) + // If the receiver is of type `Nothing` or `Null`, add an ascription or cast + // so that the selection succeeds. + // e.g. `null.ne(null)` doesn't type, but `(null: AnyRef).ne(null)` does. + val receiver = + if tree.tpe.isBottomType then + if ctx.explicitNulls then tree.cast(defn.AnyRefType) + else Typed(tree, TypeTree(defn.AnyRefType)) + else tree.ensureConforms(defn.ObjectType) + // also need to cast the null literal to AnyRef in explicit nulls + val nullLit = if ctx.explicitNulls then nullLiteral.cast(defn.AnyRefType) else nullLiteral + receiver.select(defn.Object_ne).appliedTo(nullLit).withSpan(tree.span) } /** If inititializer tree is `_`, the default value of its type, diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala index 241e18e67880..dd511f996f46 100644 --- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala +++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala @@ -153,6 +153,7 @@ class Erasure extends Phase with DenotTransformer { override def checkPostCondition(tree: tpd.Tree)(using Context): Unit = { assertErased(tree) tree match { + case _: tpd.Import => assert(false, i"illegal tree: $tree") case res: tpd.This => assert(!ExplicitOuter.referencesOuter(ctx.owner.lexicallyEnclosingClass, res), i"Reference to $res from ${ctx.owner.showLocated}") @@ -1034,7 +1035,8 @@ object Erasure { typed(tree.arg, pt) override def typedStats(stats: List[untpd.Tree], exprOwner: Symbol)(using Context): (List[Tree], Context) = { - val stats0 = addRetainedInlineBodies(stats)(using preErasureCtx) + // discard Imports first, since Bridges will use tree's symbol + val stats0 = addRetainedInlineBodies(stats.filter(!_.isInstanceOf[untpd.Import]))(using preErasureCtx) val stats1 = if (takesBridges(ctx.owner)) new Bridges(ctx.owner.asClass, erasurePhase).add(stats0) else stats0 @@ -1042,6 +1044,12 @@ object Erasure { (stats2.filterConserve(!_.isEmpty), finalCtx) } + /** Finally drops all (language-) imports in erasure. + * Since some of the language imports change the subtyping, + * we cannot check the trees before erasure. + */ + override def typedImport(tree: untpd.Import)(using Context) = EmptyTree + override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree = trace(i"adapting ${tree.showSummary()}: ${tree.tpe} to $pt", show = true) { if ctx.phase != erasurePhase && ctx.phase != erasurePhase.next then diff --git a/compiler/src/dotty/tools/dotc/transform/MixinOps.scala b/compiler/src/dotty/tools/dotc/transform/MixinOps.scala index 787a0d5be1df..fa1c09806893 100644 --- a/compiler/src/dotty/tools/dotc/transform/MixinOps.scala +++ b/compiler/src/dotty/tools/dotc/transform/MixinOps.scala @@ -6,6 +6,7 @@ import Symbols._, Types._, Contexts._, DenotTransformers._, Flags._ import util.Spans._ import SymUtils._ import StdNames._, NameOps._ +import typer.Nullables class MixinOps(cls: ClassSymbol, thisPhase: DenotTransformer)(using Context) { import ast.tpd._ @@ -80,13 +81,20 @@ class MixinOps(cls: ClassSymbol, thisPhase: DenotTransformer)(using Context) { prefss => val (targs, vargss) = splitArgs(prefss) val tapp = superRef(target).appliedToTypeTrees(targs) - vargss match + val rhs = vargss match case Nil | List(Nil) => // Overriding is somewhat loose about `()T` vs `=> T`, so just pick // whichever makes sense for `target` tapp.ensureApplied case _ => tapp.appliedToArgss(vargss) + if ctx.explicitNulls && target.is(JavaDefined) && !ctx.phase.erasedTypes then + // We may forward to a super Java member in resolveSuper phase. + // Since this is still before erasure, the type can be nullable + // and causes error during checking. So we need to enable + // unsafe-nulls to construct the rhs. + Block(Nullables.importUnsafeNulls :: Nil, rhs) + else rhs private def competingMethodsIterator(meth: Symbol): Iterator[Symbol] = cls.baseClasses.iterator diff --git a/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala b/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala index 0ab852398049..568512207fde 100644 --- a/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala +++ b/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala @@ -20,7 +20,6 @@ import Decorators.* * The phase also replaces all expressions that appear in an erased context by * default values. This is necessary so that subsequent checking phases such * as IsInstanceOfChecker don't give false negatives. - * Finally, the phase drops (language-) imports. */ class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform => import tpd._ @@ -56,18 +55,10 @@ class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform => checkErasedInExperimental(tree.symbol) tree - override def transformOther(tree: Tree)(using Context): Tree = tree match - case tree: Import => EmptyTree - case _ => tree - def checkErasedInExperimental(sym: Symbol)(using Context): Unit = // Make an exception for Scala 2 experimental macros to allow dual Scala 2/3 macros under non experimental mode if sym.is(Erased, butNot = Macro) && sym != defn.Compiletime_erasedValue && !sym.isInExperimentalScope then Feature.checkExperimentalFeature("erased", sym.sourcePos) - - override def checkPostCondition(tree: Tree)(using Context): Unit = tree match - case _: tpd.Import => assert(false, i"illegal tree: $tree") - case _ => } object PruneErasedDefs { diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala index faccb79f3c9a..d5a17c2248f9 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala @@ -122,7 +122,8 @@ class SyntheticMembers(thisPhase: DenotTransformer) { def nameRef: Tree = if isJavaEnumValue then - Select(This(clazz), nme.name).ensureApplied + val name = Select(This(clazz), nme.name).ensureApplied + if ctx.explicitNulls then name.cast(defn.StringType) else name else identifierRef diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index 4f38af4d6198..695392c0ca60 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -134,7 +134,7 @@ class TreeChecker extends Phase with SymTransformer { val checkingCtx = ctx .fresh - .setMode(Mode.ImplicitsEnabled) + .addMode(Mode.ImplicitsEnabled) .setReporter(new ThrowingReporter(ctx.reporter)) val checker = inContext(ctx) { diff --git a/compiler/src/dotty/tools/dotc/typer/Nullables.scala b/compiler/src/dotty/tools/dotc/typer/Nullables.scala index 08ab428e7203..0d4d23243f46 100644 --- a/compiler/src/dotty/tools/dotc/typer/Nullables.scala +++ b/compiler/src/dotty/tools/dotc/typer/Nullables.scala @@ -20,6 +20,10 @@ import ast.Trees.mods object Nullables: import ast.tpd._ + def importUnsafeNulls(using Context): Import = Import( + ref(defn.LanguageModule), + List(untpd.ImportSelector(untpd.Ident(nme.unsafeNulls), EmptyTree, EmptyTree))) + inline def unsafeNullsEnabled(using Context): Boolean = ctx.explicitNulls && !ctx.mode.is(Mode.SafeNulls) diff --git a/compiler/src/dotty/tools/dotc/typer/ReTyper.scala b/compiler/src/dotty/tools/dotc/typer/ReTyper.scala index 385d71b570ca..ee5a156ca5c7 100644 --- a/compiler/src/dotty/tools/dotc/typer/ReTyper.scala +++ b/compiler/src/dotty/tools/dotc/typer/ReTyper.scala @@ -51,7 +51,7 @@ class ReTyper(nestingLevel: Int = 0) extends Typer(nestingLevel) with ReChecking override def typedSuper(tree: untpd.Super, pt: Type)(using Context): Tree = promote(tree) - override def typedImport(tree: untpd.Import, sym: Symbol)(using Context): Tree = + override def typedImport(tree: untpd.Import)(using Context): Tree = promote(tree) override def typedTyped(tree: untpd.Typed, pt: Type)(using Context): Tree = { diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 9a62ac480c30..860ece62efab 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1059,7 +1059,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val (stats1, exprCtx) = withoutMode(Mode.Pattern) { typedBlockStats(tree.stats) } - val expr1 = typedExpr(tree.expr, pt.dropIfProto)(using exprCtx) + var expr1 = typedExpr(tree.expr, pt.dropIfProto)(using exprCtx) + + // If unsafe nulls is enabled inside a block but not enabled outside + // and the type does not conform the expected type without unsafe nulls, + // we will cast the last expression to the expected type. + // See: tests/explicit-nulls/pos/unsafe-block.scala + if ctx.mode.is(Mode.SafeNulls) + && !exprCtx.mode.is(Mode.SafeNulls) + && pt.isValueType + && !inContext(exprCtx.addMode(Mode.SafeNulls))(expr1.tpe <:< pt) then + expr1 = expr1.cast(pt) + ensureNoLocalRefs( cpy.Block(tree)(stats1, expr1) .withType(expr1.tpe) @@ -2602,7 +2613,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer |The selector is not a member of an object or package.""") else typd(imp.expr, AnySelectionProto) - def typedImport(imp: untpd.Import, sym: Symbol)(using Context): Tree = + def typedImport(imp: untpd.Import)(using Context): Tree = + val sym = retrieveSym(imp) val expr1 = typedImportQualifier(imp, typedExpr(_, _)(using ctx.withOwner(sym))) checkLegalImportPath(expr1) val selectors1 = typedSelectors(imp.selectors) @@ -2868,7 +2880,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case tree: untpd.If => typedIf(tree, pt) case tree: untpd.Function => typedFunction(tree, pt) case tree: untpd.Closure => typedClosure(tree, pt) - case tree: untpd.Import => typedImport(tree, retrieveSym(tree)) + case tree: untpd.Import => typedImport(tree) case tree: untpd.Export => typedExport(tree) case tree: untpd.Match => typedMatch(tree, pt) case tree: untpd.Return => typedReturn(tree) diff --git a/tests/explicit-nulls/pos/enums.scala b/tests/explicit-nulls/pos/enums.scala new file mode 100644 index 000000000000..6d323331fa34 --- /dev/null +++ b/tests/explicit-nulls/pos/enums.scala @@ -0,0 +1,12 @@ +enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMessageID]: + + case NoExplanationID // errorNumber: -1 + case EmptyCatchOrFinallyBlockID extends ErrorMessageID(isActive = false) // errorNumber: 0 + + def errorNumber = ordinal - 1 + +enum Color(val rgb: Int): + case Red extends Color(0xFF0000) + case Green extends Color(0x00FF00) + case Blue extends Color(0x0000FF) + diff --git a/tests/explicit-nulls/pos/test-not-null.scala b/tests/explicit-nulls/pos/test-not-null.scala new file mode 100644 index 000000000000..3cf17916dd40 --- /dev/null +++ b/tests/explicit-nulls/pos/test-not-null.scala @@ -0,0 +1,5 @@ +// testNotNull can be inserted during PatternMatcher +def f(xs: List[String]) = + xs.zipWithIndex.collect { + case (arg, idx) => idx + } \ No newline at end of file diff --git a/tests/explicit-nulls/pos/unsafe-block.scala b/tests/explicit-nulls/pos/unsafe-block.scala new file mode 100644 index 000000000000..1b103cf33269 --- /dev/null +++ b/tests/explicit-nulls/pos/unsafe-block.scala @@ -0,0 +1,67 @@ +def trim(x: String | Null): String = + import scala.language.unsafeNulls + // The type of `x.trim()` is `String | Null`. + // Although `String | Null` conforms the expected type `String`, + // we still need to cast the expression to the expected type here, + // because outside the scope we don't have `unsafeNulls` anymore. + x.trim() + +class TestDefs: + + def f1: String | Null = null + def f2: Array[String | Null] | Null = null + def f3: Array[String] | Null = null + + def h1a: String = + import scala.language.unsafeNulls + f1 + + def h1b: String | Null = + import scala.language.unsafeNulls + f1 + + def h2a: Array[String] = + import scala.language.unsafeNulls + f2 + + def h2b: Array[String | Null] = + import scala.language.unsafeNulls + f2 + + def h3a: Array[String] = + import scala.language.unsafeNulls + f3 + + def h3b: Array[String | Null] = + import scala.language.unsafeNulls + f3 + +class TestVals: + + val f1: String | Null = null + val f2: Array[String | Null] | Null = null + val f3: Array[String] | Null = null + + val h1a: String = + import scala.language.unsafeNulls + f1 + + val h1b: String | Null = + import scala.language.unsafeNulls + f1 + + val h2a: Array[String] = + import scala.language.unsafeNulls + f2 + + val h2b: Array[String | Null] = + import scala.language.unsafeNulls + f2 + + val h3a: Array[String] = + import scala.language.unsafeNulls + f3 + + val h3b: Array[String | Null] = + import scala.language.unsafeNulls + f3 \ No newline at end of file diff --git a/tests/explicit-nulls/pos/unsafe-chain.scala b/tests/explicit-nulls/pos/unsafe-chain.scala index 76c80d0c53fe..0ba52602d2dd 100644 --- a/tests/explicit-nulls/pos/unsafe-chain.scala +++ b/tests/explicit-nulls/pos/unsafe-chain.scala @@ -1,10 +1,21 @@ import java.nio.file.FileSystems import java.util.ArrayList -def directorySeparator: String = - import scala.language.unsafeNulls - FileSystems.getDefault().getSeparator() +class A: + + def directorySeparator: String = + import scala.language.unsafeNulls + FileSystems.getDefault().getSeparator() + + def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String = + import scala.language.unsafeNulls + xs.get(0).get(0).get(0) -def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String = +class B: import scala.language.unsafeNulls - xs.get(0).get(0).get(0) \ No newline at end of file + + def directorySeparator: String = + FileSystems.getDefault().getSeparator() + + def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String = + xs.get(0).get(0).get(0) \ No newline at end of file