Skip to content

Commit 20174d7

Browse files
committed
Replace quoted type variables in signature of HOAS pattern result
To be able to construct the lambda returned by the HOAS pattern we need: first resolve the type variables and then use the result to construct the signature of the lambdas. To simplify this transformation, `QuoteMatcher` returns a `Seq[MatchResult]` instead of an untyped `Tuple` containing `Expr[?]`. The tuple is created once we have accumulated and processed all extracted values. Fixes #15165
1 parent dbdca17 commit 20174d7

File tree

8 files changed

+122
-34
lines changed

8 files changed

+122
-34
lines changed

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ object QuoteMatcher {
109109
/** Sequence of matched expressions.
110110
* These expressions are part of the scrutinee and will be bound to the quote pattern term splices.
111111
*/
112-
type MatchingExprs = Seq[Expr[Any]]
112+
type MatchingExprs = Seq[MatchResult]
113113

114114
/** A map relating equivalent symbols from the scrutinee and the pattern
115115
* For example in
@@ -141,12 +141,13 @@ object QuoteMatcher {
141141
extension (scrutinee0: Tree)
142142

143143
/** Check that the trees match and return the contents from the pattern holes.
144-
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
144+
* Return a sequence containing all the contents in the holes.
145+
* If it does not match, continues to the `optional` with `None`.
145146
*
146147
* @param scrutinee The tree being matched
147148
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
148149
* @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
149-
* @return `None` if it did not match or `Some(tup: MatchingExprs)` if it matched where `tup` contains the contents of the holes.
150+
* @return The sequence with the contents of the holes of the matched expression.
150151
*/
151152
private def =?= (pattern0: Tree)(using Env, Context): optional[MatchingExprs] =
152153

@@ -205,31 +206,12 @@ object QuoteMatcher {
205206
// Matches an open term and wraps it into a lambda that provides the free variables
206207
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
207208
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
208-
def hoasClosure = {
209-
val names: List[TermName] = args.map {
210-
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
211-
case arg => arg.symbol.name.asTermName
212-
}
213-
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
214-
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
215-
val meth = newAnonFun(ctx.owner, methTpe)
216-
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
217-
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
218-
val body = new TreeMap {
219-
override def transform(tree: Tree)(using Context): Tree =
220-
tree match
221-
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
222-
case tree => super.transform(tree)
223-
}.transform(scrutinee)
224-
TreeOps(body).changeNonLocalOwners(meth)
225-
}
226-
Closure(meth, bodyFn)
227-
}
209+
val env = summon[Env]
228210
val capturedArgs = args.map(_.symbol)
229-
val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v))
211+
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v))
230212
withEnv(captureEnv) {
231213
scrutinee match
232-
case ClosedPatternTerm(scrutinee) => matched(hoasClosure)
214+
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
233215
case _ => notMatched
234216
}
235217

@@ -453,16 +435,67 @@ object QuoteMatcher {
453435
accumulator.apply(Set.empty, term)
454436
}
455437

438+
enum MatchResult:
439+
/** Closed pattern extracted value
440+
* @param tree Scrutinee sub-tree that matched
441+
*/
442+
case ClosedTree(tree: Tree)
443+
/** HOAS pattern extracted value
444+
*
445+
* @param tree Scrutinee sub-tree that matched
446+
* @param patternTpe Type of the pattern hole (from the pattern)
447+
* @param args HOAS arguments (from the pattern)
448+
* @param env Mapping between scrutinee and pattern variables
449+
*/
450+
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)
451+
452+
/** Return the expression that was extracted from a hole.
453+
*
454+
* If it was a closed expression it returns that expression. Otherwise,
455+
* if it is a HOAS pattern, the surrounding lambda is generated using
456+
* `mapTypeHoles` to create the signature of the lambda.
457+
*
458+
* This expression is assumed to be a valid expression in the given splice scope.
459+
*/
460+
def toExpr(mapTypeHoles: TypeMap, spliceScope: Scope)(using Context): Expr[Any] = this match
461+
case MatchResult.ClosedTree(tree) =>
462+
new ExprImpl(tree, spliceScope)
463+
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
464+
val names: List[TermName] = args.map {
465+
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
466+
case arg => arg.symbol.name.asTermName
467+
}
468+
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
469+
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
470+
val meth = newAnonFun(ctx.owner, methTpe)
471+
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
472+
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap
473+
val body = new TreeMap {
474+
override def transform(tree: Tree)(using Context): Tree =
475+
tree match
476+
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
477+
case tree => super.transform(tree)
478+
}.transform(tree)
479+
TreeOps(body).changeNonLocalOwners(meth)
480+
}
481+
val hoasClosure = Closure(meth, bodyFn)
482+
new ExprImpl(hoasClosure, spliceScope)
483+
456484
private inline def notMatched: optional[MatchingExprs] =
457485
optional.break()
458486

