Skip to content

Commit 059c539

Browse files
Backport "Context Bounds for Polymorphic Functions" to 3.6 (#21972)
Backports #21643 to the 3.6.2. PR submitted by the release tooling.
2 parents 8615aa2 + 386d83d commit 059c539

File tree

6 files changed

+274
-68
lines changed

6 files changed

+274
-68
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+139-63
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ object desugar {
5252
*/
5353
val ContextBoundParam: Property.Key[Unit] = Property.StickyKey()
5454

55+
/** Marks a poly fcuntion apply method, so that we can handle adding evidence parameters to them in a special way
56+
*/
57+
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()
58+
5559
/** What static check should be applied to a Match? */
5660
enum MatchCheck {
5761
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
@@ -242,7 +246,7 @@ object desugar {
242246
* def f$default$2[T](x: Int) = x + "m"
243247
*/
244248
private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(using Context): Tree =
245-
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor))
249+
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor).asInstanceOf[DefDef])
246250

247251
/** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that
248252
* get added to a buffer.
@@ -304,10 +308,8 @@ object desugar {
304308
tdef1
305309
end desugarContextBounds
306310

307-
private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef =
308-
val DefDef(_, paramss, tpt, rhs) = meth
311+
def elimContextBounds(meth: Tree, isPrimaryConstructor: Boolean = false)(using Context): Tree =
309312
val evidenceParamBuf = mutable.ListBuffer[ValDef]()
310-
311313
var seenContextBounds: Int = 0
312314
def freshName(unused: Tree) =
313315
seenContextBounds += 1 // Start at 1 like FreshNameCreator.
@@ -317,7 +319,7 @@ object desugar {
317319
// parameters of the method since shadowing does not affect
318320
// implicit resolution in Scala 3.
319321

320-
val paramssNoContextBounds =
322+
def paramssNoContextBounds(paramss: List[ParamClause]): List[ParamClause] =
321323
val iflag = paramss.lastOption.flatMap(_.headOption) match
322324
case Some(param) if param.mods.isOneOf(GivenOrImplicit) =>
323325
param.mods.flags & GivenOrImplicit
@@ -329,15 +331,32 @@ object desugar {
329331
tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss)
330332
}(identity)
331333

332-
rhs match
333-
case MacroTree(call) =>
334-
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
335-
case _ =>
336-
addEvidenceParams(
337-
cpy.DefDef(meth)(
338-
name = normalizeName(meth, tpt).asTermName,
339-
paramss = paramssNoContextBounds),
340-
evidenceParamBuf.toList)
334+
meth match
335+
case meth @ DefDef(_, paramss, tpt, rhs) =>
336+
val newParamss = paramssNoContextBounds(paramss)
337+
rhs match
338+
case MacroTree(call) =>
339+
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
340+
case _ =>
341+
addEvidenceParams(
342+
cpy.DefDef(meth)(
343+
name = normalizeName(meth, tpt).asTermName,
344+
paramss = newParamss
345+
),
346+
evidenceParamBuf.toList
347+
)
348+
case meth @ PolyFunction(tparams, fun) =>
349+
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = meth: @unchecked
350+
val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked
351+
val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil)
352+
val params = evidenceParamBuf.toList
353+
if params.isEmpty then
354+
meth
355+
else
356+
val boundNames = getBoundNames(params, newParamss)
357+
val recur = fitEvidenceParams(params, nme.apply, boundNames)
358+
val (paramsFst, paramsSnd) = recur(newParamss)
359+
functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs)
341360
end elimContextBounds
342361

343362
def addDefaultGetters(meth: DefDef)(using Context): Tree =
@@ -465,6 +484,74 @@ object desugar {
465484
case _ =>
466485
(Nil, tree)
467486

487+
private def referencesName(vdef: ValDef, names: Set[TermName])(using Context): Boolean =
488+
vdef.tpt.existsSubTree:
489+
case Ident(name: TermName) => names.contains(name)
490+
case _ => false
491+
492+
/** Fit evidence `params` into the `mparamss` parameter lists, making sure
493+
* that all parameters referencing `params` are after them.
494+
* - for methods the final parameter lists are := result._1 ++ result._2
495+
* - for poly functions, each element of the pair contains at most one term
496+
* parameter list
497+
*
498+
* @param params the evidence parameters list that should fit into `mparamss`
499+
* @param methName the name of the method that `mparamss` belongs to
500+
* @param boundNames the names of the evidence parameters
501+
* @param mparamss the original parameter lists of the method
502+
* @return a pair of parameter lists containing all parameter lists in a
503+
* reference-correct order; make sure that `params` is always at the
504+
* intersection of the pair elements; this is relevant, for poly functions
505+
* where `mparamss` is guaranteed to have exectly one term parameter list,
506+
* then each pair element will have at most one term parameter list
507+
*/
508+
private def fitEvidenceParams(
509+
params: List[ValDef],
510+
methName: Name,
511+
boundNames: Set[TermName]
512+
)(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match
513+
case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) =>
514+
(params :: Nil) -> mparamss
515+
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
516+
val normParams =
517+
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
518+
params.map: param =>
519+
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
520+
param.withMods(param.mods.withFlags(normFlags))
521+
.showing(i"adapted param $result ${result.mods.flags} for ${methName}", Printers.desugar)
522+
else params
523+
((normParams ++ mparams) :: Nil) -> Nil
524+
case mparams :: mparamss1 =>
525+
val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1)
526+
(mparams :: fst) -> snd
527+
case Nil =>
528+
Nil -> (params :: Nil)
529+
530+
/** Create a chain of possibly contextual functions from the parameter lists */
531+
private def functionsOf(paramss: List[ParamClause], rhs: Tree)(using Context): Tree = paramss match
532+
case Nil => rhs
533+
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
534+
val paramTpts = head.map(_.tpt)
535+
val paramNames = head.map(_.name)
536+
val paramsErased = head.map(_.mods.flags.is(Erased))
537+
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
538+
case ValDefs(head) :: rest =>
539+
Function(head, functionsOf(rest, rhs))
540+
case TypeDefs(head) :: rest =>
541+
PolyFunction(head, functionsOf(rest, rhs))
542+
case _ =>
543+
assert(false, i"unexpected paramss $paramss")
544+
EmptyTree
545+
546+
private def getBoundNames(params: List[ValDef], paramss: List[ParamClause])(using Context): Set[TermName] =
547+
var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
548+
for mparams <- paramss; mparam <- mparams do
549+
mparam match
550+
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
551+
boundNames += tparam.name.toTermName
552+
case _ =>
553+
boundNames
554+
468555
/** Add all evidence parameters in `params` as implicit parameters to `meth`.
469556
* The position of the added parameters is determined as follows:
470557
*
@@ -479,36 +566,23 @@ object desugar {
479566
private def addEvidenceParams(meth: DefDef, params: List[ValDef])(using Context): DefDef =
480567
if params.isEmpty then return meth
481568

482-
var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
483-
for mparams <- meth.paramss; mparam <- mparams do
484-
mparam match
485-
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
486-
boundNames += tparam.name.toTermName
487-
case _ =>
569+
val boundNames = getBoundNames(params, meth.paramss)
488570

489-
def referencesBoundName(vdef: ValDef): Boolean =
490-
vdef.tpt.existsSubTree:
491-
case Ident(name: TermName) => boundNames.contains(name)
492-
case _ => false
571+
val fitParams = fitEvidenceParams(params, meth.name, boundNames)
493572

494-
def recur(mparamss: List[ParamClause]): List[ParamClause] = mparamss match
495-
case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) =>
496-
params :: mparamss
497-
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
498-
val normParams =
499-
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
500-
params.map: param =>
501-
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
502-
param.withMods(param.mods.withFlags(normFlags))
503-
.showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar)
504-
else params
505-
(normParams ++ mparams) :: Nil
506-
case mparams :: mparamss1 =>
507-
mparams :: recur(mparamss1)
508-
case Nil =>
509-
params :: Nil
510-
511-
cpy.DefDef(meth)(paramss = recur(meth.paramss))
573+
if meth.removeAttachment(PolyFunctionApply).isDefined then
574+
// for PolyFunctions we are limited to a single term param list, so we
575+
// reuse the fitEvidenceParams logic to compute the new parameter lists
576+
// and then we add the other parameter lists as function types to the
577+
// return type
578+
val (paramsFst, paramsSnd) = fitParams(meth.paramss)
579+
if ctx.mode.is(Mode.Type) then
580+
cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt))
581+
else
582+
cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs))
583+
else
584+
val (paramsFst, paramsSnd) = fitParams(meth.paramss)
585+
cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd)
512586
end addEvidenceParams
513587

514588
/** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */
@@ -1224,27 +1298,29 @@ object desugar {
12241298
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
12251299
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
12261300
*/
1227-
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
1228-
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1229-
val paramFlags = fun match
1230-
case fun: FunctionWithMods =>
1231-
// TODO: make use of this in the desugaring when pureFuns is enabled.
1232-
// val isImpure = funFlags.is(Impure)
1233-
1234-
// Function flags to be propagated to each parameter in the desugared method type.
1235-
val givenFlag = fun.mods.flags.toTermFlags & Given
1236-
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1237-
case _ =>
1238-
vparamTypes.map(_ => EmptyFlags)
1239-
1240-
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1241-
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1242-
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1243-
}.toList
1244-
1245-
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1246-
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
1247-
)).withSpan(tree.span)
1301+
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = (tree: @unchecked) match
1302+
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
1303+
val paramFlags = fun match
1304+
case fun: FunctionWithMods =>
1305+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1306+
// val isImpure = funFlags.is(Impure)
1307+
1308+
// Function flags to be propagated to each parameter in the desugared method type.
1309+
val givenFlag = fun.mods.flags.toTermFlags & Given
1310+
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1311+
case _ =>
1312+
vparamTypes.map(_ => EmptyFlags)
1313+
1314+
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1315+
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1316+
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1317+
}.toList
1318+
1319+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1320+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
1321+
.withFlags(Synthetic)
1322+
.withAttachment(PolyFunctionApply, ())
1323+
)).withSpan(tree.span)
12481324
end makePolyFunctionType
12491325

