Skip to content

Commit 82022b8

Browse files
committed
Fix bounds of encoded type variables in quote patterns
When we encode quote patterns into `unapply` methods, we need to create a copy of each type variable. One copy is kept within the quote(in the next stage) and the other is used in the `unapply` method to define the usual pattern type variables. When creating the latter we copied the symbols but did not update the infos. This implies that if type variables would be bounded by each other, the bounds of the copies would be the original types instead of the copies. We need to update those references. To update the info we now create all the symbols in one pass and the update all their infos in a second pass. This also implies that we cannot use the `newPatternBoundSymbol` to create the symbol as this constructor will register the info into GADT bounds. Instead we use the plain `newSymbol`. Then in the second pass, when we have updated the infos, we register the symbol into GADT bounds. Note that the code in the added test does compiles correctly, but it had the inconsistent bounds. This test is added in case we need to manually inspect the bounds latter. This test does fail to compile in #17935 if this fix is not applied.
1 parent 229dc12 commit 82022b8

File tree

2 files changed

+64
-9
lines changed

2 files changed

+64
-9
lines changed

compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,7 @@ trait QuotesAndSplices {
202202
private def splitQuotePattern(quoted: Tree)(using Context): (collection.Map[Symbol, Bind], Tree, List[Tree]) = {
203203
val ctx0 = ctx
204204

205-
val typeBindings: mutable.Map[Symbol, Bind] = mutable.LinkedHashMap.empty
206-
def getBinding(sym: Symbol): Bind =
207-
typeBindings.getOrElseUpdate(sym, {
208-
val bindingBounds = sym.info
209-
val bsym = newPatternBoundSymbol(sym.name.toString.stripPrefix("$").toTypeName, bindingBounds, quoted.span)
210-
Bind(bsym, untpd.Ident(nme.WILDCARD).withType(bindingBounds)).withSpan(quoted.span)
211-
})
205+
val bindSymMapping: collection.Map[Symbol, Bind] = unapplyBindingsMapping(quoted)
212206

213207
object splitter extends tpd.TreeMap {
214208
private var variance: Int = 1
@@ -288,7 +282,7 @@ trait QuotesAndSplices {
288282
report.error(IllegalVariableInPatternAlternative(tdef.symbol.name), tdef.srcPos)
289283
if variance == -1 then
290284
tdef.symbol.addAnnotation(Annotation(New(ref(defn.QuotedRuntimePatterns_fromAboveAnnot.typeRef)).withSpan(tdef.span)))
291-
val bindingType = getBinding(tdef.symbol).symbol.typeRef
285+
val bindingType = bindSymMapping(tdef.symbol).symbol.typeRef
292286
val bindingTypeTpe = AppliedType(defn.QuotedTypeClass.typeRef, bindingType :: Nil)
293287
val sym = newPatternBoundSymbol(nameOfSyntheticGiven, bindingTypeTpe, tdef.span, flags = ImplicitVal)(using ctx0)
294288
buff += Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingTypeTpe)).withSpan(tdef.span)
@@ -325,7 +319,56 @@ trait QuotesAndSplices {
325319
new TreeTypeMap(typeMap = typeMap).transform(shape1)
326320
}
327321

328-
(typeBindings, shape2, patterns)
322+
(bindSymMapping, shape2, patterns)
323+
}
324+
325+
/** For each type variable defined in the quote pattern we generate an equivalent
326+
* binding that will be as type variable in the encoded `unapply` of the quote pattern.
327+
*
328+
* @return Mapping from type variable symbols defined in the quote pattern into
329+
* type variable `Bind` definitions for the `unapply` of the quote pattern.
330+
* This mapping retains the original type variable definition order.
331+
*/
332+
private def unapplyBindingsMapping(quoted: Tree)(using Context): collection.Map[Symbol, Bind] = {
333+
// Collect all existing type variable bindings and create new symbols for them.
334+
// The old info is used, it may contain references to the old symbols.
335+
val (oldBindings, newBindings) = {
336+
val seen = mutable.Set.empty[Symbol]
337+
val oldBindingsBuffer = mutable.LinkedHashSet.empty[Symbol]
338+
val newBindingsBuffer = mutable.ListBuffer.empty[Symbol]
339+
340+
new tpd.TreeTraverser {
341+
def traverse(tree: Tree)(using Context): Unit = tree match {
342+
case _: SplicePattern =>
343+
case Select(pat: Bind, _) if tree.symbol.isTypeSplice =>
344+
val sym = tree.tpe.dealias.typeSymbol
345+
if sym.exists then registerNewBindSym(sym)
346+
case tdef: TypeDef =>
347+
if tdef.symbol.hasAnnotation(defn.QuotedRuntimePatterns_patternTypeAnnot) then
348+
registerNewBindSym(tdef.symbol)
349+
traverseChildren(tdef)
350+
case _ =>
351+
traverseChildren(tree)
352+
}
353+
private def registerNewBindSym(sym: Symbol): Unit =
354+
if !seen(sym) then
355+
seen += sym
356+
oldBindingsBuffer += sym
357+
newBindingsBuffer += newSymbol(ctx.owner, sym.name.toString.stripPrefix("$").toTypeName, Case | sym.flags, sym.info, coord = quoted.span)
358+
}.traverse(quoted)
359+
(oldBindingsBuffer.toList, newBindingsBuffer.toList)
360+
}
361+
362+
// Replace symbols in `mapping` in the infos of the new symbol and register GADT bounds.
363+
// GADT bounds need to be added after the info is updated to avoid references to the old symbols.
364+
val newBindingsRefs = newBindings.map(_.typeRef)
365+
for newBindings <- newBindings do
366+
newBindings.info = newBindings.info.subst(oldBindings.toList, newBindingsRefs)
367+
ctx.gadtState.addToConstraint(newBindings) // This must be performed after the info has been updated
368+
369+
// Map into Bind nodes retaining the original order
370+
val newBindingBinds = newBindings.map(newSym => Bind(newSym, untpd.Ident(nme.WILDCARD).withType(newSym.info)).withSpan(quoted.span))
371+
mutable.LinkedHashMap.from(oldBindings.lazyZip(newBindingBinds))
329372
}
330373

331374
/** Type a quote pattern `case '{ <quoted> } =>` qiven the a current prototype. Typing the pattern
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import quoted.*
2+
3+
def foo(using Quotes)(x: Expr[Int]) =
4+
x match
5+
case '{ type t; type u <: `t`; f[`t`, `u`] } =>
6+
case '{ type u <: `t`; type t; f[`t`, `u`] } =>
7+
case '{ type t; type u <: `t`; g[F[`t`, `u`]] } =>
8+
case '{ type u <: `t`; type t; g[F[`t`, `u`]] } =>
9+
10+
def f[T, U <: T] = ???
11+
def g[T] = ???
12+
type F[T, U <: T]

0 commit comments

Comments
 (0)