Skip to content

Fix checking ctx to carry correct modes #15350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1010,12 +1010,17 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

/** `tree ne null` (might need a cast to be type correct) */
def testNotNull(using Context): Tree = {
val receiver = if (tree.tpe.isBottomType)
// If the receiver is of type `Nothing` or `Null`, add an ascription so that the selection
// succeeds: e.g. `null.ne(null)` doesn't type, but `(null: AnyRef).ne(null)` does.
Typed(tree, TypeTree(defn.AnyRefType))
else tree.ensureConforms(defn.ObjectType)
receiver.select(defn.Object_ne).appliedTo(nullLiteral).withSpan(tree.span)
// If the receiver is of type `Nothing` or `Null`, add an ascription or cast
// so that the selection succeeds.
// e.g. `null.ne(null)` doesn't type, but `(null: AnyRef).ne(null)` does.
val receiver =
if tree.tpe.isBottomType then
if ctx.explicitNulls then tree.cast(defn.AnyRefType)
else Typed(tree, TypeTree(defn.AnyRefType))
else tree.ensureConforms(defn.ObjectType)
// also need to cast the null literal to AnyRef in explicit nulls
val nullLit = if ctx.explicitNulls then nullLiteral.cast(defn.AnyRefType) else nullLiteral
receiver.select(defn.Object_ne).appliedTo(nullLit).withSpan(tree.span)
}

/** If inititializer tree is `_`, the default value of its type,
Expand Down
10 changes: 9 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class Erasure extends Phase with DenotTransformer {
override def checkPostCondition(tree: tpd.Tree)(using Context): Unit = {
assertErased(tree)
tree match {
case _: tpd.Import => assert(false, i"illegal tree: $tree")
case res: tpd.This =>
assert(!ExplicitOuter.referencesOuter(ctx.owner.lexicallyEnclosingClass, res),
i"Reference to $res from ${ctx.owner.showLocated}")
Expand Down Expand Up @@ -1034,14 +1035,21 @@ object Erasure {
typed(tree.arg, pt)

override def typedStats(stats: List[untpd.Tree], exprOwner: Symbol)(using Context): (List[Tree], Context) = {
val stats0 = addRetainedInlineBodies(stats)(using preErasureCtx)
// discard Imports first, since Bridges will use tree's symbol
val stats0 = addRetainedInlineBodies(stats.filter(!_.isInstanceOf[untpd.Import]))(using preErasureCtx)
val stats1 =
if (takesBridges(ctx.owner)) new Bridges(ctx.owner.asClass, erasurePhase).add(stats0)
else stats0
val (stats2, finalCtx) = super.typedStats(stats1, exprOwner)
(stats2.filterConserve(!_.isEmpty), finalCtx)
}

/** Finally drops all (language-) imports in erasure.
* Since some of the language imports change the subtyping,
* we cannot check the trees before erasure.
*/
override def typedImport(tree: untpd.Import)(using Context) = EmptyTree

