Skip to content

Commit 04eae14

Browse files
authored
Pattern match support in checking global objects (#18127)
Pattern match in checking global objects
2 parents ca29cdc + 4cfcacf commit 04eae14

14 files changed

+322
-31
lines changed

compiler/src/dotty/tools/dotc/transform/init/Objects.scala

Lines changed: 201 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@ import core.*
66
import Contexts.*
77
import Symbols.*
88
import Types.*
9+
import Denotations.Denotation
910
import StdNames.*
11+
import Names.TermName
1012
import NameKinds.OuterSelectName
1113
import NameKinds.SuperAccessorName
1214

1315
import ast.tpd.*
14-
import util.SourcePosition
16+
import util.{ SourcePosition, NoSourcePosition }
1517
import config.Printers.init as printer
1618
import reporting.StoreReporter
1719
import reporting.trace as log
20+
import typer.Applications.*
1821

1922
import Errors.*
2023
import Trace.*
@@ -249,7 +252,7 @@ object Objects:
249252
val joinedTrace = data.pendingTraces.slice(index + 1, data.checkingObjects.size).foldLeft(pendingTrace) { (a, acc) => acc ++ a }
250253
val callTrace = Trace.buildStacktrace(joinedTrace, "Calling trace:\n")
251254
val cycle = data.checkingObjects.slice(index, data.checkingObjects.size)
252-
val pos = clazz.defTree
255+
val pos = clazz.defTree.sourcePos.focus
253256
report.warning("Cyclic initialization: " + cycle.map(_.klass.show).mkString(" -> ") + " -> " + clazz.show + ". " + callTrace, pos)
254257
end if
255258
data.checkingObjects(index)
@@ -834,11 +837,10 @@ object Objects:
834837

835838
/** Handle local variable definition, `val x = e` or `var x = e`.
836839
*
837-
* @param ref The value for `this` where the variable is defined.
838840
* @param sym The symbol of the variable.
839841
* @param value The value of the initializer.
840842
*/
841-
def initLocal(ref: Ref, sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) {
843+
def initLocal(sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) {
842844
if sym.is(Flags.Mutable) then
843845
val addr = Heap.localVarAddr(summon[Regions.Data], sym, State.currentObject)
844846
Env.setLocalVar(sym, addr)
@@ -870,9 +872,6 @@ object Objects:
870872
case _ =>
871873
report.warning("[Internal error] Variable not found " + sym.show + "\nenv = " + env.show + ". Calling trace:\n" + Trace.show, Trace.position)
872874
Bottom
873-
else if sym.isPatternBound then
874-
// TODO: handle patterns
875-
Cold
876875
else
877876
given Env.Data = env
878877
// Assume forward reference check is doing a good job
@@ -1113,11 +1112,9 @@ object Objects:
11131112
else
11141113
eval(arg, thisV, klass)
11151114

1116-
case Match(selector, cases) =>
1117-
eval(selector, thisV, klass)
1118-
// TODO: handle pattern match properly
1119-
report.warning("[initChecker] Pattern match is skipped. Trace:\n" + Trace.show, expr)
1120-
Bottom
1115+
case Match(scrutinee, cases) =>
1116+
val scrutineeValue = eval(scrutinee, thisV, klass)
1117+
patternMatch(scrutineeValue, cases, thisV, klass)
11211118

11221119
case Return(expr, from) =>
11231120
Returns.handle(from.symbol, eval(expr, thisV, klass))
@@ -1151,7 +1148,7 @@ object Objects:
11511148
// local val definition
11521149
val rhs = eval(vdef.rhs, thisV, klass)
11531150
val sym = vdef.symbol
1154-
initLocal(thisV.asInstanceOf[Ref], vdef.symbol, rhs)
1151+
initLocal(vdef.symbol, rhs)
11551152
Bottom
11561153

11571154
case ddef : DefDef =>
@@ -1173,6 +1170,196 @@ object Objects:
11731170
Bottom
11741171
}
11751172

1173+
/** Evaluate the cases against the scrutinee value.
1174+
*
1175+
* It returns the scrutinee in most cases. The main effect of the function is for its side effects of adding bindings
1176+
* to the environment.
1177+
*
1178+
* See https://docs.scala-lang.org/scala3/reference/changed-features/pattern-matching.html
1179+
*
1180+
* @param scrutinee The abstract value of the scrutinee.
1181+
* @param cases The cases to match.
1182+
* @param thisV The value for `C.this` where `C` is represented by `klass`.
1183+
* @param klass The enclosing class where the type `tp` is located.
1184+
*/
1185+
def patternMatch(scrutinee: Value, cases: List[CaseDef], thisV: Value, klass: ClassSymbol): Contextual[Value] =
1186+
// expected member types for `unapplySeq`
1187+
def lengthType = ExprType(defn.IntType)
1188+
def lengthCompareType = MethodType(List(defn.IntType), defn.IntType)
1189+
def applyType(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
1190+
def dropType(elemTp: Type) = MethodType(List(defn.IntType), defn.CollectionSeqType.appliedTo(elemTp))
1191+
def toSeqType(elemTp: Type) = ExprType(defn.CollectionSeqType.appliedTo(elemTp))
1192+
1193+
def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
1194+
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)
1195+
1196+
def evalCase(caseDef: CaseDef): Value =
1197+
evalPattern(scrutinee, caseDef.pat)
1198+
eval(caseDef.guard, thisV, klass)
1199+
eval(caseDef.body, thisV, klass)
1200+
1201+
/** Abstract evaluation of patterns.
1202+
*
1203+
* It augments the local environment for bound pattern variables. As symbols are globally
1204+
* unique, we can put them in a single environment.
1205+
*
1206+
* Currently, we assume all cases are reachable, thus all patterns are assumed to match.
1207+
*/
1208+
def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show):
1209+
val trace2 = Trace.trace.add(pat)
1210+
pat match
1211+
case Alternative(pats) =>
1212+
for pat <- pats do evalPattern(scrutinee, pat)
1213+
scrutinee
1214+
1215+
case bind @ Bind(_, pat) =>
1216+
val value = evalPattern(scrutinee, pat)
1217+
initLocal(bind.symbol, value)
1218+
scrutinee
1219+
1220+
case UnApply(fun, implicits, pats) =>
1221+
given Trace = trace2
1222+
1223+
val fun1 = funPart(fun)
1224+
val funRef = fun1.tpe.asInstanceOf[TermRef]
1225+
val unapplyResTp = funRef.widen.finalResultType
1226+
1227+
val receiver = fun1 match
1228+
case ident: Ident =>
1229+
evalType(funRef.prefix, thisV, klass)
1230+
case select: Select =>
1231+
eval(select.qualifier, thisV, klass)
1232+
1233+
val implicitValues = evalArgs(implicits.map(Arg.apply), thisV, klass)
1234+
// TODO: implicit values may appear before and/or after the scrutinee parameter.
1235+
val unapplyRes = call(receiver, funRef.symbol, TraceValue(scrutinee, summon[Trace]) :: implicitValues, funRef.prefix, superType = NoType, needResolve = true)
1236+
1237+
if fun.symbol.name == nme.unapplySeq then
1238+
var resultTp = unapplyResTp
1239+
var elemTp = unapplySeqTypeElemTp(resultTp)
1240+
var arity = productArity(resultTp, NoSourcePosition)
1241+
var needsGet = false
1242+
if (!elemTp.exists && arity <= 0) {
1243+
needsGet = true
1244+
resultTp = resultTp.select(nme.get).finalResultType
1245+
elemTp = unapplySeqTypeElemTp(resultTp.widen)
1246+
arity = productSelectorTypes(resultTp, NoSourcePosition).size
1247+
}
1248+
1249+
var resToMatch = unapplyRes
1250+
1251+
if needsGet then
1252+
// Get match
1253+
val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
1254+
call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
1255+
1256+
val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
1257+
resToMatch = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
1258+
end if
1259+
1260+
if elemTp.exists then
1261+
// sequence match
1262+
evalSeqPatterns(resToMatch, resultTp, elemTp, pats)
1263+
else
1264+
// product sequence match
1265+
val selectors = productSelectors(resultTp)
1266+
assert(selectors.length <= pats.length)
1267+
selectors.init.zip(pats).map { (sel, pat) =>
1268+
val selectRes = call(resToMatch, sel, Nil, resultTp, superType = NoType, needResolve = true)
1269+
evalPattern(selectRes, pat)
1270+
}
1271+
val seqPats = pats.drop(selectors.length - 1)
1272+
val toSeqRes = call(resToMatch, selectors.last, Nil, resultTp, superType = NoType, needResolve = true)
1273+
val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType
1274+
evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats)
1275+
end if
1276+
1277+
else
1278+
// distribute unapply to patterns
1279+
if isProductMatch(unapplyResTp, pats.length) then
1280+
// product match
1281+
val selectors = productSelectors(unapplyResTp)
1282+
assert(selectors.length == pats.length)
1283+
selectors.zip(pats).map { (sel, pat) =>
1284+
val selectRes = call(unapplyRes, sel, Nil, unapplyResTp, superType = NoType, needResolve = true)
1285+
evalPattern(selectRes, pat)
1286+
}
1287+
else if unapplyResTp <:< defn.BooleanType then
1288+
// Boolean extractor, do nothing
1289+
()
1290+
else
1291+
// Get match
1292+
val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
1293+
call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
1294+
1295+
val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
1296+
val getRes = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
1297+
if pats.length == 1 then
1298+
// single match
1299+
evalPattern(getRes, pats.head)
1300+
else
1301+
val getResTp = getDenot.info.finalResultType
1302+
val selectors = productSelectors(getResTp).take(pats.length)
1303+
selectors.zip(pats).map { (sel, pat) =>
1304+
val selectRes = call(unapplyRes, sel, Nil, getResTp, superType = NoType, needResolve = true)
1305+
evalPattern(selectRes, pat)
1306+
}
1307+
end if
1308+
end if
1309+
end if
1310+
scrutinee
1311+
1312+
case Ident(nme.WILDCARD) | Ident(nme.WILDCARD_STAR) =>
1313+
scrutinee
1314+
1315+
case Typed(pat, _) =>
1316+
evalPattern(scrutinee, pat)
1317+
1318+
case tree =>
1319+
// For all other trees, the semantics is normal.
1320+
eval(tree, thisV, klass)
1321+
1322+
end evalPattern
1323+
1324+
/**
1325+
* Evaluate a sequence value against sequence patterns.
1326+
*/
1327+
def evalSeqPatterns(scrutinee: Value, scrutineeType: Type, elemType: Type, pats: List[Tree])(using Trace): Unit =
1328+
// call .lengthCompare or .length
1329+
val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
1330+
if lengthCompareDenot.exists then
1331+
call(scrutinee, lengthCompareDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
1332+
else
1333+
val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
1334+
call(scrutinee, lengthDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
1335+
end if
1336+
1337+
// call .apply
1338+
val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
1339+
val applyRes = call(scrutinee, applyDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
1340+
1341+
if isWildcardStarArg(pats.last) then
1342+
if pats.size == 1 then
1343+
// call .toSeq
1344+
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
1345+
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
1346+
evalPattern(toSeqRes, pats.head)
1347+
else
1348+
// call .drop
1349+
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
1350+
val dropRes = call(scrutinee, dropDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
1351+
for pat <- pats.init do evalPattern(applyRes, pat)
1352+
evalPattern(dropRes, pats.last)
1353+
end if
1354+
else
1355+
// no patterns like `xs*`
1356+
for pat <- pats do evalPattern(applyRes, pat)
1357+
end evalSeqPatterns
1358+
1359+
1360+
cases.map(evalCase).join
1361+
end patternMatch
1362+
11761363
/** Handle semantics of leaf nodes
11771364
*
11781365
* For leaf nodes, their semantics is determined by their types.
@@ -1231,7 +1418,7 @@ object Objects:
12311418
resolveThis(tref.classSymbol.asClass, thisV, klass)
12321419

12331420
case _ =>
1234-
throw new Exception("unexpected type: " + tp)
1421+
throw new Exception("unexpected type: " + tp + ", Trace:\n" + Trace.show)
12351422
}
12361423

12371424
/** Evaluate arguments of methods and constructors */

compiler/src/dotty/tools/dotc/transform/init/Trace.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ object Trace:
4949
val code = SyntaxHighlighting.highlight(pos.lineContent.trim.nn)
5050
i"$code\t$loc"
5151
else
52-
tree.show
52+
tree match
53+
case defDef: DefTree =>
54+
// The definition can be huge, avoid printing the whole definition.
55+
defDef.symbol.show
56+
case _ =>
57+
tree.show
5358
val positionMarkerLine =
5459
if pos.exists && pos.source.exists then
5560
positionMarker(pos)

compiler/src/dotty/tools/dotc/transform/init/Util.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ object Util:
2626
opaque type Arg = Tree | ByNameArg
2727
case class ByNameArg(tree: Tree)
2828

29+
object Arg:
30+
def apply(tree: Tree): Arg = tree
31+
2932
extension (arg: Arg)
3033
def isByName = arg.isInstanceOf[ByNameArg]
3134
def tree: Tree = arg match

tests/init-global/neg/global-cycle1.check

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
-- Error: tests/init-global/neg/global-cycle1.scala:1:7 ----------------------------------------------------------------
22
1 |object A { // error
3-
|^
4-
|Cyclic initialization: object A -> object B -> object A. Calling trace:
5-
|-> object A { // error [ global-cycle1.scala:1 ]
6-
| ^
7-
|-> val a: Int = B.b [ global-cycle1.scala:2 ]
8-
| ^
9-
|-> object B { [ global-cycle1.scala:5 ]
10-
| ^
11-
|-> val b: Int = A.a // error [ global-cycle1.scala:6 ]
12-
| ^
13-
2 | val a: Int = B.b
14-
3 |}
3+
| ^
4+
| Cyclic initialization: object A -> object B -> object A. Calling trace:
5+
| -> object A { // error [ global-cycle1.scala:1 ]
6+
| ^
7+
| -> val a: Int = B.b [ global-cycle1.scala:2 ]
8+
| ^
9+
| -> object B { [ global-cycle1.scala:5 ]
10+
| ^
11+
| -> val b: Int = A.a // error [ global-cycle1.scala:6 ]
12+
| ^
1513
-- Error: tests/init-global/neg/global-cycle1.scala:6:17 ---------------------------------------------------------------
1614
6 | val b: Int = A.a // error
1715
| ^^^

tests/init-global/neg/global-cycle6.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
object A { // error
22
val n: Int = B.m
33
class Inner {
4-
println(n)
4+
println(n) // error
55
}
66
}
77

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
-- Error: tests/init-global/neg/patmat-unapplySeq.scala:8:32 -----------------------------------------------------------
2+
8 | def apply(i: Int): Box = array(i) // error
3+
| ^^^^^^^^
4+
|Reading mutable state of object A during initialization of object B.
5+
|Reading mutable state of other static objects is forbidden as it breaks initialization-time irrelevance. Calling trace:
6+
|-> object B: [ patmat-unapplySeq.scala:15 ]
7+
| ^
8+
|-> case A(b) => [ patmat-unapplySeq.scala:17 ]
9+
| ^^^^
10+
|-> def apply(i: Int): Box = array(i) // error [ patmat-unapplySeq.scala:8 ]
11+
| ^^^^^^^^
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
object A:
2+
class Box(var x: Int)
3+
4+
val array: Array[Box] = new Array(1)
5+
array(0) = new Box(10)
6+
7+
def length: Int = array.length
8+
def apply(i: Int): Box = array(i) // error
9+
def drop(n: Int): Seq[Box] = array.toSeq
10+
def toSeq: Seq[Box] = array.toSeq
11+
12+
def unapplySeq(array: Array[Box]): A.type = this
13+
14+
15+
object B:
16+
A.array match
17+
case A(b) =>
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
object A:
2+
class Box(var x: Int)
3+
4+
val array: Array[Box] = new Array(1)
5+
array(0) = new Box(10)
6+
7+
def length: Int = array.length
8+
def apply(i: Int): Box = array(i) // error
9+
def drop(n: Int): Seq[Box] = array.toSeq
10+
def toSeq: Seq[Box] = array.toSeq
11+
12+
def unapplySeq(array: Array[Box]): A.type = this
13+
14+
15+
object B:
16+
A.array match
17+
case A(b*) =>

0 commit comments

Comments
 (0)