459487
private inline def matched: MatchingExprs =
460488
Seq.empty
461489

462490
private inline def matched(tree: Tree)(using Context): MatchingExprs =
463-
Seq(new ExprImpl(tree, SpliceScope.getCurrent))
491+
Seq(MatchResult.ClosedTree(tree))
492+
493+
private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs =
494+
Seq(MatchResult.OpenTree(tree, patternTpe, args, env))
464495

465496
extension (self: MatchingExprs)
466-
private inline def &&& (that: MatchingExprs): MatchingExprs = self ++ that
497+
/** Concatenates the contents of two successful matchings */
498+
def &&& (that: MatchingExprs): MatchingExprs = self ++ that
499+
end extension
467500

468501
}

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3137,18 +3137,27 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
31373137
ctx1.gadtState.addToConstraint(typeHoles)
31383138
ctx1
31393139

3140-
val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1)
3141-
31423140
// After matching and doing all subtype checks, we have to approximate all the type bindings
31433141
// that we have found, seal them in a quoted.Type and add them to the result
31443142
def typeHoleApproximation(sym: Symbol) =
31453143
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
31463144
val fullBounds = ctx1.gadt.fullBounds(sym)
3147-
val tp = if fromAboveAnnot then fullBounds.hi else fullBounds.lo
3148-
reflect.TypeReprMethods.asType(tp)
3149-
matchings.map { tup =>
3150-
val results = typeHoles.map(typeHoleApproximation) ++ tup
3151-
Tuple.fromIArray(results.toArray.asInstanceOf[IArray[Object]])
3145+
if fromAboveAnnot then fullBounds.hi else fullBounds.lo
3146+
3147+
QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1).map { matchings =>
3148+
import QuoteMatcher.MatchResult.*
3149+
lazy val spliceScope = SpliceScope.getCurrent
3150+
val typeHoleApproximations = typeHoles.map(typeHoleApproximation)
3151+
val typeHoleMapping = Map(typeHoles.zip(typeHoleApproximations)*)
3152+
val typeHoleMap = new Types.TypeMap {
3153+
def apply(tp: Types.Type): Types.Type = tp match
3154+
case Types.TypeRef(Types.NoPrefix, _) => typeHoleMapping.getOrElse(tp.typeSymbol, tp)
3155+
case _ => mapOver(tp)
3156+
}
3157+
val matchedExprs = matchings.map(_.toExpr(typeHoleMap, spliceScope))
3158+
val matchedTypes = typeHoleApproximations.map(reflect.TypeReprMethods.asType)
3159+
val results = matchedTypes ++ matchedExprs
3160+
Tuple.fromIArray(IArray.unsafeFromArray(results.toArray))
31523161
}
31533162
}
31543163

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.quoted.*
2+
3+
inline def valToFun[T](inline expr: T): T =
4+
${ impl('expr) }
5+
6+
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
7+
expr match
8+
case '{ { val ident = ($a: α); $rest(ident): T } } =>
9+
'{ { (y: α) => $rest(y) }.apply(???) }

tests/pos-macros/i15165a/Test_2.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = valToFun {
2+
val a: Int = 1
3+
a + 1
4+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted.*
2+
3+
inline def valToFun[T](inline expr: T): T =
4+
${ impl('expr) }
5+
6+
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
7+
expr match
8+
case '{ { val ident = ($a: α); $rest(ident): T } } =>
9+
'{
10+
{ (y: α) =>
11+
${
12+
val bound = '{ ${ rest }(y) }
13+
Expr.betaReduce(bound)
14+
}
15+
}.apply($a)
16+
}

tests/pos-macros/i15165b/Test_2.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = valToFun {
2+
val a: Int = 1
3+
a + 1
4+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.quoted.*
2+
3+
inline def valToFun[T](inline expr: T): T =
4+
${ impl('expr) }
5+
6+
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
7+
expr match
8+
case '{ type α; { val ident = ($a: `α`); $rest(ident): `α` & T } } =>
9+
'{ { (y: α) => $rest(y) }.apply(???) }

tests/pos-macros/i15165c/Test_2.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = valToFun {
2+
val a: Int = 1
3+
a + 1
4+
}

0 commit comments

Comments
 (0)