diff --git a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala index c76dee520331..9d6c3020d406 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala @@ -6,15 +6,18 @@ import core.* import Contexts.* import Symbols.* import Types.* +import Denotations.Denotation import StdNames.* +import Names.TermName import NameKinds.OuterSelectName import NameKinds.SuperAccessorName import ast.tpd.* -import util.SourcePosition +import util.{ SourcePosition, NoSourcePosition } import config.Printers.init as printer import reporting.StoreReporter import reporting.trace as log +import typer.Applications.* import Errors.* import Trace.* @@ -249,7 +252,7 @@ object Objects: val joinedTrace = data.pendingTraces.slice(index + 1, data.checkingObjects.size).foldLeft(pendingTrace) { (a, acc) => acc ++ a } val callTrace = Trace.buildStacktrace(joinedTrace, "Calling trace:\n") val cycle = data.checkingObjects.slice(index, data.checkingObjects.size) - val pos = clazz.defTree + val pos = clazz.defTree.sourcePos.focus report.warning("Cyclic initialization: " + cycle.map(_.klass.show).mkString(" -> ") + " -> " + clazz.show + ". " + callTrace, pos) end if data.checkingObjects(index) @@ -834,11 +837,10 @@ object Objects: /** Handle local variable definition, `val x = e` or `var x = e`. * - * @param ref The value for `this` where the variable is defined. * @param sym The symbol of the variable. * @param value The value of the initializer. */ - def initLocal(ref: Ref, sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) { + def initLocal(sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) { if sym.is(Flags.Mutable) then val addr = Heap.localVarAddr(summon[Regions.Data], sym, State.currentObject) Env.setLocalVar(sym, addr) @@ -870,9 +872,6 @@ object Objects: case _ => report.warning("[Internal error] Variable not found " + sym.show + "\nenv = " + env.show + ". Calling trace:\n" + Trace.show, Trace.position) Bottom - else if sym.isPatternBound then - // TODO: handle patterns - Cold else given Env.Data = env // Assume forward reference check is doing a good job @@ -1113,11 +1112,9 @@ object Objects: else eval(arg, thisV, klass) - case Match(selector, cases) => - eval(selector, thisV, klass) - // TODO: handle pattern match properly - report.warning("[initChecker] Pattern match is skipped. Trace:\n" + Trace.show, expr) - Bottom + case Match(scrutinee, cases) => + val scrutineeValue = eval(scrutinee, thisV, klass) + patternMatch(scrutineeValue, cases, thisV, klass) case Return(expr, from) => Returns.handle(from.symbol, eval(expr, thisV, klass)) @@ -1151,7 +1148,7 @@ object Objects: // local val definition val rhs = eval(vdef.rhs, thisV, klass) val sym = vdef.symbol - initLocal(thisV.asInstanceOf[Ref], vdef.symbol, rhs) + initLocal(vdef.symbol, rhs) Bottom case ddef : DefDef => @@ -1173,6 +1170,196 @@ object Objects: Bottom } + /** Evaluate the cases against the scrutinee value. + * + * It returns the scrutinee in most cases. The main effect of the function is for its side effects of adding bindings + * to the environment. + * + * See https://docs.scala-lang.org/scala3/reference/changed-features/pattern-matching.html + * + * @param scrutinee The abstract value of the scrutinee. + * @param cases The cases to match. + * @param thisV The value for `C.this` where `C` is represented by `klass`. + * @param klass The enclosing class where the type `tp` is located. + */ + def patternMatch(scrutinee: Value, cases: List[CaseDef], thisV: Value, klass: ClassSymbol): Contextual[Value] = + // expected member types for `unapplySeq` + def lengthType = ExprType(defn.IntType) + def lengthCompareType = MethodType(List(defn.IntType), defn.IntType) + def applyType(elemTp: Type) = MethodType(List(defn.IntType), elemTp) + def dropType(elemTp: Type) = MethodType(List(defn.IntType), defn.CollectionSeqType.appliedTo(elemTp)) + def toSeqType(elemTp: Type) = ExprType(defn.CollectionSeqType.appliedTo(elemTp)) + + def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation = + receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp) + + def evalCase(caseDef: CaseDef): Value = + evalPattern(scrutinee, caseDef.pat) + eval(caseDef.guard, thisV, klass) + eval(caseDef.body, thisV, klass) + + /** Abstract evaluation of patterns. + * + * It augments the local environment for bound pattern variables. As symbols are globally + * unique, we can put them in a single environment. + * + * Currently, we assume all cases are reachable, thus all patterns are assumed to match. + */ + def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show): + val trace2 = Trace.trace.add(pat) + pat match + case Alternative(pats) => + for pat <- pats do evalPattern(scrutinee, pat) + scrutinee + + case bind @ Bind(_, pat) => + val value = evalPattern(scrutinee, pat) + initLocal(bind.symbol, value) + scrutinee + + case UnApply(fun, implicits, pats) => + given Trace = trace2 + + val fun1 = funPart(fun) + val funRef = fun1.tpe.asInstanceOf[TermRef] + val unapplyResTp = funRef.widen.finalResultType + + val receiver = fun1 match + case ident: Ident => + evalType(funRef.prefix, thisV, klass) + case select: Select => + eval(select.qualifier, thisV, klass) + + val implicitValues = evalArgs(implicits.map(Arg.apply), thisV, klass) + // TODO: implicit values may appear before and/or after the scrutinee parameter. + val unapplyRes = call(receiver, funRef.symbol, TraceValue(scrutinee, summon[Trace]) :: implicitValues, funRef.prefix, superType = NoType, needResolve = true) + + if fun.symbol.name == nme.unapplySeq then + var resultTp = unapplyResTp + var elemTp = unapplySeqTypeElemTp(resultTp) + var arity = productArity(resultTp, NoSourcePosition) + var needsGet = false + if (!elemTp.exists && arity <= 0) { + needsGet = true + resultTp = resultTp.select(nme.get).finalResultType + elemTp = unapplySeqTypeElemTp(resultTp.widen) + arity = productSelectorTypes(resultTp, NoSourcePosition).size + } + + var resToMatch = unapplyRes + + if needsGet then + // Get match + val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless) + call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true) + + val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless) + resToMatch = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true) + end if + + if elemTp.exists then + // sequence match + evalSeqPatterns(resToMatch, resultTp, elemTp, pats) + else + // product sequence match + val selectors = productSelectors(resultTp) + assert(selectors.length <= pats.length) + selectors.init.zip(pats).map { (sel, pat) => + val selectRes = call(resToMatch, sel, Nil, resultTp, superType = NoType, needResolve = true) + evalPattern(selectRes, pat) + } + val seqPats = pats.drop(selectors.length - 1) + val toSeqRes = call(resToMatch, selectors.last, Nil, resultTp, superType = NoType, needResolve = true) + val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType + evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats) + end if + + else + // distribute unapply to patterns + if isProductMatch(unapplyResTp, pats.length) then + // product match + val selectors = productSelectors(unapplyResTp) + assert(selectors.length == pats.length) + selectors.zip(pats).map { (sel, pat) => + val selectRes = call(unapplyRes, sel, Nil, unapplyResTp, superType = NoType, needResolve = true) + evalPattern(selectRes, pat) + } + else if unapplyResTp <:< defn.BooleanType then + // Boolean extractor, do nothing + () + else + // Get match + val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless) + call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true) + + val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless) + val getRes = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true) + if pats.length == 1 then + // single match + evalPattern(getRes, pats.head) + else + val getResTp = getDenot.info.finalResultType + val selectors = productSelectors(getResTp).take(pats.length) + selectors.zip(pats).map { (sel, pat) => + val selectRes = call(unapplyRes, sel, Nil, getResTp, superType = NoType, needResolve = true) + evalPattern(selectRes, pat) + } + end if + end if + end if + scrutinee + + case Ident(nme.WILDCARD) | Ident(nme.WILDCARD_STAR) => + scrutinee + + case Typed(pat, _) => + evalPattern(scrutinee, pat) + + case tree => + // For all other trees, the semantics is normal. + eval(tree, thisV, klass) + + end evalPattern + + /** + * Evaluate a sequence value against sequence patterns. + */ + def evalSeqPatterns(scrutinee: Value, scrutineeType: Type, elemType: Type, pats: List[Tree])(using Trace): Unit = + // call .lengthCompare or .length + val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType) + if lengthCompareDenot.exists then + call(scrutinee, lengthCompareDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true) + else + val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType) + call(scrutinee, lengthDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true) + end if + + // call .apply + val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType)) + val applyRes = call(scrutinee, applyDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true) + + if isWildcardStarArg(pats.last) then + if pats.size == 1 then + // call .toSeq + val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless) + val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true) + evalPattern(toSeqRes, pats.head) + else + // call .drop + val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType)) + val dropRes = call(scrutinee, dropDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true) + for pat <- pats.init do evalPattern(applyRes, pat) + evalPattern(dropRes, pats.last) + end if + else + // no patterns like `xs*` + for pat <- pats do evalPattern(applyRes, pat) + end evalSeqPatterns + + + cases.map(evalCase).join + end patternMatch + /** Handle semantics of leaf nodes * * For leaf nodes, their semantics is determined by their types. @@ -1231,7 +1418,7 @@ object Objects: resolveThis(tref.classSymbol.asClass, thisV, klass) case _ => - throw new Exception("unexpected type: " + tp) + throw new Exception("unexpected type: " + tp + ", Trace:\n" + Trace.show) } /** Evaluate arguments of methods and constructors */ diff --git a/compiler/src/dotty/tools/dotc/transform/init/Trace.scala b/compiler/src/dotty/tools/dotc/transform/init/Trace.scala index 7dfbc0b6cfa5..7f3208dae952 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Trace.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Trace.scala @@ -49,7 +49,12 @@ object Trace: val code = SyntaxHighlighting.highlight(pos.lineContent.trim.nn) i"$code\t$loc" else - tree.show + tree match + case defDef: DefTree => + // The definition can be huge, avoid printing the whole definition. + defDef.symbol.show + case _ => + tree.show val positionMarkerLine = if pos.exists && pos.source.exists then positionMarker(pos) diff --git a/compiler/src/dotty/tools/dotc/transform/init/Util.scala b/compiler/src/dotty/tools/dotc/transform/init/Util.scala index ad7d2afffbaf..5fcc65bcce10 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Util.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Util.scala @@ -26,6 +26,9 @@ object Util: opaque type Arg = Tree | ByNameArg case class ByNameArg(tree: Tree) + object Arg: + def apply(tree: Tree): Arg = tree + extension (arg: Arg) def isByName = arg.isInstanceOf[ByNameArg] def tree: Tree = arg match diff --git a/tests/init-global/neg/global-cycle1.check b/tests/init-global/neg/global-cycle1.check index e0125630c077..faed741761a7 100644 --- a/tests/init-global/neg/global-cycle1.check +++ b/tests/init-global/neg/global-cycle1.check @@ -1,17 +1,15 @@ -- Error: tests/init-global/neg/global-cycle1.scala:1:7 ---------------------------------------------------------------- 1 |object A { // error - |^ - |Cyclic initialization: object A -> object B -> object A. Calling trace: - |-> object A { // error [ global-cycle1.scala:1 ] - | ^ - |-> val a: Int = B.b [ global-cycle1.scala:2 ] - | ^ - |-> object B { [ global-cycle1.scala:5 ] - | ^ - |-> val b: Int = A.a // error [ global-cycle1.scala:6 ] - | ^ -2 | val a: Int = B.b -3 |} + | ^ + | Cyclic initialization: object A -> object B -> object A. Calling trace: + | -> object A { // error [ global-cycle1.scala:1 ] + | ^ + | -> val a: Int = B.b [ global-cycle1.scala:2 ] + | ^ + | -> object B { [ global-cycle1.scala:5 ] + | ^ + | -> val b: Int = A.a // error [ global-cycle1.scala:6 ] + | ^ -- Error: tests/init-global/neg/global-cycle1.scala:6:17 --------------------------------------------------------------- 6 | val b: Int = A.a // error | ^^^ diff --git a/tests/init-global/neg/global-cycle6.scala b/tests/init-global/neg/global-cycle6.scala index 2d4f23c25187..36e3ab0b6a94 100644 --- a/tests/init-global/neg/global-cycle6.scala +++ b/tests/init-global/neg/global-cycle6.scala @@ -1,7 +1,7 @@ object A { // error val n: Int = B.m class Inner { - println(n) + println(n) // error } } diff --git a/tests/init-global/neg/patmat-unapplySeq.check b/tests/init-global/neg/patmat-unapplySeq.check new file mode 100644 index 000000000000..74093b029614 --- /dev/null +++ b/tests/init-global/neg/patmat-unapplySeq.check @@ -0,0 +1,11 @@ +-- Error: tests/init-global/neg/patmat-unapplySeq.scala:8:32 ----------------------------------------------------------- +8 | def apply(i: Int): Box = array(i) // error + | ^^^^^^^^ + |Reading mutable state of object A during initialization of object B. + |Reading mutable state of other static objects is forbidden as it breaks initialization-time irrelevance. Calling trace: + |-> object B: [ patmat-unapplySeq.scala:15 ] + | ^ + |-> case A(b) => [ patmat-unapplySeq.scala:17 ] + | ^^^^ + |-> def apply(i: Int): Box = array(i) // error [ patmat-unapplySeq.scala:8 ] + | ^^^^^^^^ diff --git a/tests/init-global/neg/patmat-unapplySeq.scala b/tests/init-global/neg/patmat-unapplySeq.scala new file mode 100644 index 000000000000..81c853a6e19f --- /dev/null +++ b/tests/init-global/neg/patmat-unapplySeq.scala @@ -0,0 +1,17 @@ +object A: + class Box(var x: Int) + + val array: Array[Box] = new Array(1) + array(0) = new Box(10) + + def length: Int = array.length + def apply(i: Int): Box = array(i) // error + def drop(n: Int): Seq[Box] = array.toSeq + def toSeq: Seq[Box] = array.toSeq + + def unapplySeq(array: Array[Box]): A.type = this + + +object B: + A.array match + case A(b) => diff --git a/tests/init-global/neg/patmat-unapplySeq2.scala b/tests/init-global/neg/patmat-unapplySeq2.scala new file mode 100644 index 000000000000..adab9495db49 --- /dev/null +++ b/tests/init-global/neg/patmat-unapplySeq2.scala @@ -0,0 +1,17 @@ +object A: + class Box(var x: Int) + + val array: Array[Box] = new Array(1) + array(0) = new Box(10) + + def length: Int = array.length + def apply(i: Int): Box = array(i) // error + def drop(n: Int): Seq[Box] = array.toSeq + def toSeq: Seq[Box] = array.toSeq + + def unapplySeq(array: Array[Box]): A.type = this + + +object B: + A.array match + case A(b*) => diff --git a/tests/init-global/neg/patmat.scala b/tests/init-global/neg/patmat.scala new file mode 100644 index 000000000000..126e66e7cf7b --- /dev/null +++ b/tests/init-global/neg/patmat.scala @@ -0,0 +1,36 @@ +object A: // error + val a: Option[Int] = Some(3) + a match + case Some(x) => println(x * 2 + B.a.size) + case None => println(0) + +object B: + val a = 3 :: 4 :: Nil + a match + case x :: xs => + println(x * 2) + if A.a.isEmpty then println(xs.size) + case Nil => + println(0) + +case class Box[T](value: T) +case class Holder[T](value: T) +object C: + (Box(5): Box[Int] | Holder[Int]) match + case Box(x) => x + case Holder(x) => x + + (Box(5): Box[Int] | Holder[Int]) match + case box: Box[Int] => box.value + case holder: Holder[Int] => holder.value + + val a: Int = Inner.b + + object Inner: // error + val b: Int = 10 + + val foo: () => Int = () => C.a + + (Box(foo): Box[() => Int] | Holder[Int]) match + case Box(f) => f() + case Holder(x) => x diff --git a/tests/init-global/neg/t9115.scala b/tests/init-global/neg/t9115.scala index e7cfe09e560c..a3020c6939a8 100644 --- a/tests/init-global/neg/t9115.scala +++ b/tests/init-global/neg/t9115.scala @@ -1,4 +1,4 @@ -object D { +object D { // error def aaa = 1 //that’s the reason class Z (depends: Any) case object D1 extends Z(aaa) // 'null' when calling D.D1 first time // error diff --git a/tests/init-global/neg/t9312.scala b/tests/init-global/neg/t9312.scala index 703cf67e05c4..d88093a2f67a 100644 --- a/tests/init-global/neg/t9312.scala +++ b/tests/init-global/neg/t9312.scala @@ -8,7 +8,7 @@ object DeadLockTest { } - object Parent { + object Parent { // error trait Child { Thread.sleep(2000) // ensure concurrent behavior val parent = Parent diff --git a/tests/init-global/neg/t9360.scala b/tests/init-global/neg/t9360.scala index 291c4dd05db1..2ec0c740d739 100644 --- a/tests/init-global/neg/t9360.scala +++ b/tests/init-global/neg/t9360.scala @@ -2,7 +2,7 @@ class BaseClass(s: String) { def print: Unit = () } -object Obj { +object Obj { // error val s: String = "hello" object AObj extends BaseClass(s) // error diff --git a/tests/init-global/pos/patmat-interpolator.scala b/tests/init-global/pos/patmat-interpolator.scala new file mode 100644 index 000000000000..2df74326b77a --- /dev/null +++ b/tests/init-global/pos/patmat-interpolator.scala @@ -0,0 +1,3 @@ +object Test: + val RootPackage = "_root_/" + val s"${RootPackageName @ _}/" = RootPackage: @unchecked diff --git a/tests/init-global/pos/patmat.scala b/tests/init-global/pos/patmat.scala new file mode 100644 index 000000000000..72a00f373e75 --- /dev/null +++ b/tests/init-global/pos/patmat.scala @@ -0,0 +1,14 @@ +object A: + val a: Option[Int] = Some(3) + a match + case Some(x) => println(x * 2) + case None => println(0) + +object B: + val a = 3 :: 4 :: Nil + a match + case x :: xs => + println(x * 2) + println(xs.size) + case Nil => + println(0)