Skip to content

Commit ea2f748

Browse files
committed
Replace quoted type variables in signature of HOAS pattern result
To be able to construct the lambda returned by the HOAS pattern we need: first resolve the type variables and then use the result to construct the signature of the lambdas. To simplify this transformation, `QuoteMatcher` returns a `Seq[MatchResult]` instead of an untyped `Tuple` containing `Expr[?]`. The tuple is created once we have accumulated and processed all extracted values. Fixes #15165
1 parent d298a3b commit ea2f748

File tree

8 files changed

+105
-39
lines changed

8 files changed

+105
-39
lines changed

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ object QuoteMatcher {
121121

122122
private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env)
123123

124-
def treeMatch(scrutineeTerm: Tree, patternTerm: Tree)(using Context): Option[Tuple] =
124+
def treeMatch(scrutineeTerm: Tree, patternTerm: Tree)(using Context): Option[Seq[MatchResult]] =
125125
given Env = Map.empty
126126
scrutineeTerm =?= patternTerm
127127

@@ -203,31 +203,12 @@ object QuoteMatcher {
203203
// Matches an open term and wraps it into a lambda that provides the free variables
204204
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
205205
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
206-
def hoasClosure = {
207-
val names: List[TermName] = args.map {
208-
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
209-
case arg => arg.symbol.name.asTermName
210-
}
211-
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
212-
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
213-
val meth = newAnonFun(ctx.owner, methTpe)
214-
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
215-
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
216-
val body = new TreeMap {
217-
override def transform(tree: Tree)(using Context): Tree =
218-
tree match
219-
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
220-
case tree => super.transform(tree)
221-
}.transform(scrutinee)
222-
TreeOps(body).changeNonLocalOwners(meth)
223-
}
224-
Closure(meth, bodyFn)
225-
}
206+
val env = summon[Env]
226207
val capturedArgs = args.map(_.symbol)
227208
val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v))
228209
withEnv(captureEnv) {
229210
scrutinee match
230-
case ClosedPatternTerm(scrutinee) => matched(hoasClosure)
211+
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
231212
case _ => notMatched
232213
}
233214

@@ -452,20 +433,52 @@ object QuoteMatcher {
452433
accumulator.apply(Set.empty, term)
453434
}
454435

436+
enum MatchResult:
437+
case ClosedTree(tree: Tree)
438+
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)
439+
440+
def toExpr(mapTypeHoles: TypeMap)(using Context): Expr[Any] = this match
441+
case MatchResult.ClosedTree(tree) =>
442+
new ExprImpl(tree, SpliceScope.getCurrent)
443+
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
444+
def hoasClosure = {
445+
val names: List[TermName] = args.map {
446+
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
447+
case arg => arg.symbol.name.asTermName
448+
}
449+
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
450+
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
451+
val meth = newAnonFun(ctx.owner, methTpe)
452+
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
453+
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
454+
val body = new TreeMap {
455+
override def transform(tree: Tree)(using Context): Tree =
456+
tree match
457+
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
458+
case tree => super.transform(tree)
459+
}.transform(tree)
460+
TreeOps(body).changeNonLocalOwners(meth)
461+
}
462+
Closure(meth, bodyFn)
463+
}
464+
new ExprImpl(hoasClosure, SpliceScope.getCurrent)
465+
455466
/** Result of matching a part of an expression */
456-
private type Matching = Option[Tuple]
467+
private type Matching = Option[Seq[MatchResult]]
457468

458469
private object Matching {
459470

460471
def notMatched: Matching = None
461472

462-
val matched: Matching = Some(Tuple())
473+
val matched: Matching = Some(Seq())
463474

464475
def matched(tree: Tree)(using Context): Matching =
465-
Some(Tuple1(new ExprImpl(tree, SpliceScope.getCurrent)))
476+
Some(Seq(MatchResult.ClosedTree(tree)))
477+
478+
def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): Matching =
479+
Some(Seq(MatchResult.OpenTree(tree, patternTpe, args, env)))
466480

