@@ -206,13 +206,7 @@ trait QuotesAndSplices {
206
206
private def splitQuotePattern (quoted : Tree )(using Context ): (collection.Map [Symbol , Bind ], Tree , List [Tree ]) = {
207
207
val ctx0 = ctx
208
208
209
- val typeBindings : mutable.Map [Symbol , Bind ] = mutable.LinkedHashMap .empty
210
- def getBinding (sym : Symbol ): Bind =
211
- typeBindings.getOrElseUpdate(sym, {
212
- val bindingBounds = sym.info
213
- val bsym = newPatternBoundSymbol(sym.name.toString.stripPrefix(" $" ).toTypeName, bindingBounds, quoted.span)
214
- Bind (bsym, untpd.Ident (nme.WILDCARD ).withType(bindingBounds)).withSpan(quoted.span)
215
- })
209
+ val bindSymMapping : collection.Map [Symbol , Bind ] = unapplyBindingsMapping(quoted)
216
210
217
211
object splitter extends tpd.TreeMap {
218
212
private var variance : Int = 1
@@ -292,13 +286,14 @@ trait QuotesAndSplices {
292
286
report.error(IllegalVariableInPatternAlternative (tdef.symbol.name), tdef.srcPos)
293
287
if variance == - 1 then
294
288
tdef.symbol.addAnnotation(Annotation (New (ref(defn.QuotedRuntimePatterns_fromAboveAnnot .typeRef)).withSpan(tdef.span)))
295
- val bindingType = getBinding (tdef.symbol).symbol.typeRef
289
+ val bindingType = bindSymMapping (tdef.symbol).symbol.typeRef
296
290
val bindingTypeTpe = AppliedType (defn.QuotedTypeClass .typeRef, bindingType :: Nil )
297
291
val sym = newPatternBoundSymbol(nameOfSyntheticGiven, bindingTypeTpe, tdef.span, flags = ImplicitVal )(using ctx0)
298
292
buff += Bind (sym, untpd.Ident (nme.WILDCARD ).withType(bindingTypeTpe)).withSpan(tdef.span)
299
293
super .transform(tdef)
300
294
}
301
295
}
296
+
302
297
val shape0 = splitter.transform(quoted)
303
298
val patterns = (splitter.typePatBuf.iterator ++ splitter.freshTypePatBuf.iterator ++ splitter.patBuf.iterator).toList
304
299
val freshTypeBindings = splitter.freshTypeBindingsBuff.result()
@@ -329,7 +324,42 @@ trait QuotesAndSplices {
329
324
new TreeTypeMap (typeMap = typeMap).transform(shape1)
330
325
}
331
326
332
- (typeBindings, shape2, patterns)
327
+ (bindSymMapping, shape2, patterns)
328
+ }
329
+
330
+ private def unapplyBindingsMapping (quoted : Tree )(using Context ): collection.Map [Symbol , Bind ] = {
331
+ val mapping = mutable.LinkedHashMap .empty[Symbol , Symbol ]
332
+ new tpd.TreeTraverser {
333
+ def traverse (tree : Tree )(using Context ): Unit = tree match {
334
+ case _ : SplicePattern =>
335
+ case Select (pat : Bind , _) if tree.symbol.isTypeSplice =>
336
+ val sym = tree.tpe.dealias.typeSymbol
337
+ if sym.exists then registerNewBindSym(sym)
338
+ case tdef : TypeDef =>
339
+ if tdef.symbol.hasAnnotation(defn.QuotedRuntimePatterns_patternTypeAnnot ) then
340
+ registerNewBindSym(tdef.symbol)
341
+ traverseChildren(tdef)
342
+ case _ =>
343
+ traverseChildren(tree)
344
+ }
345
+ private def registerNewBindSym (sym : Symbol ): Unit =
346
+ if ! mapping.contains(sym) then
347
+ mapping(sym) = newSymbol(ctx.owner, sym.name.toString.stripPrefix(" $" ).toTypeName, Case | sym.flags, sym.info, coord = quoted.span)
348
+ }.traverse(quoted)
349
+
350
+ // Replace symbols in `mapping` in the infos of the new symbol and register GADT bounds.
351
+ // GADT bounds need to be added after the info is updated to avoid references to the old symbols.
352
+ var oldBindings : List [Symbol ] = mapping.keys.toList
353
+ var newBindingsRefs : List [Type ] = mapping.values.toList.map(_.typeRef)
354
+ for newBindings <- mapping.values do
355
+ newBindings.info = newBindings.info.subst(oldBindings, newBindingsRefs)
356
+ ctx.gadtState.addToConstraint(newBindings)
357
+
358
+ // Map into Bind nodes retaining the original order
359
+ val mapping2 : mutable.Map [Symbol , Bind ] = mutable.LinkedHashMap .empty
360
+ for (oldSym, newSym) <- mapping do
361
+ mapping2(oldSym) = Bind (newSym, untpd.Ident (nme.WILDCARD ).withType(newSym.info)).withSpan(quoted.span)
362
+ mapping2
333
363
}
334
364
335
365
/** Type a quote pattern `case '{ <quoted> } =>` qiven the a current prototype. Typing the pattern
@@ -451,6 +481,11 @@ trait QuotesAndSplices {
451
481
452
482
val decoded = QuotePatterns .decode(patternUnapply)
453
483
decoded.foreach(QuotePatterns .checkPattern)
484
+ // val encoded = QuotePatterns.encode(decoded.get)
485
+ // println(patternUnapply.show)
486
+ // println(encoded.show)
487
+ // println(decoded.get.show)
488
+
454
489
decoded.getOrElse(patternUnapply)
455
490
}
456
491
}
0 commit comments