override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree =
trace(i"adapting ${tree.showSummary()}: ${tree.tpe} to $pt", show = true) {
if ctx.phase != erasurePhase && ctx.phase != erasurePhase.next then
Expand Down
10 changes: 9 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/MixinOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Symbols._, Types._, Contexts._, DenotTransformers._, Flags._
import util.Spans._
import SymUtils._
import StdNames._, NameOps._
import typer.Nullables

class MixinOps(cls: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
import ast.tpd._
Expand Down Expand Up @@ -80,13 +81,20 @@ class MixinOps(cls: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
prefss =>
val (targs, vargss) = splitArgs(prefss)
val tapp = superRef(target).appliedToTypeTrees(targs)
vargss match
val rhs = vargss match
case Nil | List(Nil) =>
// Overriding is somewhat loose about `()T` vs `=> T`, so just pick
// whichever makes sense for `target`
tapp.ensureApplied
case _ =>
tapp.appliedToArgss(vargss)
if ctx.explicitNulls && target.is(JavaDefined) && !ctx.phase.erasedTypes then
// We may forward to a super Java member in resolveSuper phase.
// Since this is still before erasure, the type can be nullable
// and causes error during checking. So we need to enable
// unsafe-nulls to construct the rhs.
Block(Nullables.importUnsafeNulls :: Nil, rhs)
else rhs

private def competingMethodsIterator(meth: Symbol): Iterator[Symbol] =
cls.baseClasses.iterator
Expand Down
9 changes: 0 additions & 9 deletions compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import Decorators.*
* The phase also replaces all expressions that appear in an erased context by
* default values. This is necessary so that subsequent checking phases such
* as IsInstanceOfChecker don't give false negatives.
* Finally, the phase drops (language-) imports.
*/
class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform =>
import tpd._
Expand Down Expand Up @@ -56,18 +55,10 @@ class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform =>
checkErasedInExperimental(tree.symbol)
tree

override def transformOther(tree: Tree)(using Context): Tree = tree match
case tree: Import => EmptyTree
case _ => tree

def checkErasedInExperimental(sym: Symbol)(using Context): Unit =
// Make an exception for Scala 2 experimental macros to allow dual Scala 2/3 macros under non experimental mode
if sym.is(Erased, butNot = Macro) && sym != defn.Compiletime_erasedValue && !sym.isInExperimentalScope then
Feature.checkExperimentalFeature("erased", sym.sourcePos)

override def checkPostCondition(tree: Tree)(using Context): Unit = tree match
case _: tpd.Import => assert(false, i"illegal tree: $tree")
case _ =>
}

object PruneErasedDefs {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ class SyntheticMembers(thisPhase: DenotTransformer) {

def nameRef: Tree =
if isJavaEnumValue then
Select(This(clazz), nme.name).ensureApplied
val name = Select(This(clazz), nme.name).ensureApplied
if ctx.explicitNulls then name.cast(defn.StringType) else name
else
identifierRef

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class TreeChecker extends Phase with SymTransformer {

val checkingCtx = ctx
.fresh
.setMode(Mode.ImplicitsEnabled)
.addMode(Mode.ImplicitsEnabled)
.setReporter(new ThrowingReporter(ctx.reporter))

val checker = inContext(ctx) {
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ import ast.Trees.mods
object Nullables:
import ast.tpd._

def importUnsafeNulls(using Context): Import = Import(
ref(defn.LanguageModule),
List(untpd.ImportSelector(untpd.Ident(nme.unsafeNulls), EmptyTree, EmptyTree)))

inline def unsafeNullsEnabled(using Context): Boolean =
ctx.explicitNulls && !ctx.mode.is(Mode.SafeNulls)

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ReTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ReTyper(nestingLevel: Int = 0) extends Typer(nestingLevel) with ReChecking
override def typedSuper(tree: untpd.Super, pt: Type)(using Context): Tree =
promote(tree)

override def typedImport(tree: untpd.Import, sym: Symbol)(using Context): Tree =
override def typedImport(tree: untpd.Import)(using Context): Tree =
promote(tree)

override def typedTyped(tree: untpd.Typed, pt: Type)(using Context): Tree = {
Expand Down
18 changes: 15 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val (stats1, exprCtx) = withoutMode(Mode.Pattern) {
typedBlockStats(tree.stats)
}
val expr1 = typedExpr(tree.expr, pt.dropIfProto)(using exprCtx)
var expr1 = typedExpr(tree.expr, pt.dropIfProto)(using exprCtx)

// If unsafe nulls is enabled inside a block but not enabled outside
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment does not correspond to the code. exprCtx is the context inside the block valid for the final expression. ctx is the context outside the block. The cast is performed if unsafeNulls is disabled inside the block but enabled around it. Is this what's intended? Either the comment or the code has to be fixed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment is correct: !exprCtx.mode.is(Mode.SafeNulls) indicates SafeNulls is disabled inside the block, which means unsafeNulls is enabled here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The confusion is that the mode is SafeNulls which is a negation of unsafeNulls. I think the comment correctly corresponds to the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, my bad. I misread the code.

// and the type does not conform the expected type without unsafe nulls,
// we will cast the last expression to the expected type.
// See: tests/explicit-nulls/pos/unsafe-block.scala
if ctx.mode.is(Mode.SafeNulls)
&& !exprCtx.mode.is(Mode.SafeNulls)
&& pt.isValueType
&& !inContext(exprCtx.addMode(Mode.SafeNulls))(expr1.tpe <:< pt) then
expr1 = expr1.cast(pt)

ensureNoLocalRefs(
cpy.Block(tree)(stats1, expr1)
.withType(expr1.tpe)
Expand Down Expand Up @@ -2602,7 +2613,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
|The selector is not a member of an object or package.""")
else typd(imp.expr, AnySelectionProto)

def typedImport(imp: untpd.Import, sym: Symbol)(using Context): Tree =
def typedImport(imp: untpd.Import)(using Context): Tree =
val sym = retrieveSym(imp)
val expr1 = typedImportQualifier(imp, typedExpr(_, _)(using ctx.withOwner(sym)))
checkLegalImportPath(expr1)
val selectors1 = typedSelectors(imp.selectors)
Expand Down Expand Up @@ -2868,7 +2880,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case tree: untpd.If => typedIf(tree, pt)
case tree: untpd.Function => typedFunction(tree, pt)
case tree: untpd.Closure => typedClosure(tree, pt)
case tree: untpd.Import => typedImport(tree, retrieveSym(tree))
case tree: untpd.Import => typedImport(tree)
case tree: untpd.Export => typedExport(tree)
case tree: untpd.Match => typedMatch(tree, pt)
case tree: untpd.Return => typedReturn(tree)
Expand Down
12 changes: 12 additions & 0 deletions tests/explicit-nulls/pos/enums.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMessageID]:

case NoExplanationID // errorNumber: -1
case EmptyCatchOrFinallyBlockID extends ErrorMessageID(isActive = false) // errorNumber: 0

def errorNumber = ordinal - 1

enum Color(val rgb: Int):
case Red extends Color(0xFF0000)
case Green extends Color(0x00FF00)
case Blue extends Color(0x0000FF)

5 changes: 5 additions & 0 deletions tests/explicit-nulls/pos/test-not-null.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// testNotNull can be inserted during PatternMatcher
def f(xs: List[String]) =
xs.zipWithIndex.collect {
case (arg, idx) => idx
}
67 changes: 67 additions & 0 deletions tests/explicit-nulls/pos/unsafe-block.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
def trim(x: String | Null): String =
import scala.language.unsafeNulls
// The type of `x.trim()` is `String | Null`.
// Although `String | Null` conforms the expected type `String`,
// we still need to cast the expression to the expected type here,
// because outside the scope we don't have `unsafeNulls` anymore.
x.trim()

class TestDefs:

def f1: String | Null = null
def f2: Array[String | Null] | Null = null
def f3: Array[String] | Null = null

def h1a: String =
import scala.language.unsafeNulls
f1

def h1b: String | Null =
import scala.language.unsafeNulls
f1

def h2a: Array[String] =
import scala.language.unsafeNulls
f2

def h2b: Array[String | Null] =
import scala.language.unsafeNulls
f2

def h3a: Array[String] =
import scala.language.unsafeNulls
f3

def h3b: Array[String | Null] =
import scala.language.unsafeNulls
f3

class TestVals:

val f1: String | Null = null
val f2: Array[String | Null] | Null = null
val f3: Array[String] | Null = null

val h1a: String =
import scala.language.unsafeNulls
f1

val h1b: String | Null =
import scala.language.unsafeNulls
f1

val h2a: Array[String] =
import scala.language.unsafeNulls
f2

val h2b: Array[String | Null] =
import scala.language.unsafeNulls
f2

val h3a: Array[String] =
import scala.language.unsafeNulls
f3

val h3b: Array[String | Null] =
import scala.language.unsafeNulls
f3
21 changes: 16 additions & 5 deletions tests/explicit-nulls/pos/unsafe-chain.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import java.nio.file.FileSystems
import java.util.ArrayList

def directorySeparator: String =
import scala.language.unsafeNulls
FileSystems.getDefault().getSeparator()
class A:

def directorySeparator: String =
import scala.language.unsafeNulls
FileSystems.getDefault().getSeparator()

def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String =
import scala.language.unsafeNulls
xs.get(0).get(0).get(0)

def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String =
class B:
import scala.language.unsafeNulls
xs.get(0).get(0).get(0)

def directorySeparator: String =
FileSystems.getDefault().getSeparator()

def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String =
xs.get(0).get(0).get(0)