Skip to content

Commit 35ed5dd

Browse files
committed
generate defs for given patterns
1 parent a956774 commit 35ed5dd

File tree

6 files changed

+113
-64
lines changed

6 files changed

+113
-64
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+80-53
Original file line numberDiff line numberDiff line change
@@ -1087,62 +1087,89 @@ object desugar {
10871087
* If the original pattern variable carries a type annotation, so does the corresponding
10881088
* ValDef or DefDef.
10891089
*/
1090-
def makePatDef(original: Tree, mods: Modifiers, pat: Tree, rhs: Tree)(using Context): Tree = pat match {
1091-
case IdPattern(named, tpt) =>
1092-
derivedValDef(original, named, tpt, rhs, mods)
1093-
case _ =>
1094-
def isTuplePattern(arity: Int): Boolean = pat match {
1095-
case Tuple(pats) if pats.size == arity =>
1096-
pats.forall(isVarPattern)
1097-
case _ => false
1098-
}
1099-
val isMatchingTuple: Tree => Boolean = {
1100-
case Tuple(es) => isTuplePattern(es.length)
1101-
case _ => false
1102-
}
1090+
def makePatDef(original: Tree, mods: Modifiers, pat: Tree, rhs: Tree)(using Context): Tree =
11031091

1104-
// We can only optimize `val pat = if (...) e1 else e2` if:
1105-
// - `e1` and `e2` are both tuples of arity N
1106-
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
1107-
val tupleOptimizable = forallResults(rhs, isMatchingTuple)
1092+
def singleGivenPattern(named: Bind) =
1093+
val Bind(_, Typed(_, tpt)) = named: @unchecked
1094+
report.error(
1095+
em"""please use an alias given, such as
1096+
|${hl("given")} ${tpt} = $rhs""".stripMargin, original)
1097+
derivedValDef(original, Ident(named.name).withSpan(pat.span), tpt, rhs, mods | Given | Lazy | Final)
1098+
1099+
pat match {
1100+
case IdPattern(named, tpt) =>
1101+
derivedValDef(original, named, tpt, rhs, mods)
1102+
case Parens(named: Bind) if named.mods.is(Given) =>
1103+
singleGivenPattern(named)
1104+
case named: Bind if named.mods.is(Given) =>
1105+
singleGivenPattern(named)
1106+
case _ =>
1107+
def isTuplePattern(arity: Int): Boolean = pat match {
1108+
case Tuple(pats) if pats.size == arity =>
1109+
pats.forall(pat => isVarPattern(pat) || isGivenPattern(pat))
1110+
case _ => false
1111+
}
1112+
val isMatchingTuple: Tree => Boolean = {
1113+
case Tuple(es) => isTuplePattern(es.length)
1114+
case _ => false
1115+
}
11081116

1109-
val vars =
1110-
if (tupleOptimizable) // include `_`
1111-
pat match {
1112-
case Tuple(pats) =>
1113-
pats.map { case id: Ident => id -> TypeTree() }
1117+
// We can only optimize `val pat = if (...) e1 else e2` if:
1118+
// - `e1` and `e2` are both tuples of arity N
1119+
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
1120+
val tupleOptimizable = forallResults(rhs, isMatchingTuple)
1121+
1122+
val vars =
1123+
if (tupleOptimizable) // include `_`
1124+
pat match {
1125+
case Tuple(pats) =>
1126+
pats.map {
1127+
case id: Ident => id -> TypeTree()
1128+
case bind @ Bind(_, Typed(_, tpt)) => bind -> tpt
1129+
}
1130+
}
1131+
else
1132+
getVariables(pat) // no `_`
1133+
1134+
val ids = for ((named, _) <- vars) yield Ident(named.name)
1135+
val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids))
1136+
val matchExpr =
1137+
if (tupleOptimizable) rhs
1138+
else Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil)
1139+
1140+
def defMods(named: NameTree, mods: Modifiers) =
1141+
named match {
1142+
case named: Bind if named.mods.is(Given) => mods | Given
1143+
case _ => mods
11141144
}
1115-
else getVariables(pat) // no `_`
1116-
1117-
val ids = for ((named, _) <- vars) yield Ident(named.name)
1118-
val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids))
1119-
val matchExpr =
1120-
if (tupleOptimizable) rhs
1121-
else Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil)
1122-
vars match {
1123-
case Nil if !mods.is(Lazy) =>
1124-
matchExpr
1125-
case (named, tpt) :: Nil =>
1126-
derivedValDef(original, named, tpt, matchExpr, mods)
1127-
case _ =>
1128-
val tmpName = UniqueName.fresh()
1129-
val patMods =
1130-
mods & Lazy | Synthetic | (if (ctx.owner.isClass) PrivateLocal else EmptyFlags)
1131-
val firstDef =
1132-
ValDef(tmpName, TypeTree(), matchExpr)
1133-
.withSpan(pat.span.union(rhs.span)).withMods(patMods)
1134-
val useSelectors = vars.length <= 22
1135-
def selector(n: Int) =
1136-
if useSelectors then Select(Ident(tmpName), nme.selectorName(n))
1137-
else Apply(Select(Ident(tmpName), nme.apply), Literal(Constant(n)) :: Nil)
1138-
val restDefs =
1139-
for (((named, tpt), n) <- vars.zipWithIndex if named.name != nme.WILDCARD)
1140-
yield
1141-
if (mods.is(Lazy)) derivedDefDef(original, named, tpt, selector(n), mods &~ Lazy)
1142-
else derivedValDef(original, named, tpt, selector(n), mods)
1143-
flatTree(firstDef :: restDefs)
1144-
}
1145-
}
1145+
1146+
vars match {
1147+
case Nil if !mods.is(Lazy) =>
1148+
matchExpr
1149+
case (named, tpt) :: Nil =>
1150+
derivedValDef(original, named, tpt, matchExpr, defMods(named, mods))
1151+
case _ =>
1152+
val tmpName = UniqueName.fresh()
1153+
val patMods =
1154+
mods & Lazy | Synthetic | (if (ctx.owner.isClass) PrivateLocal else EmptyFlags)
1155+
val firstDef =
1156+
ValDef(tmpName, TypeTree(), matchExpr)
1157+
.withSpan(pat.span.union(rhs.span)).withMods(patMods)
1158+
val useSelectors = vars.length <= 22
1159+
def selector(n: Int) =
1160+
if useSelectors then Select(Ident(tmpName), nme.selectorName(n))
1161+
else Apply(Select(Ident(tmpName), nme.apply), Literal(Constant(n)) :: Nil)
1162+
val restDefs =
1163+
for (((named, tpt), n) <- vars.zipWithIndex if named.name != nme.WILDCARD)
1164+
yield
1165+
val mods1 = defMods(named, mods)
1166+
if (mods1.is(Lazy)) derivedDefDef(original, named, tpt, selector(n), mods1 &~ Lazy)
1167+
else derivedValDef(original, named, tpt, selector(n), mods1)
1168+
flatTree(firstDef :: restDefs)
1169+
}
1170+
}
1171+
1172+
end makePatDef
11461173

