Skip to content

Commit 6c5de07

Browse files
committed
Add quoted pattern type splices runtime
1 parent e0dbba6 commit 6c5de07

File tree

6 files changed

+186
-24
lines changed

6 files changed

+186
-24
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
3939

4040
def Context_withoutColors(self: Context): Context = ctx.fresh.setSetting(ctx.settings.color, "never")
4141

42+
def Context_GADT_setFreshGADTBounds(self: Context): Context =
43+
self.fresh.setFreshGADTBounds.addMode(Mode.GADTflexible)
44+
45+
def Context_GADT_addToConstraint(self: Context)(syms: List[Symbol]): Boolean =
46+
self.gadt.addToConstraint(syms)
47+
48+
def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type =
49+
self.gadt.approximation(sym, fromBelow)
50+
4251
//
4352
// REPORTING
4453
//

library/src-3.x/scala/internal/Quoted.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@ object Quoted {
2424
/** A splice of a name in a quoted pattern is desugared by wrapping getting this annotation */
2525
class patternBindHole extends Annotation
2626

27+
class patternType extends Annotation
28+
2729
}

library/src-3.x/scala/internal/quoted/Matcher.scala

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ object Matcher {
4747
}
4848

4949
/** Check that all trees match with =#= and concatenate the results with && */
50-
def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Env: Matching =
50+
def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Context, Env: Matching =
5151
matchLists(scrutinees, patterns)(_ =#= _)
5252

5353
/** Check that the trees match and return the contents from the pattern holes.
@@ -58,7 +58,17 @@ object Matcher {
5858
* @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
5959
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
6060
*/
61-
def (scrutinee: Tree) =#= (pattern: Tree) given Env: Matching = {
61+
def (scrutinee0: Tree) =#= (pattern0: Tree) given Context, Env: Matching = {
62+
63+
/** Normalieze the tree */
64+
def normalize(tree: Tree): Tree = tree match {
65+
case Block(Nil, expr) => normalize(expr)
66+
case Inlined(_, Nil, expr) => normalize(expr)
67+
case _ => tree
68+
}
69+
70+
val scrutinee = normalize(scrutinee0)
71+
val pattern = normalize(pattern0)
6272

6373
/** Check that both are `val` or both are `lazy val` or both are `var` **/
6474
def checkValFlags(): Boolean = {
@@ -80,14 +90,7 @@ object Matcher {
8090
def hasBindAnnotation(sym: Symbol) =
8191
sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),"<init>"),List()) => true; case _ => true }
8292