12501326
/** Invent a name for an anonympus given of type or template `impl`. */

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -3460,7 +3460,7 @@ object Parsers {
34603460
*
34613461
* TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’
34623462
* TypTypeParam ::= {Annotation}
3463-
* (id | ‘_’) [HkTypeParamClause] TypeBounds
3463+
* (id | ‘_’) [HkTypeParamClause] TypeAndCtxBounds
34643464
*
34653465
* HkTypeParamClause ::= ‘[’ HkTypeParam {‘,’ HkTypeParam} ‘]’
34663466
* HkTypeParam ::= {Annotation} [‘+’ | ‘-’]
@@ -3491,7 +3491,9 @@ object Parsers {
34913491
else ident().toTypeName
34923492
val hkparams = typeParamClauseOpt(ParamOwner.Hk)
34933493
val bounds =
3494-
if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name) else typeBounds()
3494+
if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name)
3495+
else if in.featureEnabled(Feature.modularity) && paramOwner == ParamOwner.Type then typeAndCtxBounds(name)
3496+
else typeBounds()
34953497
TypeDef(name, lambdaAbstract(hkparams, bounds)).withMods(mods)
34963498
}
34973499
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

+2-3
Original file line numberDiff line numberDiff line change
@@ -1917,7 +1917,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19171917
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19181918
val tree1 = desugar.normalizePolyFunction(tree)
19191919
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
1920-
else typedPolyFunctionValue(tree1, pt)
1920+
else typedPolyFunctionValue(desugar.elimContextBounds(tree1).asInstanceOf[untpd.PolyFunction], pt)
19211921

19221922
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19231923
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
@@ -2471,7 +2471,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24712471
val TypeDef(_, impl: Template) = typed(refineClsDef): @unchecked
24722472
val refinements1 = impl.body
24732473
val seen = mutable.Set[Symbol]()
2474-
for (refinement <- refinements1) { // TODO: get clarity whether we want to enforce these conditions
2474+
for refinement <- refinements1 do // TODO: get clarity whether we want to enforce these conditions
24752475
typr.println(s"adding refinement $refinement")
24762476
checkRefinementNonCyclic(refinement, refineCls, seen)
24772477
val rsym = refinement.symbol
@@ -2485,7 +2485,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24852485
val member = refineCls.info.member(rsym.name)
24862486
if (member.isOverloaded)
24872487
report.error(OverloadInRefinement(rsym), refinement.srcPos)
2488-
}
24892488
assignType(cpy.RefinedTypeTree(tree)(tpt1, refinements1), tpt1, refinements1, refineCls)
24902489
}
24912490

0 commit comments

Comments
 (0)