Skip to content

Commit 5611522

Browse files
Context Bounds for Polymorphic Functions (#21643)
Implement the `#6` point form SIP-64 i.e. --- ### 6. Context Bounds for Polymorphic Functions Currently, context bounds can be used in methods, but not in function types or function literals. It would be nice propose to drop this irregularity and allow context bounds also in these places. Example: ```scala type Comparer = [X: Ord] => (x: X, y: X) => Boolean val less: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 ``` The expansion of such context bounds is analogous to the expansion in method types, except that instead of adding a using clause in a method, we insert a context function type. For instance, the `type` and `val` definitions above would expand to ```scala type Comparer = [X] => (x: X, y: X) => Ord[X] ?=> Boolean val less: Comparer = [X] => (x: X, y: X) => (ord: Ord[X]) ?=> ord.compare(x, y) < 0 ``` The expansion of using clauses does look inside alias types. For instance, here is a variation of the previous example that uses a parameterized type alias: ```scala type Cmp[X] = (x: X, y: X) => Boolean type Comparer2 = [X: Ord] => Cmp[X] ``` The expansion of the right hand side of `Comparer2` expands the `Cmp[X]` alias and then inserts the context function at the same place as what's done for `Comparer`.
2 parents cc4a324 + 952eff7 commit 5611522

File tree

6 files changed

+274
-68
lines changed

6 files changed

+274
-68
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+139-63
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ object desugar {
5252
*/
5353
val ContextBoundParam: Property.Key[Unit] = Property.StickyKey()
5454

55+
/** Marks a poly fcuntion apply method, so that we can handle adding evidence parameters to them in a special way
56+
*/
57+
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()
58+
5559
/** What static check should be applied to a Match? */
5660
enum MatchCheck {
5761
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
@@ -242,7 +246,7 @@ object desugar {
242246
* def f$default$2[T](x: Int) = x + "m"
243247
*/
244248
private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(using Context): Tree =
245-
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor))
249+
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor).asInstanceOf[DefDef])
246250