467481
extension (self: Matching)
468-
def asOptionOfTuple: Option[Tuple] = self
469482

470483
/** Concatenates the contents of two successful matchings or return a `notMatched` */
471484
def &&& (that: => Matching): Matching = self match {

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3137,20 +3137,27 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
31373137
ctx1.gadtState.addToConstraint(typeHoles)
31383138
ctx1
31393139

3140-
val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1)
3141-
3142-
if typeHoles.isEmpty then matchings
3143-
else {
3144-
// After matching and doing all subtype checks, we have to approximate all the type bindings
3145-
// that we have found, seal them in a quoted.Type and add them to the result
3146-
def typeHoleApproximation(sym: Symbol) =
3147-
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
3148-
val fullBounds = ctx1.gadt.fullBounds(sym)
3149-
val tp = if fromAboveAnnot then fullBounds.hi else fullBounds.lo
3150-
reflect.TypeReprMethods.asType(tp)
3151-
matchings.map { tup =>
3152-
Tuple.fromIArray(typeHoles.map(typeHoleApproximation).toArray.asInstanceOf[IArray[Object]]) ++ tup
3140+
// After matching and doing all subtype checks, we have to approximate all the type bindings
3141+
// that we have found, seal them in a quoted.Type and add them to the result
3142+
def typeHoleApproximation(sym: Symbol) =
3143+
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
3144+
val fullBounds = ctx1.gadt.fullBounds(sym)
3145+
if fromAboveAnnot then fullBounds.hi else fullBounds.lo
3146+
3147+
QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1).map { matchings =>
3148+
import QuoteMatcher.MatchResult.*
3149+
lazy val spliceScope = SpliceScope.getCurrent
3150+
val typeHoleApproximations = typeHoles.map(typeHoleApproximation)
3151+
val typeHoleMapping = Map(typeHoles.zip(typeHoleApproximations)*)
3152+
val typeHoleMap = new Types.TypeMap {
3153+
def apply(tp: Types.Type): Types.Type = tp match
3154+
case Types.TypeRef(Types.NoPrefix, _) => typeHoleMapping.getOrElse(tp.typeSymbol, tp)
3155+
case _ => mapOver(tp)
31533156
}
3157+
val matchedExprs = matchings.map(_.toExpr(typeHoleMap))
3158+
val matchedTypes = typeHoleApproximations.map(reflect.TypeReprMethods.asType)
3159+
val results = matchedTypes ++ matchedExprs
3160+
Tuple.fromIArray(results.toArray.asInstanceOf[IArray[Object]])
31543161
}
31553162
}
31563163

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.quoted.*
2+
3+
inline def valToFun[T](inline expr: T): T =
4+
${ impl('expr) }
5+
6+
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
7+
expr match
8+
case '{ { val ident = ($a: α); $rest(ident): T } } =>
9+
'{ { (y: α) => $rest(y) }.apply(???) }

tests/pos-macros/i15165a/Test_2.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = valToFun {
2+
val a: Int = 1
3+
a + 1
4+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted.*
2+
3+
inline def valToFun[T](inline expr: T): T =
4+
${ impl('expr) }
5+
6+
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
7+
expr match
8+
case '{ { val ident = ($a: α); $rest(ident): T } } =>
9+
'{
10+
{ (y: α) =>
11+
${
12+
val bound = '{ ${ rest }(y) }
13+
Expr.betaReduce(bound)
14+
}
15+
}.apply($a)
16+
}

tests/pos-macros/i15165b/Test_2.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = valToFun {
2+
val a: Int = 1
3+
a + 1
4+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.quoted.*
2+
3+
inline def valToFun[T](inline expr: T): T =
4+
${ impl('expr) }
5+
6+
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
7+
expr match
8+
case '{ type α; { val ident = ($a: `α`); $rest(ident): `α` & T } } =>
9+
'{ { (y: α) => $rest(y) }.apply(???) }

tests/pos-macros/i15165c/Test_2.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = valToFun {
2+
val a: Int = 1
3+
a + 1
4+
}

0 commit comments

Comments
 (0)