@@ -6,15 +6,18 @@ import core.*
6
6
import Contexts .*
7
7
import Symbols .*
8
8
import Types .*
9
+ import Denotations .Denotation
9
10
import StdNames .*
11
+ import Names .TermName
10
12
import NameKinds .OuterSelectName
11
13
import NameKinds .SuperAccessorName
12
14
13
15
import ast .tpd .*
14
- import util .SourcePosition
16
+ import util .{ SourcePosition , NoSourcePosition }
15
17
import config .Printers .init as printer
16
18
import reporting .StoreReporter
17
19
import reporting .trace as log
20
+ import typer .Applications .*
18
21
19
22
import Errors .*
20
23
import Trace .*
@@ -249,7 +252,7 @@ object Objects:
249
252
val joinedTrace = data.pendingTraces.slice(index + 1 , data.checkingObjects.size).foldLeft(pendingTrace) { (a, acc) => acc ++ a }
250
253
val callTrace = Trace .buildStacktrace(joinedTrace, " Calling trace:\n " )
251
254
val cycle = data.checkingObjects.slice(index, data.checkingObjects.size)
252
- val pos = clazz.defTree
255
+ val pos = clazz.defTree.sourcePos.focus
253
256
report.warning(" Cyclic initialization: " + cycle.map(_.klass.show).mkString(" -> " ) + " -> " + clazz.show + " . " + callTrace, pos)
254
257
end if
255
258
data.checkingObjects(index)
@@ -834,11 +837,10 @@ object Objects:
834
837
835
838
/** Handle local variable definition, `val x = e` or `var x = e`.
836
839
*
837
- * @param ref The value for `this` where the variable is defined.
838
840
* @param sym The symbol of the variable.
839
841
* @param value The value of the initializer.
840
842
*/
841
- def initLocal (ref : Ref , sym : Symbol , value : Value ): Contextual [Unit ] = log(" initialize local " + sym.show + " with " + value.show, printer) {
843
+ def initLocal (sym : Symbol , value : Value ): Contextual [Unit ] = log(" initialize local " + sym.show + " with " + value.show, printer) {
842
844
if sym.is(Flags .Mutable ) then
843
845
val addr = Heap .localVarAddr(summon[Regions .Data ], sym, State .currentObject)
844
846
Env .setLocalVar(sym, addr)
@@ -870,9 +872,6 @@ object Objects:
870
872
case _ =>
871
873
report.warning(" [Internal error] Variable not found " + sym.show + " \n env = " + env.show + " . Calling trace:\n " + Trace .show, Trace .position)
872
874
Bottom
873
- else if sym.isPatternBound then
874
- // TODO: handle patterns
875
- Cold
876
875
else
877
876
given Env .Data = env
878
877
// Assume forward reference check is doing a good job
@@ -1113,11 +1112,9 @@ object Objects:
1113
1112
else
1114
1113
eval(arg, thisV, klass)
1115
1114
1116
- case Match (selector, cases) =>
1117
- eval(selector, thisV, klass)
1118
- // TODO: handle pattern match properly
1119
- report.warning(" [initChecker] Pattern match is skipped. Trace:\n " + Trace .show, expr)
1120
- Bottom
1115
+ case Match (scrutinee, cases) =>
1116
+ val scrutineeValue = eval(scrutinee, thisV, klass)
1117
+ patternMatch(scrutineeValue, cases, thisV, klass)
1121
1118
1122
1119
case Return (expr, from) =>
1123
1120
Returns .handle(from.symbol, eval(expr, thisV, klass))
@@ -1151,7 +1148,7 @@ object Objects:
1151
1148
// local val definition
1152
1149
val rhs = eval(vdef.rhs, thisV, klass)
1153
1150
val sym = vdef.symbol
1154
- initLocal(thisV. asInstanceOf [ Ref ], vdef.symbol, rhs)
1151
+ initLocal(vdef.symbol, rhs)
1155
1152
Bottom
1156
1153
1157
1154
case ddef : DefDef =>
@@ -1173,6 +1170,196 @@ object Objects:
1173
1170
Bottom
1174
1171
}
1175
1172
1173
+ /** Evaluate the cases against the scrutinee value.
1174
+ *
1175
+ * It returns the scrutinee in most cases. The main effect of the function is for its side effects of adding bindings
1176
+ * to the environment.
1177
+ *
1178
+ * See https://docs.scala-lang.org/scala3/reference/changed-features/pattern-matching.html
1179
+ *
1180
+ * @param scrutinee The abstract value of the scrutinee.
1181
+ * @param cases The cases to match.
1182
+ * @param thisV The value for `C.this` where `C` is represented by `klass`.
1183
+ * @param klass The enclosing class where the type `tp` is located.
1184
+ */
1185
+ def patternMatch (scrutinee : Value , cases : List [CaseDef ], thisV : Value , klass : ClassSymbol ): Contextual [Value ] =
1186
+ // expected member types for `unapplySeq`
1187
+ def lengthType = ExprType (defn.IntType )
1188
+ def lengthCompareType = MethodType (List (defn.IntType ), defn.IntType )
1189
+ def applyType (elemTp : Type ) = MethodType (List (defn.IntType ), elemTp)
1190
+ def dropType (elemTp : Type ) = MethodType (List (defn.IntType ), defn.CollectionSeqType .appliedTo(elemTp))
1191
+ def toSeqType (elemTp : Type ) = ExprType (defn.CollectionSeqType .appliedTo(elemTp))
1192
+
1193
+ def getMemberMethod (receiver : Type , name : TermName , tp : Type ): Denotation =
1194
+ receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)
1195
+
1196
+ def evalCase (caseDef : CaseDef ): Value =
1197
+ evalPattern(scrutinee, caseDef.pat)
1198
+ eval(caseDef.guard, thisV, klass)
1199
+ eval(caseDef.body, thisV, klass)
1200
+
1201
+ /** Abstract evaluation of patterns.
1202
+ *
1203
+ * It augments the local environment for bound pattern variables. As symbols are globally
1204
+ * unique, we can put them in a single environment.
1205
+ *
1206
+ * Currently, we assume all cases are reachable, thus all patterns are assumed to match.
1207
+ */
1208
+ def evalPattern (scrutinee : Value , pat : Tree ): Value = log(" match " + scrutinee.show + " against " + pat.show, printer, (_ : Value ).show):
1209
+ val trace2 = Trace .trace.add(pat)
1210
+ pat match
1211
+ case Alternative (pats) =>
1212
+ for pat <- pats do evalPattern(scrutinee, pat)
1213
+ scrutinee
1214
+
1215
+ case bind @ Bind (_, pat) =>
1216
+ val value = evalPattern(scrutinee, pat)
1217
+ initLocal(bind.symbol, value)
1218
+ scrutinee
1219
+
1220
+ case UnApply (fun, implicits, pats) =>
1221
+ given Trace = trace2
1222
+
1223
+ val fun1 = funPart(fun)
1224
+ val funRef = fun1.tpe.asInstanceOf [TermRef ]
1225
+ val unapplyResTp = funRef.widen.finalResultType
1226
+
1227
+ val receiver = fun1 match
1228
+ case ident : Ident =>
1229
+ evalType(funRef.prefix, thisV, klass)
1230
+ case select : Select =>
1231
+ eval(select.qualifier, thisV, klass)
1232
+
1233
+ val implicitValues = evalArgs(implicits.map(Arg .apply), thisV, klass)
1234
+ // TODO: implicit values may appear before and/or after the scrutinee parameter.
1235
+ val unapplyRes = call(receiver, funRef.symbol, TraceValue (scrutinee, summon[Trace ]) :: implicitValues, funRef.prefix, superType = NoType , needResolve = true )
1236
+
1237
+ if fun.symbol.name == nme.unapplySeq then
1238
+ var resultTp = unapplyResTp
1239
+ var elemTp = unapplySeqTypeElemTp(resultTp)
1240
+ var arity = productArity(resultTp, NoSourcePosition )
1241
+ var needsGet = false
1242
+ if (! elemTp.exists && arity <= 0 ) {
1243
+ needsGet = true
1244
+ resultTp = resultTp.select(nme.get).finalResultType
1245
+ elemTp = unapplySeqTypeElemTp(resultTp.widen)
1246
+ arity = productSelectorTypes(resultTp, NoSourcePosition ).size
1247
+ }
1248
+
1249
+ var resToMatch = unapplyRes
1250
+
1251
+ if needsGet then
1252
+ // Get match
1253
+ val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
1254
+ call(unapplyRes, isEmptyDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1255
+
1256
+ val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
1257
+ resToMatch = call(unapplyRes, getDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1258
+ end if
1259
+
1260
+ if elemTp.exists then
1261
+ // sequence match
1262
+ evalSeqPatterns(resToMatch, resultTp, elemTp, pats)
1263
+ else
1264
+ // product sequence match
1265
+ val selectors = productSelectors(resultTp)
1266
+ assert(selectors.length <= pats.length)
1267
+ selectors.init.zip(pats).map { (sel, pat) =>
1268
+ val selectRes = call(resToMatch, sel, Nil , resultTp, superType = NoType , needResolve = true )
1269
+ evalPattern(selectRes, pat)
1270
+ }
1271
+ val seqPats = pats.drop(selectors.length - 1 )
1272
+ val toSeqRes = call(resToMatch, selectors.last, Nil , resultTp, superType = NoType , needResolve = true )
1273
+ val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType
1274
+ evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats)
1275
+ end if
1276
+
1277
+ else
1278
+ // distribute unapply to patterns
1279
+ if isProductMatch(unapplyResTp, pats.length) then
1280
+ // product match
1281
+ val selectors = productSelectors(unapplyResTp)
1282
+ assert(selectors.length == pats.length)
1283
+ selectors.zip(pats).map { (sel, pat) =>
1284
+ val selectRes = call(unapplyRes, sel, Nil , unapplyResTp, superType = NoType , needResolve = true )
1285
+ evalPattern(selectRes, pat)
1286
+ }
1287
+ else if unapplyResTp <:< defn.BooleanType then
1288
+ // Boolean extractor, do nothing
1289
+ ()
1290
+ else
1291
+ // Get match
1292
+ val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
1293
+ call(unapplyRes, isEmptyDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1294
+
1295
+ val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
1296
+ val getRes = call(unapplyRes, getDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1297
+ if pats.length == 1 then
1298
+ // single match
1299
+ evalPattern(getRes, pats.head)
1300
+ else
1301
+ val getResTp = getDenot.info.finalResultType
1302
+ val selectors = productSelectors(getResTp).take(pats.length)
1303
+ selectors.zip(pats).map { (sel, pat) =>
1304
+ val selectRes = call(unapplyRes, sel, Nil , getResTp, superType = NoType , needResolve = true )
1305
+ evalPattern(selectRes, pat)
1306
+ }
1307
+ end if
1308
+ end if
1309
+ end if
1310
+ scrutinee
1311
+
1312
+ case Ident (nme.WILDCARD ) | Ident (nme.WILDCARD_STAR ) =>
1313
+ scrutinee
1314
+
1315
+ case Typed (pat, _) =>
1316
+ evalPattern(scrutinee, pat)
1317
+
1318
+ case tree =>
1319
+ // For all other trees, the semantics is normal.
1320
+ eval(tree, thisV, klass)
1321
+
1322
+ end evalPattern
1323
+
1324
+ /**
1325
+ * Evaluate a sequence value against sequence patterns.
1326
+ */
1327
+ def evalSeqPatterns (scrutinee : Value , scrutineeType : Type , elemType : Type , pats : List [Tree ])(using Trace ): Unit =
1328
+ // call .lengthCompare or .length
1329
+ val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
1330
+ if lengthCompareDenot.exists then
1331
+ call(scrutinee, lengthCompareDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1332
+ else
1333
+ val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
1334
+ call(scrutinee, lengthDenot.symbol, Nil , scrutineeType, superType = NoType , needResolve = true )
1335
+ end if
1336
+
1337
+ // call .apply
1338
+ val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
1339
+ val applyRes = call(scrutinee, applyDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1340
+
1341
+ if isWildcardStarArg(pats.last) then
1342
+ if pats.size == 1 then
1343
+ // call .toSeq
1344
+ val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
1345
+ val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil , scrutineeType, superType = NoType , needResolve = true )
1346
+ evalPattern(toSeqRes, pats.head)
1347
+ else
1348
+ // call .drop
1349
+ val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
1350
+ val dropRes = call(scrutinee, dropDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1351
+ for pat <- pats.init do evalPattern(applyRes, pat)
1352
+ evalPattern(dropRes, pats.last)
1353
+ end if
1354
+ else
1355
+ // no patterns like `xs*`
1356
+ for pat <- pats do evalPattern(applyRes, pat)
1357
+ end evalSeqPatterns
1358
+
1359
+
1360
+ cases.map(evalCase).join
1361
+ end patternMatch
1362
+
1176
1363
/** Handle semantics of leaf nodes
1177
1364
*
1178
1365
* For leaf nodes, their semantics is determined by their types.
@@ -1231,7 +1418,7 @@ object Objects:
1231
1418
resolveThis(tref.classSymbol.asClass, thisV, klass)
1232
1419
1233
1420
case _ =>
1234
- throw new Exception (" unexpected type: " + tp)
1421
+ throw new Exception (" unexpected type: " + tp + " , Trace: \n " + Trace .show )
1235
1422
}
1236
1423
1237
1424
/** Evaluate arguments of methods and constructors */
0 commit comments