11471174
/** Expand variable identifier x to x @ _ */
11481175
def patternVar(tree: Tree)(using Context): Bind = {

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

+5
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
164164
case _ => false
165165
}
166166

167+
def isGivenPattern(pat: Tree): Boolean = unsplice(pat) match {
168+
case x: Bind => x.mods.is(Given) && x.name.isVarPattern && !isBackquoted(x)
169+
case _ => false
170+
}
171+
167172
/** The first constructor definition in `stats` */
168173
def firstConstructor(stats: List[Tree]): Tree = stats match {
169174
case (meth: DefDef) :: _ if meth.name.isConstructorName => meth

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import Contexts._
1515
import Names._
1616
import NameKinds.WildcardParamName
1717
import NameOps._
18-
import ast.{Positioned, Trees}
18+
import ast.{Positioned, Trees, desugar}
1919
import ast.Trees._
2020
import typer.ImportInfo
2121
import StdNames._
@@ -2703,8 +2703,9 @@ object Parsers {
27032703
case GIVEN =>
27042704
atSpan(in.offset) {
27052705
val givenMod = atSpan(in.skipToken())(Mod.Given())
2706-
val typed = Typed(Ident(nme.WILDCARD), refinedType())
2707-
Bind(nme.WILDCARD, typed).withMods(addMod(Modifiers(), givenMod))
2706+
val tpt = refinedType()
2707+
val typed = Typed(Ident(nme.WILDCARD), tpt)
2708+
Bind(desugar.inventGivenOrExtensionName(tpt), typed).withMods(addMod(Modifiers(), givenMod))
27082709
}
27092710
case _ =>
27102711
if (isLiteral) literal(inPattern = true)

compiler/src/dotty/tools/dotc/typer/Typer.scala

+3-8
Original file line numberDiff line numberDiff line change
@@ -2006,12 +2006,7 @@ class Typer extends Namer
20062006
tpd.cpy.UnApply(body1)(fn, Nil,
20072007
typed(untpd.Bind(tree.name, untpd.TypedSplice(arg)).withSpan(tree.span), arg.tpe) :: Nil)
20082008
case _ =>
2009-
var name = tree.name
2010-
if (name == nme.WILDCARD && tree.mods.is(Given)) {
2011-
val Typed(_, tpt) = tree.body: @unchecked
2012-
name = desugar.inventGivenOrExtensionName(tpt)
2013-
}
2014-
if (name == nme.WILDCARD) body1
2009+
if (tree.name == nme.WILDCARD) body1
20152010
else {
20162011
// In `x @ Nil`, `Nil` is a _stable identifier pattern_ and will be compiled
20172012
// to an `==` test, so the type of `x` is unrelated to the type of `Nil`.
@@ -2037,11 +2032,11 @@ class Typer extends Namer
20372032
// See also #5649.
20382033
then body1.tpe
20392034
else pt & body1.tpe
2040-
val sym = newPatternBoundSymbol(name, symTp, tree.span)
2035+
val sym = newPatternBoundSymbol(tree.name, symTp, tree.span)
20412036
if (pt == defn.ImplicitScrutineeTypeRef || tree.mods.is(Given)) sym.setFlag(Given)
20422037
if (ctx.mode.is(Mode.InPatternAlternative))
20432038
report.error(i"Illegal variable ${sym.name} in pattern alternative", tree.srcPos)
2044-
assignType(cpy.Bind(tree)(name, body1), sym)
2039+
assignType(cpy.Bind(tree)(tree.name, body1), sym)
20452040
}
20462041
}
20472042
}

tests/neg/i11897.scala

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
case class B(b: Boolean)
2+
3+
def test =
4+
val (given B) = B(false) // error
5+
val given B = B(false) // error

tests/run/i11897.scala

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
case class A(i: Int)
2+
case class B(b: Boolean)
3+
case class C(s: String)
4+
case class D(c: C)
5+
case class E(i: Int)
6+
case class F(i: Int, e: E)
7+
8+
@main def Test =
9+
val (x, given A) = (1, A(23))
10+
val (_, given B) = (true, B(false))
11+
val D(given C) = D(C("c"))
12+
val F(y, given E) = F(47, E(93))
13+
assert(summon[A] == A(23))
14+
assert(summon[B] == B(false))
15+
assert(summon[C] == C("c"))
16+
assert(summon[E] == E(93))

0 commit comments

Comments
 (0)