Skip to content

Commit 24e3fa0

Browse files
committed
More cleanup of poly context bound desugaring
1 parent 7755e3b commit 24e3fa0

File tree

3 files changed

+83
-59
lines changed

3 files changed

+83
-59
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+77-56
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ object desugar {
247247
* def f$default$2[T](x: Int) = x + "m"
248248
*/
249249
private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(using Context): Tree =
250-
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor))
250+
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor).asInstanceOf[DefDef])
251251

252252
/** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that
253253
* get added to a buffer.
@@ -309,10 +309,8 @@ object desugar {
309309
tdef1
310310
end desugarContextBounds
311311

312-
private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef =
313-
val DefDef(_, paramss, tpt, rhs) = meth
312+
def elimContextBounds(meth: Tree, isPrimaryConstructor: Boolean = false)(using Context): Tree =
314313
val evidenceParamBuf = mutable.ListBuffer[ValDef]()
315-
316314
var seenContextBounds: Int = 0
317315
def freshName(unused: Tree) =
318316
seenContextBounds += 1 // Start at 1 like FreshNameCreator.
@@ -322,7 +320,7 @@ object desugar {
322320
// parameters of the method since shadowing does not affect
323321
// implicit resolution in Scala 3.
324322

325-
val paramssNoContextBounds =
323+
def paramssNoContextBounds(paramss: List[ParamClause]): List[ParamClause] =
326324
val iflag = paramss.lastOption.flatMap(_.headOption) match
327325
case Some(param) if param.mods.isOneOf(GivenOrImplicit) =>
328326
param.mods.flags & GivenOrImplicit
@@ -334,16 +332,29 @@ object desugar {
334332
tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss)
335333
}(identity)
336334

337-
rhs match
338-
case MacroTree(call) =>
339-
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
340-
case _ =>
341-
addEvidenceParams(
342-
cpy.DefDef(meth)(
343-
name = normalizeName(meth, tpt).asTermName,
344-
paramss = paramssNoContextBounds),
345-
evidenceParamBuf.toList
346-
)
335+
meth match
336+
case meth @ DefDef(_, paramss, tpt, rhs) =>
337+
val newParamss = paramssNoContextBounds(paramss)
338+
rhs match
339+
case MacroTree(call) =>
340+
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
341+
case _ =>
342+
addEvidenceParams(
343+
cpy.DefDef(meth)(
344+
name = normalizeName(meth, tpt).asTermName,
345+
paramss = newParamss
346+
),
347+
evidenceParamBuf.toList
348+
)
349+
case meth @ PolyFunction(tparams, fun) =>
350+
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = meth: @unchecked
351+
val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked
352+
val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil)
353+
val params = evidenceParamBuf.toList
354+
val boundNames = getBoundNames(params, newParamss)
355+
val recur = fitEvidenceParams(params, nme.apply, boundNames)
356+
val (paramsFst, paramsSnd) = recur(newParamss)
357+
functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs)
347358
end elimContextBounds
348359

349360
def addDefaultGetters(meth: DefDef)(using Context): Tree =
@@ -471,6 +482,55 @@ object desugar {
471482
case _ =>
472483
(Nil, tree)
473484

485+
private def referencesName(vdef: ValDef, names: Set[TermName])(using Context): Boolean =
486+
vdef.tpt.existsSubTree:
487+
case Ident(name: TermName) => names.contains(name)
488+
case _ => false
489+
490+
/** Fit evidence `params` into the `mparamss` parameter lists */
491+
private def fitEvidenceParams(params: List[ValDef], methName: Name, boundNames: Set[TermName])(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match
492+
case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) =>
493+
(params :: Nil) -> mparamss
494+
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
495+
val normParams =
496+
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
497+
params.map: param =>
498+
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
499+
param.withMods(param.mods.withFlags(normFlags))
500+
.showing(i"adapted param $result ${result.mods.flags} for ${methName}", Printers.desugar)
501+
else params
502+
((normParams ++ mparams) :: Nil) -> Nil
503+
case mparams :: mparamss1 =>
504+
val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1)
505+
(mparams :: fst) -> snd
506+
case Nil =>
507+
Nil -> (params :: Nil)
508+
509+
/** Create a chain of possibly contextual functions from the parameter lists */
510+
private def functionsOf(paramss: List[ParamClause], rhs: Tree)(using Context): Tree = paramss match
511+
case Nil => rhs
512+
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
513+
val paramTpts = head.map(_.tpt)
514+
val paramNames = head.map(_.name)
515+
val paramsErased = head.map(_.mods.flags.is(Erased))
516+
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
517+
case ValDefs(head) :: rest =>
518+
Function(head, functionsOf(rest, rhs))
519+
case TypeDefs(head) :: rest =>
520+
PolyFunction(head, functionsOf(rest, rhs))
521+
case _ =>
522+
assert(false, i"unexpected paramss $paramss")
523+
EmptyTree
524+
525+
private def getBoundNames(params: List[ValDef], paramss: List[ParamClause])(using Context): Set[TermName] =
526+
var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
527+
for mparams <- paramss; mparam <- mparams do
528+
mparam match
529+
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
530+
boundNames += tparam.name.toTermName
531+
case _ =>
532+
boundNames
533+
474534
/** Add all evidence parameters in `params` as implicit parameters to `meth`.
475535
* The position of the added parameters is determined as follows:
476536
*
@@ -485,48 +545,9 @@ object desugar {
485545
private def addEvidenceParams(meth: DefDef, params: List[ValDef])(using Context): DefDef =
486546
if params.isEmpty then return meth
487547

488-
var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
489-
for mparams <- meth.paramss; mparam <- mparams do
490-
mparam match
491-
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
492-
boundNames += tparam.name.toTermName
493-
case _ =>
548+
val boundNames = getBoundNames(params, meth.paramss)
494549

495-
def referencesBoundName(vdef: ValDef): Boolean =
496-
vdef.tpt.existsSubTree:
497-
case Ident(name: TermName) => boundNames.contains(name)
498-
case _ => false
499-
500-
def recur(mparamss: List[ParamClause]): (List[ParamClause], List[ParamClause]) = mparamss match
501-
case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) =>
502-
(params :: Nil) -> mparamss
503-
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
504-
val normParams =
505-
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
506-
params.map: param =>
507-
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
508-
param.withMods(param.mods.withFlags(normFlags))
509-
.showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar)
510-
else params
511-
((normParams ++ mparams) :: Nil) -> Nil
512-
case mparams :: mparamss1 =>
513-
val (fst, snd) = recur(mparamss1)
514-
(mparams :: fst) -> snd
515-
case Nil =>
516-
Nil -> (params :: Nil)
517-
518-
def functionsOf(paramss: List[ParamClause], rhs: Tree): Tree = paramss match
519-
case Nil => rhs
520-
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
521-
val paramTpts = params.map(_.tpt)
522-
val paramNames = params.map(_.name)
523-
val paramsErased = params.map(_.mods.flags.is(Erased))
524-
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
525-
case ValDefs(head) :: rest =>
526-
Function(head, functionsOf(rest, rhs))
527-
case head :: _ =>
528-
assert(false, i"unexpected type parameters when adding evidence parameters to $meth")
529-
EmptyTree
550+
val recur = fitEvidenceParams(params, meth.name, boundNames)
530551

531552
if meth.hasAttachment(PolyFunctionApply) then
532553
meth.removeAttachment(PolyFunctionApply)

compiler/src/dotty/tools/dotc/typer/Typer.scala

+1-3
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19201920
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19211921
val tree1 = desugar.normalizePolyFunction(tree)
19221922
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
1923-
else typedPolyFunctionValue(tree1, pt)
1923+
else typedPolyFunctionValue(desugar.elimContextBounds(tree1).asInstanceOf[untpd.PolyFunction], pt)
19241924

19251925
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19261926
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
@@ -1946,7 +1946,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19461946
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
19471947
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
19481948
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
1949-
defdef.putAttachment(desugar.PolyFunctionApply, List.empty)
19501949
typed(desugared, pt)
19511950
else
19521951
val msg =
@@ -1955,7 +1954,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19551954
errorTree(EmptyTree, msg, tree.srcPos)
19561955
case _ =>
19571956
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span)
1958-
defdef.putAttachment(desugar.PolyFunctionApply, List.empty)
19591957
typed(desugared, pt)
19601958
end typedPolyFunctionValue
19611959

tests/pos/contextbounds-for-poly-functions.scala

+5
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,8 @@ val namedConstraintRef = [X: {Ord as ord}] => (x: ord.T) => x
8686
type DependentCmp = [X: {Ord as ord}] => ord.T => Boolean
8787
type DependentCmp1 = [X: {Ord as ord}] => (ord.T, Int) => ord.T => Boolean
8888
val dependentCmp: DependentCmp = [X: {Ord as ord}] => (x: ord.T) => true
89+
val dependentCmp_1: [X: {Ord as ord}] => ord.T => Boolean = [X: {Ord as ord}] => (x: ord.T) => true
90+
91+
val dependentCmp1: DependentCmp1 = [X: {Ord as ord}] => (x: ord.T, y: Int) => (z: ord.T) => true
92+
val dependentCmp1_1: [X: {Ord as ord}] => (ord.T, Int) => ord.T => Boolean =
93+
[X: {Ord as ord}] => (x: ord.T, y: Int) => (z: ord.T) => true

0 commit comments

Comments
 (0)