@@ -6,15 +6,18 @@ import core.*
66import Contexts .*
77import Symbols .*
88import Types .*
9+ import Denotations .Denotation
910import StdNames .*
11+ import Names .TermName
1012import NameKinds .OuterSelectName
1113import NameKinds .SuperAccessorName
1214
1315import ast .tpd .*
14- import util .SourcePosition
16+ import util .{ SourcePosition , NoSourcePosition }
1517import config .Printers .init as printer
1618import reporting .StoreReporter
1719import reporting .trace as log
20+ import typer .Applications .*
1821
1922import Errors .*
2023import Trace .*
@@ -249,7 +252,7 @@ object Objects:
249252 val joinedTrace = data.pendingTraces.slice(index + 1 , data.checkingObjects.size).foldLeft(pendingTrace) { (a, acc) => acc ++ a }
250253 val callTrace = Trace .buildStacktrace(joinedTrace, " Calling trace:\n " )
251254 val cycle = data.checkingObjects.slice(index, data.checkingObjects.size)
252- val pos = clazz.defTree
255+ val pos = clazz.defTree.sourcePos.focus
253256 report.warning(" Cyclic initialization: " + cycle.map(_.klass.show).mkString(" -> " ) + " -> " + clazz.show + " . " + callTrace, pos)
254257 end if
255258 data.checkingObjects(index)
@@ -834,11 +837,10 @@ object Objects:
834837
835838 /** Handle local variable definition, `val x = e` or `var x = e`.
836839 *
837- * @param ref The value for `this` where the variable is defined.
838840 * @param sym The symbol of the variable.
839841 * @param value The value of the initializer.
840842 */
841- def initLocal (ref : Ref , sym : Symbol , value : Value ): Contextual [Unit ] = log(" initialize local " + sym.show + " with " + value.show, printer) {
843+ def initLocal (sym : Symbol , value : Value ): Contextual [Unit ] = log(" initialize local " + sym.show + " with " + value.show, printer) {
842844 if sym.is(Flags .Mutable ) then
843845 val addr = Heap .localVarAddr(summon[Regions .Data ], sym, State .currentObject)
844846 Env .setLocalVar(sym, addr)
@@ -870,9 +872,6 @@ object Objects:
870872 case _ =>
871873 report.warning(" [Internal error] Variable not found " + sym.show + " \n env = " + env.show + " . Calling trace:\n " + Trace .show, Trace .position)
872874 Bottom
873- else if sym.isPatternBound then
874- // TODO: handle patterns
875- Cold
876875 else
877876 given Env .Data = env
878877 // Assume forward reference check is doing a good job
@@ -1113,11 +1112,9 @@ object Objects:
11131112 else
11141113 eval(arg, thisV, klass)
11151114
1116- case Match (selector, cases) =>
1117- eval(selector, thisV, klass)
1118- // TODO: handle pattern match properly
1119- report.warning(" [initChecker] Pattern match is skipped. Trace:\n " + Trace .show, expr)
1120- Bottom
1115+ case Match (scrutinee, cases) =>
1116+ val scrutineeValue = eval(scrutinee, thisV, klass)
1117+ patternMatch(scrutineeValue, cases, thisV, klass)
11211118
11221119 case Return (expr, from) =>
11231120 Returns .handle(from.symbol, eval(expr, thisV, klass))
@@ -1151,7 +1148,7 @@ object Objects:
11511148 // local val definition
11521149 val rhs = eval(vdef.rhs, thisV, klass)
11531150 val sym = vdef.symbol
1154- initLocal(thisV. asInstanceOf [ Ref ], vdef.symbol, rhs)
1151+ initLocal(vdef.symbol, rhs)
11551152 Bottom
11561153
11571154 case ddef : DefDef =>
@@ -1173,6 +1170,196 @@ object Objects:
11731170 Bottom
11741171 }
11751172
1173+ /** Evaluate the cases against the scrutinee value.
1174+ *
1175+ * It returns the scrutinee in most cases. The main effect of the function is for its side effects of adding bindings
1176+ * to the environment.
1177+ *
1178+ * See https://docs.scala-lang.org/scala3/reference/changed-features/pattern-matching.html
1179+ *
1180+ * @param scrutinee The abstract value of the scrutinee.
1181+ * @param cases The cases to match.
1182+ * @param thisV The value for `C.this` where `C` is represented by `klass`.
1183+ * @param klass The enclosing class where the type `tp` is located.
1184+ */
1185+ def patternMatch (scrutinee : Value , cases : List [CaseDef ], thisV : Value , klass : ClassSymbol ): Contextual [Value ] =
1186+ // expected member types for `unapplySeq`
1187+ def lengthType = ExprType (defn.IntType )
1188+ def lengthCompareType = MethodType (List (defn.IntType ), defn.IntType )
1189+ def applyType (elemTp : Type ) = MethodType (List (defn.IntType ), elemTp)
1190+ def dropType (elemTp : Type ) = MethodType (List (defn.IntType ), defn.CollectionSeqType .appliedTo(elemTp))
1191+ def toSeqType (elemTp : Type ) = ExprType (defn.CollectionSeqType .appliedTo(elemTp))
1192+
1193+ def getMemberMethod (receiver : Type , name : TermName , tp : Type ): Denotation =
1194+ receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)
1195+
1196+ def evalCase (caseDef : CaseDef ): Value =
1197+ evalPattern(scrutinee, caseDef.pat)
1198+ eval(caseDef.guard, thisV, klass)
1199+ eval(caseDef.body, thisV, klass)
1200+
1201+ /** Abstract evaluation of patterns.
1202+ *
1203+ * It augments the local environment for bound pattern variables. As symbols are globally
1204+ * unique, we can put them in a single environment.
1205+ *
1206+ * Currently, we assume all cases are reachable, thus all patterns are assumed to match.
1207+ */
1208+ def evalPattern (scrutinee : Value , pat : Tree ): Value = log(" match " + scrutinee.show + " against " + pat.show, printer, (_ : Value ).show):
1209+ val trace2 = Trace .trace.add(pat)
1210+ pat match
1211+ case Alternative (pats) =>
1212+ for pat <- pats do evalPattern(scrutinee, pat)
1213+ scrutinee
1214+
1215+ case bind @ Bind (_, pat) =>
1216+ val value = evalPattern(scrutinee, pat)
1217+ initLocal(bind.symbol, value)
1218+ scrutinee
1219+
1220+ case UnApply (fun, implicits, pats) =>
1221+ given Trace = trace2
1222+
1223+ val fun1 = funPart(fun)
1224+ val funRef = fun1.tpe.asInstanceOf [TermRef ]
1225+ val unapplyResTp = funRef.widen.finalResultType
1226+
1227+ val receiver = fun1 match
1228+ case ident : Ident =>
1229+ evalType(funRef.prefix, thisV, klass)
1230+ case select : Select =>
1231+ eval(select.qualifier, thisV, klass)
1232+
1233+ val implicitValues = evalArgs(implicits.map(Arg .apply), thisV, klass)
1234+ // TODO: implicit values may appear before and/or after the scrutinee parameter.
1235+ val unapplyRes = call(receiver, funRef.symbol, TraceValue (scrutinee, summon[Trace ]) :: implicitValues, funRef.prefix, superType = NoType , needResolve = true )
1236+
1237+ if fun.symbol.name == nme.unapplySeq then
1238+ var resultTp = unapplyResTp
1239+ var elemTp = unapplySeqTypeElemTp(resultTp)
1240+ var arity = productArity(resultTp, NoSourcePosition )
1241+ var needsGet = false
1242+ if (! elemTp.exists && arity <= 0 ) {
1243+ needsGet = true
1244+ resultTp = resultTp.select(nme.get).finalResultType
1245+ elemTp = unapplySeqTypeElemTp(resultTp.widen)
1246+ arity = productSelectorTypes(resultTp, NoSourcePosition ).size
1247+ }
1248+
1249+ var resToMatch = unapplyRes
1250+
1251+ if needsGet then
1252+ // Get match
1253+ val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
1254+ call(unapplyRes, isEmptyDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1255+
1256+ val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
1257+ resToMatch = call(unapplyRes, getDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1258+ end if
1259+
1260+ if elemTp.exists then
1261+ // sequence match
1262+ evalSeqPatterns(resToMatch, resultTp, elemTp, pats)
1263+ else
1264+ // product sequence match
1265+ val selectors = productSelectors(resultTp)
1266+ assert(selectors.length <= pats.length)
1267+ selectors.init.zip(pats).map { (sel, pat) =>
1268+ val selectRes = call(resToMatch, sel, Nil , resultTp, superType = NoType , needResolve = true )
1269+ evalPattern(selectRes, pat)
1270+ }
1271+ val seqPats = pats.drop(selectors.length - 1 )
1272+ val toSeqRes = call(resToMatch, selectors.last, Nil , resultTp, superType = NoType , needResolve = true )
1273+ val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType
1274+ evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats)
1275+ end if
1276+
1277+ else
1278+ // distribute unapply to patterns
1279+ if isProductMatch(unapplyResTp, pats.length) then
1280+ // product match
1281+ val selectors = productSelectors(unapplyResTp)
1282+ assert(selectors.length == pats.length)
1283+ selectors.zip(pats).map { (sel, pat) =>
1284+ val selectRes = call(unapplyRes, sel, Nil , unapplyResTp, superType = NoType , needResolve = true )
1285+ evalPattern(selectRes, pat)
1286+ }
1287+ else if unapplyResTp <:< defn.BooleanType then
1288+ // Boolean extractor, do nothing
1289+ ()
1290+ else
1291+ // Get match
1292+ val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
1293+ call(unapplyRes, isEmptyDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1294+
1295+ val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
1296+ val getRes = call(unapplyRes, getDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1297+ if pats.length == 1 then
1298+ // single match
1299+ evalPattern(getRes, pats.head)
1300+ else
1301+ val getResTp = getDenot.info.finalResultType
1302+ val selectors = productSelectors(getResTp).take(pats.length)
1303+ selectors.zip(pats).map { (sel, pat) =>
1304+ val selectRes = call(unapplyRes, sel, Nil , getResTp, superType = NoType , needResolve = true )
1305+ evalPattern(selectRes, pat)
1306+ }
1307+ end if
1308+ end if
1309+ end if
1310+ scrutinee
1311+
1312+ case Ident (nme.WILDCARD ) | Ident (nme.WILDCARD_STAR ) =>
1313+ scrutinee
1314+
1315+ case Typed (pat, _) =>
1316+ evalPattern(scrutinee, pat)
1317+
1318+ case tree =>
1319+ // For all other trees, the semantics is normal.
1320+ eval(tree, thisV, klass)
1321+
1322+ end evalPattern
1323+
1324+ /**
1325+ * Evaluate a sequence value against sequence patterns.
1326+ */
1327+ def evalSeqPatterns (scrutinee : Value , scrutineeType : Type , elemType : Type , pats : List [Tree ])(using Trace ): Unit =
1328+ // call .lengthCompare or .length
1329+ val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
1330+ if lengthCompareDenot.exists then
1331+ call(scrutinee, lengthCompareDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1332+ else
1333+ val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
1334+ call(scrutinee, lengthDenot.symbol, Nil , scrutineeType, superType = NoType , needResolve = true )
1335+ end if
1336+
1337+ // call .apply
1338+ val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
1339+ val applyRes = call(scrutinee, applyDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1340+
1341+ if isWildcardStarArg(pats.last) then
1342+ if pats.size == 1 then
1343+ // call .toSeq
1344+ val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
1345+ val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil , scrutineeType, superType = NoType , needResolve = true )
1346+ evalPattern(toSeqRes, pats.head)
1347+ else
1348+ // call .drop
1349+ val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
1350+ val dropRes = call(scrutinee, dropDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1351+ for pat <- pats.init do evalPattern(applyRes, pat)
1352+ evalPattern(dropRes, pats.last)
1353+ end if
1354+ else
1355+ // no patterns like `xs*`
1356+ for pat <- pats do evalPattern(applyRes, pat)
1357+ end evalSeqPatterns
1358+
1359+
1360+ cases.map(evalCase).join
1361+ end patternMatch
1362+
11761363 /** Handle semantics of leaf nodes
11771364 *
11781365 * For leaf nodes, their semantics is determined by their types.
@@ -1231,7 +1418,7 @@ object Objects:
12311418 resolveThis(tref.classSymbol.asClass, thisV, klass)
12321419
12331420 case _ =>
1234- throw new Exception (" unexpected type: " + tp)
1421+ throw new Exception (" unexpected type: " + tp + " , Trace: \n " + Trace .show )
12351422 }
12361423
12371424 /** Evaluate arguments of methods and constructors */
0 commit comments