Skip to content

Commit 3fb0f98

Browse files
committed
Fix bounds of encoded type variables in quote patterns
When we encode quote patterns into `unapply` methods, we need to create a copy of each type variable. One copy is kept within the quote(in the next stage) and the other is used in the `unapply` method to define the usual pattern type variables. When creating the latter we copied the symbols but did not update the infos. This implies that if type variables would be bounded by each other, the bounds of the copies would be the original types instead of the copies. We need to update those references. To update the info we now create all the symbols in one pass and the update all their infos in a second pass. This also implies that we cannot use the `newPatternBoundSymbol` to create the symbol as this constructor will register the info into GADT bounds. Instead we use the plain `newSymbol`. Then in the second pass, when we have updated the infos, we register the symbol into GADT bounds. Note that the code in the added test does compiles correctly, but it had the inconsistent bounds. This test is added in case we need to manually inspect the bounds latter. This test does fail to compile in #17935 if this fix is not applied.
1 parent d0f9a51 commit 3fb0f98

File tree

2 files changed

+64
-9
lines changed

2 files changed

+64
-9
lines changed

compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,7 @@ trait QuotesAndSplices {
214214
private def splitQuotePattern(quoted: Tree)(using Context): (collection.Map[Symbol, Bind], Tree, List[Tree]) = {
215215
val ctx0 = ctx
216216

217-
val typeBindings: mutable.Map[Symbol, Bind] = mutable.LinkedHashMap.empty
218-
def getBinding(sym: Symbol): Bind =
219-
typeBindings.getOrElseUpdate(sym, {
220-
val bindingBounds = sym.info
221-
val bsym = newPatternBoundSymbol(sym.name.toString.stripPrefix("$").toTypeName, bindingBounds, quoted.span)
222-
Bind(bsym, untpd.Ident(nme.WILDCARD).withType(bindingBounds)).withSpan(quoted.span)
223-
})
217+
val bindSymMapping: collection.Map[Symbol, Bind] = unapplyBindingsMapping(quoted)
224218

225219
object splitter extends tpd.TreeMap {
226220
private var variance: Int = 1
@@ -300,7 +294,7 @@ trait QuotesAndSplices {
300294
report.error(IllegalVariableInPatternAlternative(tdef.symbol.name), tdef.srcPos)
301295
if variance == -1 then
302296
tdef.symbol.addAnnotation(Annotation(New(ref(defn.QuotedRuntimePatterns_fromAboveAnnot.typeRef)).withSpan(tdef.span)))
303-
val bindingType = getBinding(tdef.symbol).symbol.typeRef
297+
val bindingType = bindSymMapping(tdef.symbol).symbol.typeRef
304298
val bindingTypeTpe = AppliedType(defn.QuotedTypeClass.typeRef, bindingType :: Nil)
305299
val sym = newPatternBoundSymbol(nameOfSyntheticGiven, bindingTypeTpe, tdef.span, flags = ImplicitVal)(using ctx0)
306300
buff += Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingTypeTpe)).withSpan(tdef.span)
@@ -337,7 +331,56 @@ trait QuotesAndSplices {
337331
new TreeTypeMap(typeMap = typeMap).transform(shape1)
338332
}
339333

340-
(typeBindings, shape2, patterns)
334+
(bindSymMapping, shape2, patterns)
335+
}
336+
337+
/** For each type variable defined in the quote pattern we generate an equivalent
338+
* binding that will be as type variable in the encoded `unapply` of the quote pattern.
339+
*
340+
* @return Mapping from type variable symbols defined in the quote pattern into
341+
* type variable `Bind` definitions for the `unapply` of the quote pattern.
342+
* This mapping retains the original type variable definition order.
343+
*/
344+
private def unapplyBindingsMapping(quoted: Tree)(using Context): collection.Map[Symbol, Bind] = {
345+
// Collect all existing type variable bindings and create new symbols for them.
346+
// The old info is used, it may contain references to the old symbols.
347+
val (oldBindings, newBindings) = {
348+
val seen = mutable.Set.empty[Symbol]
349+
val oldBindingsBuffer = mutable.LinkedHashSet.empty[Symbol]
350+
val newBindingsBuffer = mutable.ListBuffer.empty[Symbol]
351+
352+
new tpd.TreeTraverser {
353+
def traverse(tree: Tree)(using Context): Unit = tree match {
354+
case _: SplicePattern =>
355+
case Select(pat: Bind, _) if tree.symbol.isTypeSplice =>
356+
val sym = tree.tpe.dealias.typeSymbol
357+
if sym.exists then registerNewBindSym(sym)
358+
case tdef: TypeDef =>
359+
if tdef.symbol.hasAnnotation(defn.QuotedRuntimePatterns_patternTypeAnnot) then
360+
registerNewBindSym(tdef.symbol)
361+
traverseChildren(tdef)
362+
case _ =>
363+
traverseChildren(tree)
364+
}
365+
private def registerNewBindSym(sym: Symbol): Unit =
366+
if !seen(sym) then
367+
seen += sym
368+
oldBindingsBuffer += sym
369+
newBindingsBuffer += newSymbol(ctx.owner, sym.name.toString.stripPrefix("$").toTypeName, Case | sym.flags, sym.info, coord = quoted.span)
370+
}.traverse(quoted)
371+
(oldBindingsBuffer.toList, newBindingsBuffer.toList)
372+
}
373+
374+
// Replace symbols in `mapping` in the infos of the new symbol and register GADT bounds.
375+
// GADT bounds need to be added after the info is updated to avoid references to the old symbols.
376+
val newBindingsRefs = newBindings.map(_.typeRef)
377+
for newBindings <- newBindings do
378+
newBindings.info = newBindings.info.subst(oldBindings.toList, newBindingsRefs)
379+
ctx.gadtState.addToConstraint(newBindings) // This must be performed after the info has been updated
380+
381+
// Map into Bind nodes retaining the original order
382+
val newBindingBinds = newBindings.map(newSym => Bind(newSym, untpd.Ident(nme.WILDCARD).withType(newSym.info)).withSpan(quoted.span))
383+
mutable.LinkedHashMap.from(oldBindings.lazyZip(newBindingBinds))
341384
}
342385

343386
/** Type a quote pattern `case '{ <quoted> } =>` qiven the a current prototype. Typing the pattern
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import quoted.*
2+
3+
def foo(using Quotes)(x: Expr[Int]) =
4+
x match
5+
case '{ type t; type u <: `t`; f[`t`, `u`] } =>
6+
case '{ type u <: `t`; type t; f[`t`, `u`] } =>
7+
case '{ type t; type u <: `t`; g[F[`t`, `u`]] } =>
8+
case '{ type u <: `t`; type t; g[F[`t`, `u`]] } =>
9+
10+
def f[T, U <: T] = ???
11+
def g[T] = ???
12+
type F[T, U <: T]

0 commit comments

Comments
 (0)