Skip to content

Commit 5f0d4a7

Browse files
committed
Add support for some type aliases, when expanding context bounds for poly functions
1 parent 5196efd commit 5f0d4a7

File tree

3 files changed

+64
-27
lines changed

3 files changed

+64
-27
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+17-9
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ object desugar {
5555
/** An attachment key to indicate that a DefDef is a poly function apply
5656
* method definition.
5757
*/
58-
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()
58+
val PolyFunctionApply: Property.Key[List[ValDef]] = Property.StickyKey()
5959

6060
/** What static check should be applied to a Match? */
6161
enum MatchCheck {
@@ -514,17 +514,25 @@ object desugar {
514514
case Nil =>
515515
params :: Nil
516516

517+
// TODO(kπ) is this enough? SHould this be a TreeTraverse-thing?
518+
def pushDownEvidenceParams(tree: Tree): Tree = tree match
519+
case Function(params, body) =>
520+
cpy.Function(tree)(params, pushDownEvidenceParams(body))
521+
case Block(stats, expr) =>
522+
cpy.Block(tree)(stats, pushDownEvidenceParams(expr))
523+
case tree =>
524+
val paramTpts = params.map(_.tpt)
525+
val paramNames = params.map(_.name)
526+
val paramsErased = params.map(_.mods.flags.is(Erased))
527+
makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span)
528+
517529
if meth.hasAttachment(PolyFunctionApply) then
518530
meth.removeAttachment(PolyFunctionApply)
519-
val paramTpts = params.map(_.tpt)
520-
val paramNames = params.map(_.name)
521-
val paramsErased = params.map(_.mods.flags.is(Erased))
531+
// (kπ): deffer this until we can type the result?
522532
if ctx.mode.is(Mode.Type) then
523-
val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.tpt, paramsErased)
524-
cpy.DefDef(meth)(tpt = ctxFunction)
533+
cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params))
525534
else
526-
val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.rhs, paramsErased)
527-
cpy.DefDef(meth)(rhs = ctxFunction)
535+
cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs))
528536
else
529537
cpy.DefDef(meth)(paramss = recur(meth.paramss))
530538
end addEvidenceParams
@@ -1263,7 +1271,7 @@ object desugar {
12631271
RefinedTypeTree(ref(defn.PolyFunctionType), List(
12641272
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
12651273
.withFlags(Synthetic)
1266-
.withAttachment(PolyFunctionApply, ())
1274+
.withAttachment(PolyFunctionApply, List.empty)
12671275
)).withSpan(tree.span)
12681276
end makePolyFunctionType
12691277

compiler/src/dotty/tools/dotc/typer/Typer.scala

+41-15
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ import config.MigrationVersion
5353
import transform.CheckUnused.OriginalName
5454

5555
import scala.annotation.constructorOnly
56-
import dotty.tools.dotc.ast.desugar.PolyFunctionApply
5756

5857
object Typer {
5958

@@ -1951,7 +1950,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19511950
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
19521951
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
19531952
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
1954-
defdef.putAttachment(PolyFunctionApply, ())
1953+
defdef.putAttachment(desugar.PolyFunctionApply, List.empty)
19551954
typed(desugared, pt)
19561955
else
19571956
val msg =
@@ -1960,7 +1959,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19601959
errorTree(EmptyTree, msg, tree.srcPos)
19611960
case _ =>
19621961
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span)
1963-
defdef.putAttachment(PolyFunctionApply, ())
1962+
defdef.putAttachment(desugar.PolyFunctionApply, List.empty)
19641963
typed(desugared, pt)
19651964
end typedPolyFunctionValue
19661965

@@ -3588,30 +3587,57 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
35883587
case xtree => typedUnnamed(xtree)
35893588

35903589
val unsimplifiedType = result.tpe
3591-
simplify(result, pt, locked)
3592-
result.tpe.stripTypeVar match
3590+
val result1 = simplify(result, pt, locked)
3591+
result1.tpe.stripTypeVar match
35933592
case e: ErrorType if !unsimplifiedType.isErroneous => errorTree(xtree, e.msg, xtree.srcPos)
3594-
case _ => result
3593+
case _ => result1
35953594
catch case ex: TypeError =>
35963595
handleTypeError(ex)
35973596
}
35983597
}
35993598

3599+
private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match {
3600+
case tpe: MethodType =>
3601+
MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3602+
case tpe: PolyType =>
3603+
PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3604+
case tpe: RefinedType =>
3605+
// TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3606+
RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3607+
case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3608+
AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3609+
case tpe =>
3610+
val paramNames = params.map(_.name)
3611+
val paramTpts = params.map(_.tpt)
3612+
val paramsErased = params.map(_.mods.flags.is(Erased))
3613+
val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span)
3614+
typed(ctxFunction).tpe
3615+
}
3616+
3617+
private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
3618+
tree.getAttachment(desugar.PolyFunctionApply) match
3619+
case Some(params) if params.nonEmpty =>
3620+
tree.removeAttachment(desugar.PolyFunctionApply)
3621+
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
3622+
TypeTree(tpe).withSpan(tree.span) -> tpe
3623+
case _ => tree -> pt
3624+
}
3625+
36003626
/** Interpolate and simplify the type of the given tree. */
3601-
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type =
3602-
if !tree.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
3603-
if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
3604-
|| tree.isDef // ... unless tree is a definition
3627+
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
3628+
val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt)
3629+
if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
3630+
if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
3631+
|| tree1.isDef // ... unless tree is a definition
36053632
then
3606-
interpolateTypeVars(tree, pt, locked)
3607-
val simplified = tree.tpe.simplified
3608-
if !MatchType.thatReducesUsingGadt(tree.tpe) then // needs a GADT cast. i15743
3633+
interpolateTypeVars(tree1, pt1, locked)
3634+
val simplified = tree1.tpe.simplified
3635+
if !MatchType.thatReducesUsingGadt(tree1.tpe) then // needs a GADT cast. i15743
36093636
tree.overwriteType(simplified)
3610-
tree
3637+
tree1
36113638

36123639
protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = {
36133640
val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked
3614-
println(i"make contextual function $tree / $pt")
36153641
val paramNamesOrNil = pt match
36163642
case RefinedType(_, _, rinfo: MethodType) => rinfo.paramNames
36173643
case _ => Nil

tests/pos/contextbounds-for-poly-functions.scala

+6-3
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
2828
// type Comparer2 = [X: Ord] => Cmp[X]
2929
// val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
3030

31-
// type CmpWeak[X] = (x: X, y: X) => Boolean
32-
// type Comparer2Weak = [X: Ord] => (x: X) => CmpWeak[X]
33-
// val less4: Comparer2Weak = [X: Ord] => (x: X) => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
31+
type CmpWeak[X] = X => Boolean
32+
type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
33+
val less4_0: [X: Ord] => X => X => Boolean =
34+
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
35+
val less4: Comparer2Weak =
36+
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
3437

3538
val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
3639

0 commit comments

Comments
 (0)