247251
/** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that
248252
* get added to a buffer.
@@ -304,10 +308,8 @@ object desugar {
304308
tdef1
305309
end desugarContextBounds
306310

307-
private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef =
308-
val DefDef(_, paramss, tpt, rhs) = meth
311+
def elimContextBounds(meth: Tree, isPrimaryConstructor: Boolean = false)(using Context): Tree =
309312
val evidenceParamBuf = mutable.ListBuffer[ValDef]()
310-
311313
var seenContextBounds: Int = 0
312314
def freshName(unused: Tree) =
313315
seenContextBounds += 1 // Start at 1 like FreshNameCreator.
@@ -317,7 +319,7 @@ object desugar {
317319
// parameters of the method since shadowing does not affect
318320
// implicit resolution in Scala 3.
319321

320-
val paramssNoContextBounds =
322+
def paramssNoContextBounds(paramss: List[ParamClause]): List[ParamClause] =
321323
val iflag = paramss.lastOption.flatMap(_.headOption) match
322324
case Some(param) if param.mods.isOneOf(GivenOrImplicit) =>
323325
param.mods.flags & GivenOrImplicit
@@ -329,15 +331,32 @@ object desugar {
329331
tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss)
330332
}(identity)
331333

332-
rhs match
333-
case MacroTree(call) =>
334-
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
335-
case _ =>
336-
addEvidenceParams(
337-
cpy.DefDef(meth)(
338-
name = normalizeName(meth, tpt).asTermName,
339-
paramss = paramssNoContextBounds),
340-
evidenceParamBuf.toList)
334+
meth match
335+
case meth @ DefDef(_, paramss, tpt, rhs) =>
336+
val newParamss = paramssNoContextBounds(paramss)
337+
rhs match
338+
case MacroTree(call) =>
339+
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
340+
case _ =>
341+
addEvidenceParams(
342+
cpy.DefDef(meth)(
343+
name = normalizeName(meth, tpt).asTermName,
344+
paramss = newParamss
345+
),
346+
evidenceParamBuf.toList
347+
)
348+
case meth @ PolyFunction(tparams, fun) =>
349+
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = meth: @unchecked
350+
val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked
351+
val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil)
352+
val params = evidenceParamBuf.toList
353+
if params.isEmpty then
354+
meth
355+
else
356+
val boundNames = getBoundNames(params, newParamss)
357+
val recur = fitEvidenceParams(params, nme.apply, boundNames)
358+
val (paramsFst, paramsSnd) = recur(newParamss)
359+
functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs)
341360
end elimContextBounds
342361

343362
def addDefaultGetters(meth: DefDef)(using Context): Tree =
@@ -465,6 +484,74 @@ object desugar {
465484
case _ =>
466485
(Nil, tree)
467486

487+
private def referencesName(vdef: ValDef, names: Set[TermName])(using Context): Boolean =
488+
vdef.tpt.existsSubTree:
489+
case Ident(name: TermName) => names.contains(name)
490+
case _ => false
491+
492+
/** Fit evidence `params` into the `mparamss` parameter lists, making sure
493+
* that all parameters referencing `params` are after them.
494+
* - for methods the final parameter lists are := result._1 ++ result._2
495+
* - for poly functions, each element of the pair contains at most one term
496+
* parameter list
497+
*
498+
* @param params the evidence parameters list that should fit into `mparamss`
499+
* @param methName the name of the method that `mparamss` belongs to
500+
* @param boundNames the names of the evidence parameters
501+
* @param mparamss the original parameter lists of the method
502+
* @return a pair of parameter lists containing all parameter lists in a
503+
* reference-correct order; make sure that `params` is always at the
504+
* intersection of the pair elements; this is relevant, for poly functions
505+
* where `mparamss` is guaranteed to have exectly one term parameter list,
506+
* then each pair element will have at most one term parameter list
507+
*/
508+
private def fitEvidenceParams(
509+
params: List[ValDef],
510+
methName: Name,
511+
boundNames: Set[TermName]
512+
)(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match
513+
case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) =>
514+
(params :: Nil) -> mparamss
515+
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
516+
val normParams =
517+
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
518+
params.map: param =>
519+
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
520+
param.withMods(param.mods.withFlags(normFlags))
521+
.showing(i"adapted param $result ${result.mods.flags} for ${methName}", Printers.desugar)
522+
else params
523+
((normParams ++ mparams) :: Nil) -> Nil
524+
case mparams :: mparamss1 =>
525+
val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1)
526+
(mparams :: fst) -> snd
527+
case Nil =>
528+
Nil -> (params :: Nil)
529+
530+
/** Create a chain of possibly contextual functions from the parameter lists */
531+
private def functionsOf(paramss: List[ParamClause], rhs: Tree)(using Context): Tree = paramss match
532+
case Nil => rhs
533+
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
534+
val paramTpts = head.map(_.tpt)
535+
val paramNames = head.map(_.name)
536+
val paramsErased = head.map(_.mods.flags.is(Erased))
537+
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
538+
case ValDefs(head) :: rest =>
539+
Function(head, functionsOf(rest, rhs))
540+
case TypeDefs(head) :: rest =>
541+
PolyFunction(head, functionsOf(rest, rhs))
542+
case _ =>
543+
assert(false, i"unexpected paramss $paramss")
544+
EmptyTree
545+
546+
private def getBoundNames(params: List[ValDef], paramss: List[ParamClause])(using Context): Set[TermName] =
547+
var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
548+
for mparams <- paramss; mparam <- mparams do
549+
mparam match
550+
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
551+
boundNames += tparam.name.toTermName
552+
case _ =>
553+
boundNames
554+
468555
/** Add all evidence parameters in `params` as implicit parameters to `meth`.
469556
* The position of the added parameters is determined as follows:
470557
*
@@ -479,36 +566,23 @@ object desugar {
479566
private def addEvidenceParams(meth: DefDef, params: List[ValDef])(using Context): DefDef =
480567
if params.isEmpty then return meth
481568

482-
var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
483-
for mparams <- meth.paramss; mparam <- mparams do
484-
mparam match
485-
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
486-
boundNames += tparam.name.toTermName
487-
case _ =>
569+
val boundNames = getBoundNames(params, meth.paramss)
488570

489-
def referencesBoundName(vdef: ValDef): Boolean =
490-
vdef.tpt.existsSubTree:
491-
case Ident(name: TermName) => boundNames.contains(name)
492-
case _ => false
571+
val fitParams = fitEvidenceParams(params, meth.name, boundNames)
493572

494-
def recur(mparamss: List[ParamClause]): List[ParamClause] = mparamss match
495-
case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) =>
496-
params :: mparamss
497-
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
498-
val normParams =
499-
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
500-
params.map: param =>
501-
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
502-
param.withMods(param.mods.withFlags(normFlags))
503-
.showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar)
504-
else params
505-
(normParams ++ mparams) :: Nil
506-
case mparams :: mparamss1 =>
507-
mparams :: recur(mparamss1)
508-
case Nil =>
509-
params :: Nil
510-
511-
cpy.DefDef(meth)(paramss = recur(meth.paramss))
573+
if meth.removeAttachment(PolyFunctionApply).isDefined then
574+
// for PolyFunctions we are limited to a single term param list, so we
575+
// reuse the fitEvidenceParams logic to compute the new parameter lists
576+
// and then we add the other parameter lists as function types to the
577+
// return type
578+
val (paramsFst, paramsSnd) = fitParams(meth.paramss)
579+
if ctx.mode.is(Mode.Type) then
580+
cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt))
581+
else
582+
cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs))
583+
else
584+
val (paramsFst, paramsSnd) = fitParams(meth.paramss)
585+
cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd)
512586
end addEvidenceParams
513587

514588
/** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */
@@ -1224,27 +1298,29 @@ object desugar {
12241298
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
12251299
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
12261300
*/
1227-
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
1228-
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1229-
val paramFlags = fun match
1230-
case fun: FunctionWithMods =>
1231-
// TODO: make use of this in the desugaring when pureFuns is enabled.
1232-
// val isImpure = funFlags.is(Impure)
1233-
1234-
// Function flags to be propagated to each parameter in the desugared method type.
1235-
val givenFlag = fun.mods.flags.toTermFlags & Given
1236-
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1237-
case _ =>
1238-
vparamTypes.map(_ => EmptyFlags)
1239-
1240-
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1241-
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1242-
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1243-
}.toList
1244-
1245-
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1246-
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
1247-
)).withSpan(tree.span)
1301+
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = (tree: @unchecked) match
1302+
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
1303+
val paramFlags = fun match
1304+
case fun: FunctionWithMods =>
1305+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1306+
// val isImpure = funFlags.is(Impure)
1307+
1308+
// Function flags to be propagated to each parameter in the desugared method type.
1309+
val givenFlag = fun.mods.flags.toTermFlags & Given
1310+
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1311+
case _ =>
1312+
vparamTypes.map(_ => EmptyFlags)
1313+
1314+
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1315+
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1316+
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1317+
}.toList
1318+
1319+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1320+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
1321+
.withFlags(Synthetic)
1322+
.withAttachment(PolyFunctionApply, ())
1323+
)).withSpan(tree.span)
12481324
end makePolyFunctionType
12491325

