@@ -3598,14 +3598,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
3598
3598
3599
3599
private def pushDownDeferredEvidenceParams (tpe : Type , params : List [untpd.ValDef ], span : Span )(using Context ): Type = tpe.dealias match {
3600
3600
case tpe : MethodType =>
3601
- MethodType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3601
+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3602
3602
case tpe : PolyType =>
3603
- PolyType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3603
+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3604
3604
case tpe : RefinedType =>
3605
- // TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3606
- RefinedType (pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3605
+ tpe.derivedRefinedType(
3606
+ pushDownDeferredEvidenceParams(tpe.parent, params, span),
3607
+ tpe.refinedName,
3608
+ pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
3609
+ )
3607
3610
case tpe @ AppliedType (tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3608
- AppliedType ( tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3611
+ tpe.derivedAppliedType( tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3609
3612
case tpe =>
3610
3613
val paramNames = params.map(_.name)
3611
3614
val paramTpts = params.map(_.tpt)
@@ -3614,18 +3617,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
3614
3617
typed(ctxFunction).tpe
3615
3618
}
3616
3619
3617
- private def addDownDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
3620
+ private def extractTopMethodTermParams (tpe : Type )(using Context ): (List [TermName ], List [Type ]) = tpe match {
3621
+ case tpe : MethodType =>
3622
+ tpe.paramNames -> tpe.paramInfos
3623
+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3624
+ extractTopMethodTermParams(tpe.refinedInfo)
3625
+ case _ =>
3626
+ Nil -> Nil
3627
+ }
3628
+
3629
+ private def removeTopMethodTermParams (tpe : Type )(using Context ): Type = tpe match {
3630
+ case tpe : MethodType =>
3631
+ tpe.resultType
3632
+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3633
+ tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo))
3634
+ case tpe : AppliedType if defn.isFunctionType(tpe) =>
3635
+ tpe.args.last
3636
+ case _ =>
3637
+ tpe
3638
+ }
3639
+
3640
+ private def healToPolyFunctionType (tree : Tree )(using Context ): Tree = tree match {
3641
+ case defdef : DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam ))) && defdef.paramss.size == 1 =>
3642
+ val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe)
3643
+ val newTpe = removeTopMethodTermParams(defdef.tpt.tpe)
3644
+ val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef (name, TypeTree (tpe), flags = SyntheticTermParam ))
3645
+ val newDefDef = cpy.DefDef (defdef)(paramss = defdef.paramss ++ List (newParams), tpt = untpd.TypeTree (newTpe))
3646
+ val nestedCtx = ctx.fresh.setNewTyperState()
3647
+ typed(newDefDef)(using nestedCtx)
3648
+ case _ => tree
3649
+ }
3650
+
3651
+ private def addDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
3618
3652
tree.getAttachment(desugar.PolyFunctionApply ) match
3619
3653
case Some (params) if params.nonEmpty =>
3620
3654
tree.removeAttachment(desugar.PolyFunctionApply )
3621
3655
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
3622
3656
TypeTree (tpe).withSpan(tree.span) -> tpe
3657
+ // case Some(params) if params.isEmpty =>
3658
+ // println(s"tree: $tree")
3659
+ // healToPolyFunctionType(tree) -> pt
3623
3660
case _ => tree -> pt
3624
3661
}
3625
3662
3626
3663
/** Interpolate and simplify the type of the given tree. */
3627
3664
protected def simplify (tree : Tree , pt : Type , locked : TypeVars )(using Context ): Tree =
3628
- val (tree1, pt1) = addDownDeferredEvidenceParams (tree, pt)
3665
+ val (tree1, pt1) = addDeferredEvidenceParams (tree, pt)
3629
3666
if ! tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
3630
3667
if ! tree1.tpe.widen.isInstanceOf [MethodOrPoly ] // wait with simplifying until method is fully applied
3631
3668
|| tree1.isDef // ... unless tree is a definition
0 commit comments