Skip to content

Commit a736592

Browse files
committed
Make the expandion of context bounds for poly types slightly more elegant
1 parent 5f0d4a7 commit a736592

File tree

3 files changed

+75
-33
lines changed

3 files changed

+75
-33
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+30-25
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,7 @@ object desugar {
527527
makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span)
528528

529529
if meth.hasAttachment(PolyFunctionApply) then
530-
meth.removeAttachment(PolyFunctionApply)
531-
// (kπ): deffer this until we can type the result?
530+
// meth.removeAttachment(PolyFunctionApply)
532531
if ctx.mode.is(Mode.Type) then
533532
cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params))
534533
else
@@ -1250,29 +1249,35 @@ object desugar {
12501249
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
12511250
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
12521251
*/
1253-
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
1254-
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1255-
val paramFlags = fun match
1256-
case fun: FunctionWithMods =>
1257-
// TODO: make use of this in the desugaring when pureFuns is enabled.
1258-
// val isImpure = funFlags.is(Impure)
1259-
1260-
// Function flags to be propagated to each parameter in the desugared method type.
1261-
val givenFlag = fun.mods.flags.toTermFlags & Given
1262-
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1263-
case _ =>
1264-
vparamTypes.map(_ => EmptyFlags)
1265-
1266-
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1267-
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1268-
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1269-
}.toList
1270-
1271-
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1272-
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
1273-
.withFlags(Synthetic)
1274-
.withAttachment(PolyFunctionApply, List.empty)
1275-
)).withSpan(tree.span)
1252+
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = tree match
1253+
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
1254+
val paramFlags = fun match
1255+
case fun: FunctionWithMods =>
1256+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1257+
// val isImpure = funFlags.is(Impure)
1258+
1259+
// Function flags to be propagated to each parameter in the desugared method type.
1260+
val givenFlag = fun.mods.flags.toTermFlags & Given
1261+
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1262+
case _ =>
1263+
vparamTypes.map(_ => EmptyFlags)
1264+
1265+
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1266+
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1267+
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1268+
}.toList
1269+
1270+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1271+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
1272+
.withFlags(Synthetic)
1273+
.withAttachment(PolyFunctionApply, List.empty)
1274+
)).withSpan(tree.span)
1275+
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, res) =>
1276+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1277+
DefDef(nme.apply, tparams :: Nil, res, EmptyTree)
1278+
.withFlags(Synthetic)
1279+
.withAttachment(PolyFunctionApply, List.empty)
1280+
)).withSpan(tree.span)
12761281
end makePolyFunctionType
12771282

12781283
/** Invent a name for an anonympus given of type or template `impl`. */

compiler/src/dotty/tools/dotc/typer/Typer.scala

+44-7
Original file line numberDiff line numberDiff line change
@@ -3598,14 +3598,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
35983598

35993599
private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match {
36003600
case tpe: MethodType =>
3601-
MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3601+
tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
36023602
case tpe: PolyType =>
3603-
PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3603+
tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
36043604
case tpe: RefinedType =>
3605-
// TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3606-
RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3605+
tpe.derivedRefinedType(
3606+
pushDownDeferredEvidenceParams(tpe.parent, params, span),
3607+
tpe.refinedName,
3608+
pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
3609+
)
36073610
case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3608-
AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3611+
tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
36093612
case tpe =>
36103613
val paramNames = params.map(_.name)
36113614
val paramTpts = params.map(_.tpt)
@@ -3614,18 +3617,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
36143617
typed(ctxFunction).tpe
36153618
}
36163619

3617-
private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
3620+
private def extractTopMethodTermParams(tpe: Type)(using Context): (List[TermName], List[Type]) = tpe match {
3621+
case tpe: MethodType =>
3622+
tpe.paramNames -> tpe.paramInfos
3623+
case tpe: RefinedType if defn.isFunctionType(tpe.parent) =>
3624+
extractTopMethodTermParams(tpe.refinedInfo)
3625+
case _ =>
3626+
Nil -> Nil
3627+
}
3628+
3629+
private def removeTopMethodTermParams(tpe: Type)(using Context): Type = tpe match {
3630+
case tpe: MethodType =>
3631+
tpe.resultType
3632+
case tpe: RefinedType if defn.isFunctionType(tpe.parent) =>
3633+
tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo))
3634+
case tpe: AppliedType if defn.isFunctionType(tpe) =>
3635+
tpe.args.last
3636+
case _ =>
3637+
tpe
3638+
}
3639+
3640+
private def healToPolyFunctionType(tree: Tree)(using Context): Tree = tree match {
3641+
case defdef: DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam))) && defdef.paramss.size == 1 =>
3642+
val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe)
3643+
val newTpe = removeTopMethodTermParams(defdef.tpt.tpe)
3644+
val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef(name, TypeTree(tpe), flags = SyntheticTermParam))
3645+
val newDefDef = cpy.DefDef(defdef)(paramss = defdef.paramss ++ List(newParams), tpt = untpd.TypeTree(newTpe))
3646+
val nestedCtx = ctx.fresh.setNewTyperState()
3647+
typed(newDefDef)(using nestedCtx)
3648+
case _ => tree
3649+
}
3650+
3651+
private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
36183652
tree.getAttachment(desugar.PolyFunctionApply) match
36193653
case Some(params) if params.nonEmpty =>
36203654
tree.removeAttachment(desugar.PolyFunctionApply)
36213655
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
36223656
TypeTree(tpe).withSpan(tree.span) -> tpe
3657+
// case Some(params) if params.isEmpty =>
3658+
// println(s"tree: $tree")
3659+
// healToPolyFunctionType(tree) -> pt
36233660
case _ => tree -> pt
36243661
}
36253662

36263663
/** Interpolate and simplify the type of the given tree. */
36273664
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
3628-
val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt)
3665+
val (tree1, pt1) = addDeferredEvidenceParams(tree, pt)
36293666
if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
36303667
if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
36313668
|| tree1.isDef // ... unless tree is a definition

tests/pos/contextbounds-for-poly-functions.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type CmpWeak[X] = X => Boolean
3232
type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
3333
val less4_0: [X: Ord] => X => X => Boolean =
3434
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
35-
val less4: Comparer2Weak =
35+
val less4_1: Comparer2Weak =
3636
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
3737

3838
val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

0 commit comments

Comments
 (0)