83-
/** Normalieze the tree */
84-
def normalize(tree: Tree): Tree = tree match {
85-
case Block(Nil, expr) => normalize(expr)
86-
case Inlined(_, Nil, expr) => normalize(expr)
87-
case _ => tree
88-
}
89-
90-
(normalize(scrutinee), normalize(pattern)) match {
93+
(scrutinee, pattern) match {
9194

9295
// Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree
9396
case (IsTerm(scrutinee @ Typed(s, tpt1)), Typed(TypeApply(patternHole, tpt :: Nil), tpt2))
@@ -112,6 +115,9 @@ object Matcher {
112115
case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
113116
expr1 =#= expr2 && tpt1 =#= tpt2
114117

118+
case (scrutinee, Typed(expr2, _)) =>
119+
scrutinee =#= expr2
120+
115121
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || the[Env].apply((scrutinee.symbol, pattern.symbol)) =>
116122
matched
117123

@@ -144,9 +150,6 @@ object Matcher {
144150
case (While(cond1, body1), While(cond2, body2)) =>
145151
cond1 =#= cond2 && body1 =#= body2
146152

147-
case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 =>
148-
expr1 =#= expr2
149-
150153
case (New(tpt1), New(tpt2)) =>
151154
tpt1 =#= tpt2
152155

@@ -159,10 +162,11 @@ object Matcher {
159162
case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
160163
elems1 =##= elems2
161164

165+
// TODO is this case required
162166
case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol =>
163167
matched
164168

165-
case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe =>
169+
case (IsTypeTree(scrutinee), IsTypeTree(pattern)) if scrutinee.tpe <:< pattern.tpe =>
166170
matched
167171

168172
case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
@@ -173,7 +177,7 @@ object Matcher {
173177
if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
174178
else matched
175179
def rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
176-
bindMatch && tpt1 =#= tpt2 && (treeOptMatches(rhs1, rhs2) given rhsEnv)
180+
bindMatch && tpt1 =#= tpt2 && (treeOptMatches(rhs1, rhs2) given (the[Context], rhsEnv))
177181

178182
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
179183
val bindMatch =
@@ -229,15 +233,15 @@ object Matcher {
229233
}
230234
}
231235

232-
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Matching = {
236+
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Context, Env: Matching = {
233237
(scrutinee, pattern) match {
234238
case (Some(x), Some(y)) => x =#= y
235239
case (None, None) => matched
236240
case _ => notMatched
237241
}
238242
}
239243

240-
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Matching = {
244+
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Context, Env: Matching = {
241245
val (caseEnv, patternMatch) = scrutinee.pattern =%= pattern.pattern
242246
withEnv(caseEnv) {
243247
patternMatch &&
@@ -256,7 +260,7 @@ object Matcher {
256260
* @return The new environment containing the bindings defined in this pattern tuppled with
257261
* `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
258262
*/
259-
def (scrutinee: Pattern) =%= (pattern: Pattern) given Env: (Env, Matching) = (scrutinee, pattern) match {
263+
def (scrutinee: Pattern) =%= (pattern: Pattern) given Context, Env: (Env, Matching) = (scrutinee, pattern) match {
260264
case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
261265
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
262266
(the[Env], matched(v1.seal))
@@ -266,7 +270,7 @@ object Matcher {
266270

267271
case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) =>
268272
val bindEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
269-
(body1 =%= body2) given bindEnv
273+
(body1 =%= body2) given (the[Context], bindEnv)
270274

271275
case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) =>
272276
val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
@@ -302,16 +306,33 @@ object Matcher {
302306
(the[Env], notMatched)
303307
}
304308

305-
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Matching) = {
309+
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Context, Env: (Env, Matching) = {
306310
if (patterns1.size != patterns2.size) (the[Env], notMatched)
307311
else patterns1.zip(patterns2).foldLeft((the[Env], matched)) { (acc, x) =>
308-
val (env, res) = (x._1 =%= x._2) given acc._1
312+
val (env, res) = (x._1 =%= x._2) given (the[Context], acc._1)
309313
(env, acc._2 && res)
310314
}
311315
}
312316

317+
def isTypeBinding(tree: Tree): Boolean = tree match {
318+
case IsTypeDef(tree) =>
319+
tree.symbol.annots.exists(_.symbol.owner.fullName == "scala.internal.Quoted$.patternType")
320+
case _ => false
321+
}
322+
313323
implied for Env = Set.empty
314-
(scrutineeExpr.unseal =#= patternExpr.unseal).asOptionOfTuple.asInstanceOf[Option[Tup]]
324+
val res = patternExpr.unseal.underlyingArgument match {
325+
case Block(typeBindings, pattern) if typeBindings.forall(isTypeBinding) =>
326+
implicit val ctx2 = reflection.kernel.Context_GADT_setFreshGADTBounds(rootContext)
327+
val bindingSymbols = typeBindings.map(_.symbol(ctx2))
328+
reflection.kernel.Context_GADT_addToConstraint(ctx2)(bindingSymbols)
329+
val matchings = scrutineeExpr.unseal.underlyingArgument =#= pattern
330+
val constainedTypes = bindingSymbols.map(s => reflection.kernel.Context_GADT_approximation(ctx2)(s, true))
331+
constainedTypes.foldRight(matchings)((x, acc) => matched(x.seal) && acc)
332+
case pattern =>
333+
scrutineeExpr.unseal.underlyingArgument =#= pattern
334+
}
335+
res.asOptionOfTuple.asInstanceOf[Option[Tup]]
315336
}
316337

317338
/** Result of matching a part of an expression */

library/src/scala/tasty/reflect/Kernel.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ trait Kernel {
149149
/** Returns a new context where printColors is false */
150150
def Context_withoutColors(self: Context): Context
151151

152+
def Context_GADT_setFreshGADTBounds(self: Context): Context
153+
def Context_GADT_addToConstraint(self: Context)(syms: List[Symbol]): Boolean
154+
def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type
152155

153156
//
154157
// REPORTING

tests/run-macros/quote-matcher-runtime.check

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Result: Some(List())
1616

1717
Scrutinee: 1
1818
Pattern: (1: scala.Int)
19-
Result: None
19+
Result: Some(List())
2020

2121
Scrutinee: 3
2222
Pattern: scala.internal.Quoted.patternHole[scala.Int]
@@ -714,3 +714,118 @@ Pattern: try scala.internal.Quoted.patternHole[scala.Int] finally {
714714
}
715715
Result: Some(List(Expr(1), Expr(2)))
716716

717+
Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int])).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x)))
718+
Pattern: {
719+
@scala.internal.Quoted.patternType type T
720+
scala.internal.Quoted.patternHole[scala.List[scala.Int]].foreach[T](scala.internal.Quoted.patternHole[scala.Function1[scala.Int, T]])
721+
}
722+
Result: Some(List(Type(scala.Unit), Expr(scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int]))), Expr(((x: scala.Int) => scala.Predef.println(x)))))
723+
724+
Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int])).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x)))
725+
Pattern: {
726+
@scala.internal.Quoted.patternType type T = scala.Unit
727+
scala.internal.Quoted.patternHole[scala.List[scala.Int]].foreach[T](scala.internal.Quoted.patternHole[scala.Function1[scala.Int, T]])
728+
}
729+
Result: Some(List(Type(scala.Unit), Expr(scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int]))), Expr(((x: scala.Int) => scala.Predef.println(x)))))
730+
731+
Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int])).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x)))
732+
Pattern: {
733+
@scala.internal.Quoted.patternType type T <: scala.Predef.String
734+
scala.internal.Quoted.patternHole[scala.List[scala.Int]].foreach[T](scala.internal.Quoted.patternHole[scala.Function1[scala.Int, T]])
735+
}
736+
Result: None
737+
738+
Scrutinee: {
739+
val a: scala.Int = 4
740+
val b: scala.Int = 4
741+
()
742+
}
743+
Pattern: {
744+
@scala.internal.Quoted.patternType type T
745+
val a: T = scala.internal.Quoted.patternHole[T]
746+
val b: T = scala.internal.Quoted.patternHole[T]
747+
()
748+
}
749+
Result: Some(List(Type(scala.Int), Expr(4), Expr(4)))
750+
751+
Scrutinee: {
752+
val a: scala.Int = 4
753+
val b: scala.Int = 5
754+
()
755+
}
756+
Pattern: {
757+
@scala.internal.Quoted.patternType type T
758+
val a: T = scala.internal.Quoted.patternHole[T]
759+
val b: T = scala.internal.Quoted.patternHole[T]
760+
()
761+
}
762+
Result: Some(List(Type(scala.Int), Expr(4), Expr(5)))
763+
764+
Scrutinee: {
765+
val a: scala.Int = 4
766+
val b: scala.Predef.String = "x"
767+
()
768+
}
769+
Pattern: {
770+
@scala.internal.Quoted.patternType type T
771+
val a: T = scala.internal.Quoted.patternHole[T]
772+
val b: T = scala.internal.Quoted.patternHole[T]
773+
()
774+
}
775+
Result: Some(List(Type(scala.Int | java.lang.String), Expr(4), Expr("x")))
776+
777+
Scrutinee: {
778+
val a: scala.Int = 4
779+
val b: scala.Predef.String = "x"
780+
()
781+
}
782+
Pattern: {
783+
@scala.internal.Quoted.patternType type T <: scala.Int
784+
val a: T = scala.internal.Quoted.patternHole[T]
785+
val b: T = scala.internal.Quoted.patternHole[T]
786+
()
787+
}
788+
Result: None
789+
790+
Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int])).map[scala.Double, scala.collection.immutable.List[scala.Double]](((x: scala.Int) => x.toDouble./(2)))(scala.collection.immutable.List.canBuildFrom[scala.Double]).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((y: scala.Double) => y.toString()))(scala.collection.immutable.List.canBuildFrom[java.lang.String])
791+
Pattern: {
792+
@scala.internal.Quoted.patternType type T
793+
@scala.internal.Quoted.patternType type U
794+
@scala.internal.Quoted.patternType type V
795+
796+
(scala.internal.Quoted.patternHole[scala.List[T]].map[U, scala.collection.immutable.List[U]](scala.internal.Quoted.patternHole[scala.Function1[T, U]])(scala.collection.immutable.List.canBuildFrom[U]).map[V, scala.collection.immutable.List[V]](scala.internal.Quoted.patternHole[scala.Function1[U, V]])(scala.collection.immutable.List.canBuildFrom[V]): scala.collection.immutable.List[scala.Any])
797+
}
798+
Result: Some(List(Type(scala.Int), Type(scala.Double), Type(java.lang.String), Expr(scala.List.apply[scala.Int]((1, 2, 3: scala.<repeated>[scala.Int]))), Expr(((x: scala.Int) => x.toDouble./(2))), Expr(((y: scala.Double) => y.toString()))))
799+
800+
Scrutinee: ((x: scala.Int) => x)
801+
Pattern: {
802+
@scala.internal.Quoted.patternType type T
803+
804+
(scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any])
805+
}
806+
Result: Some(List(Type(scala.Int), Expr(((x: scala.Int) => x))))
807+
808+
Scrutinee: ((x: scala.Int) => x.toString())
809+
Pattern: {
810+
@scala.internal.Quoted.patternType type T
811+
812+
(scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any])
813+
}
814+
Result: None
815+
816+
Scrutinee: ((x: scala.Any) => scala.Predef.???)
817+
Pattern: {
818+
@scala.internal.Quoted.patternType type T
819+
820+
(scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any])
821+
}
822+
Result: Some(List(Type(scala.Nothing), Expr(((x: scala.Any) => scala.Predef.???))))
823+
824+
Scrutinee: ((x: scala.Nothing) => (1: scala.Any))
825+
Pattern: {
826+
@scala.internal.Quoted.patternType type T
827+
828+
(scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any])
829+
}
830+
Result: None
831+

tests/run-macros/quote-matcher-runtime/quoted_2.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import Macros._
33

44
import scala.internal.quoted.Matcher._
55

6-
import scala.internal.Quoted.{patternHole, patternBindHole}
6+
import scala.internal.Quoted.{patternHole, patternBindHole, patternType}
77

88
object Test {
99

@@ -134,6 +134,18 @@ object Test {
134134
matches(try 1 finally 2, try 1 finally 2)
135135
matches(try 1 catch { case _ => 2 }, try patternHole[Int] catch { case _ => patternHole[Int] })
136136
matches(try 1 finally 2, try patternHole[Int] finally patternHole[Int])
137+
matches(List(1, 2, 3).foreach(x => println(x)), { @patternType type T; patternHole[List[Int]].foreach[T](patternHole[Int => T]) })
138+
matches(List(1, 2, 3).foreach(x => println(x)), { @patternType type T = Unit; patternHole[List[Int]].foreach[T](patternHole[Int => T]) })
139+
matches(List(1, 2, 3).foreach(x => println(x)), { @patternType type T <: String; patternHole[List[Int]].foreach[T](patternHole[Int => T]) })
140+
matches({ val a: Int = 4; val b: Int = 4 }, { @patternType type T; { val a: T = patternHole[T]; val b: T = patternHole[T] } })
141+
matches({ val a: Int = 4; val b: Int = 5 }, { @patternType type T; { val a: T = patternHole[T]; val b: T = patternHole[T] } })
142+
matches({ val a: Int = 4; val b: String = "x" }, { @patternType type T; { val a: T = patternHole[T]; val b: T = patternHole[T] } })
143+
matches({ val a: Int = 4; val b: String = "x" }, { @patternType type T <: Int; { val a: T = patternHole[T]; val b: T = patternHole[T] } })
144+
matches(List(1, 2, 3).map(x => x.toDouble / 2).map(y => y.toString), { @patternType type T; @patternType type U; @patternType type V; patternHole[List[T]].map(patternHole[T => U]).map(patternHole[U => V]) })
145+
matches((x: Int) => x, { @patternType type T; patternHole[T => T] })
146+
matches((x: Int) => x.toString, { @patternType type T; patternHole[T => T] })
147+
matches((x: Any) => ???, { @patternType type T; patternHole[T => T] })
148+
matches((x: Nothing) => (1 : Any), { @patternType type T; patternHole[T => T] })
137149

138150
}
139151
}

0 commit comments

Comments
 (0)