Skip to content

Fix bounds of encoded type variables in quote patterns #17956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 60 additions & 14 deletions compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import dotty.tools.dotc.util.Spans._
import dotty.tools.dotc.util.Stats.record
import dotty.tools.dotc.reporting.IllegalVariableInPatternAlternative
import scala.collection.mutable
import scala.collection.SeqMap

/** Type quotes `'{ ... }` and splices `${ ... }` */
trait QuotesAndSplices {
Expand Down Expand Up @@ -202,7 +203,7 @@ trait QuotesAndSplices {
* will return
* ```
* (
* Map(<t$giveni>: Symbol -> <t @ _>: Bind),
* Map(<t$giveni>: Symbol -> <t>: Symbol),
* <'{
* @scala.internal.Quoted.patternType type t
* scala.internal.Quoted.patternHole[List[t]]
Expand All @@ -211,16 +212,10 @@ trait QuotesAndSplices {
* )
* ```
*/
private def splitQuotePattern(quoted: Tree)(using Context): (collection.Map[Symbol, Bind], Tree, List[Tree]) = {
private def splitQuotePattern(quoted: Tree)(using Context): (SeqMap[Symbol, Symbol], Tree, List[Tree]) = {
val ctx0 = ctx

val typeBindings: mutable.Map[Symbol, Bind] = mutable.LinkedHashMap.empty
def getBinding(sym: Symbol): Bind =
typeBindings.getOrElseUpdate(sym, {
val bindingBounds = sym.info
val bsym = newPatternBoundSymbol(sym.name.toString.stripPrefix("$").toTypeName, bindingBounds, quoted.span)
Bind(bsym, untpd.Ident(nme.WILDCARD).withType(bindingBounds)).withSpan(quoted.span)
})
val bindSymMapping: SeqMap[Symbol, Symbol] = unapplyBindingsMapping(quoted)

object splitter extends tpd.TreeMap {
private var variance: Int = 1
Expand Down Expand Up @@ -300,7 +295,7 @@ trait QuotesAndSplices {
report.error(IllegalVariableInPatternAlternative(tdef.symbol.name), tdef.srcPos)
if variance == -1 then
tdef.symbol.addAnnotation(Annotation(New(ref(defn.QuotedRuntimePatterns_fromAboveAnnot.typeRef)).withSpan(tdef.span)))
val bindingType = getBinding(tdef.symbol).symbol.typeRef
val bindingType = bindSymMapping(tdef.symbol).typeRef
val bindingTypeTpe = AppliedType(defn.QuotedTypeClass.typeRef, bindingType :: Nil)
val sym = newPatternBoundSymbol(nameOfSyntheticGiven, bindingTypeTpe, tdef.span, flags = ImplicitVal)(using ctx0)
buff += Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingTypeTpe)).withSpan(tdef.span)
Expand Down Expand Up @@ -337,7 +332,55 @@ trait QuotesAndSplices {
new TreeTypeMap(typeMap = typeMap).transform(shape1)
}

(typeBindings, shape2, patterns)
(bindSymMapping, shape2, patterns)
}

/** For each type variable defined in the quote pattern we generate an equivalent
* binding that will be as type variable in the encoded `unapply` of the quote pattern.
*
* @return Mapping from type variable symbols defined in the quote pattern into
* type variable definitions for the `unapply` of the quote pattern.
* This mapping retains the original type variable definition order.
*/
private def unapplyBindingsMapping(quoted: Tree)(using Context): SeqMap[Symbol, Symbol] = {
// Collect all existing type variable bindings and create new symbols for them.
// The old info is used, it may contain references to the old symbols.
val (oldBindings, newBindings) = {
val seen = mutable.Set.empty[Symbol]
val oldBindingsBuffer = mutable.LinkedHashSet.empty[Symbol]
val newBindingsBuffer = mutable.ListBuffer.empty[Symbol]
Comment on lines +350 to +351
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're returning a map from old bindings to new bindings without creating Bind nodes, I think the original idea of directly having a LinkedHashMap here makes more sense, sorry for the back and forth!


new tpd.TreeTraverser {
def traverse(tree: Tree)(using Context): Unit = tree match {
case _: SplicePattern =>
case Select(pat: Bind, _) if tree.symbol.isTypeSplice =>
val sym = tree.tpe.dealias.typeSymbol
if sym.exists then registerNewBindSym(sym)
case tdef: TypeDef =>
if tdef.symbol.hasAnnotation(defn.QuotedRuntimePatterns_patternTypeAnnot) then
registerNewBindSym(tdef.symbol)
traverseChildren(tdef)
case _ =>
traverseChildren(tree)
}
private def registerNewBindSym(sym: Symbol): Unit =
if !seen(sym) then
seen += sym
oldBindingsBuffer += sym
newBindingsBuffer += newSymbol(ctx.owner, sym.name.toString.stripPrefix("$").toTypeName, Case | sym.flags, sym.info, coord = quoted.span)
}.traverse(quoted)
(oldBindingsBuffer.toList, newBindingsBuffer.toList)
}

// Replace symbols in `mapping` in the infos of the new symbol and register GADT bounds.
// GADT bounds need to be added after the info is updated to avoid references to the old symbols.
val newBindingsRefs = newBindings.map(_.typeRef)
for newBindings <- newBindings do
newBindings.info = newBindings.info.subst(oldBindings.toList, newBindingsRefs)
ctx.gadtState.addToConstraint(newBindings) // This must be performed after the info has been updated

// Map into Bind nodes retaining the original order
mutable.LinkedHashMap.from(oldBindings.lazyZip(newBindings))
}

/** Type a quote pattern `case '{ <quoted> } =>` qiven the a current prototype. Typing the pattern
Expand Down Expand Up @@ -427,20 +470,23 @@ trait QuotesAndSplices {
else tpd.Block(typeTypeVariables, pattern)
}

val (typeBindings, shape, splices) = splitQuotePattern(quoted1)
val (bindSymMapping, shape, splices) = splitQuotePattern(quoted1)

class ReplaceBindings extends TypeMap() {
override def apply(tp: Type): Type = tp match {
case tp: TypeRef =>
val tp1 = if (tp.symbol.isTypeSplice) tp.dealias else tp
mapOver(typeBindings.get(tp1.typeSymbol).fold(tp)(_.symbol.typeRef))
mapOver(bindSymMapping.get(tp1.typeSymbol).fold(tp)(_.typeRef))
case tp => mapOver(tp)
}
}
val replaceBindings = new ReplaceBindings
val patType = defn.tupleType(splices.tpes.map(tpe => replaceBindings(tpe.widen)))

val typeBindingsTuple = tpd.hkNestedPairsTypeTree(typeBindings.values.toList)
val typeBinds = bindSymMapping.values.toList.map(sym =>
Bind(sym, untpd.Ident(nme.WILDCARD).withType(sym.info)).withSpan(quoted.span)
)
val typeBindingsTuple = tpd.hkNestedPairsTypeTree(typeBinds)

val replaceBindingsInTree = new TreeMap {
private var bindMap = Map.empty[Symbol, Symbol]
Expand Down
12 changes: 12 additions & 0 deletions tests/pos-macros/quote-pattern-type-variable-bounds.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import quoted.*

def foo(using Quotes)(x: Expr[Int]) =
x match
case '{ type t; type u <: `t`; f[`t`, `u`] } =>
case '{ type u <: `t`; type t; f[`t`, `u`] } =>
case '{ type t; type u <: `t`; g[F[`t`, `u`]] } =>
case '{ type u <: `t`; type t; g[F[`t`, `u`]] } =>

def f[T, U <: T] = ???
def g[T] = ???
type F[T, U <: T]