12501326
/** Invent a name for an anonympus given of type or template `impl`. */

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -3459,7 +3459,7 @@ object Parsers {
34593459
*
34603460
* TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’
34613461
* TypTypeParam ::= {Annotation}
3462-
* (id | ‘_’) [HkTypeParamClause] TypeBounds
3462+
* (id | ‘_’) [HkTypeParamClause] TypeAndCtxBounds
34633463
*
34643464
* HkTypeParamClause ::= ‘[’ HkTypeParam {‘,’ HkTypeParam} ‘]’
34653465
* HkTypeParam ::= {Annotation} [‘+’ | ‘-’]
@@ -3490,7 +3490,9 @@ object Parsers {
34903490
else ident().toTypeName
34913491
val hkparams = typeParamClauseOpt(ParamOwner.Hk)
34923492
val bounds =
3493-
if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name) else typeBounds()
3493+
if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name)
3494+
else if in.featureEnabled(Feature.modularity) && paramOwner == ParamOwner.Type then typeAndCtxBounds(name)
3495+
else typeBounds()
34943496
TypeDef(name, lambdaAbstract(hkparams, bounds)).withMods(mods)
34953497
}
34963498
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

+2-3
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19201920
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19211921
val tree1 = desugar.normalizePolyFunction(tree)
19221922
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
1923-
else typedPolyFunctionValue(tree1, pt)
1923+
else typedPolyFunctionValue(desugar.elimContextBounds(tree1).asInstanceOf[untpd.PolyFunction], pt)
19241924

19251925
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19261926
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
@@ -2474,7 +2474,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24742474
val TypeDef(_, impl: Template) = typed(refineClsDef): @unchecked
24752475
val refinements1 = impl.body
24762476
val seen = mutable.Set[Symbol]()
2477-
for (refinement <- refinements1) { // TODO: get clarity whether we want to enforce these conditions
2477+
for refinement <- refinements1 do // TODO: get clarity whether we want to enforce these conditions
24782478
typr.println(s"adding refinement $refinement")
24792479
checkRefinementNonCyclic(refinement, refineCls, seen)
24802480
val rsym = refinement.symbol
@@ -2488,7 +2488,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24882488
val member = refineCls.info.member(rsym.name)
24892489
if (member.isOverloaded)
24902490
report.error(OverloadInRefinement(rsym), refinement.srcPos)
2491-
}
24922491
assignType(cpy.RefinedTypeTree(tree)(tpt1, refinements1), tpt1, refinements1, refineCls)
24932492
}
24942493

0 commit comments

Comments
 (0)