@@ -52,6 +52,10 @@ object desugar {
52
52
*/
53
53
val ContextBoundParam : Property .Key [Unit ] = Property .StickyKey ()
54
54
55
+ /** Marks a poly fcuntion apply method, so that we can handle adding evidence parameters to them in a special way
56
+ */
57
+ val PolyFunctionApply : Property .Key [Unit ] = Property .StickyKey ()
58
+
55
59
/** What static check should be applied to a Match? */
56
60
enum MatchCheck {
57
61
case None , Exhaustive , IrrefutablePatDef , IrrefutableGenFrom
@@ -242,7 +246,7 @@ object desugar {
242
246
* def f$default$2[T](x: Int) = x + "m"
243
247
*/
244
248
private def defDef (meth : DefDef , isPrimaryConstructor : Boolean = false )(using Context ): Tree =
245
- addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor))
249
+ addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor). asInstanceOf [ DefDef ] )
246
250
247
251
/** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that
248
252
* get added to a buffer.
@@ -304,10 +308,8 @@ object desugar {
304
308
tdef1
305
309
end desugarContextBounds
306
310
307
- private def elimContextBounds (meth : DefDef , isPrimaryConstructor : Boolean )(using Context ): DefDef =
308
- val DefDef (_, paramss, tpt, rhs) = meth
311
+ def elimContextBounds (meth : Tree , isPrimaryConstructor : Boolean = false )(using Context ): Tree =
309
312
val evidenceParamBuf = mutable.ListBuffer [ValDef ]()
310
-
311
313
var seenContextBounds : Int = 0
312
314
def freshName (unused : Tree ) =
313
315
seenContextBounds += 1 // Start at 1 like FreshNameCreator.
@@ -317,7 +319,7 @@ object desugar {
317
319
// parameters of the method since shadowing does not affect
318
320
// implicit resolution in Scala 3.
319
321
320
- val paramssNoContextBounds =
322
+ def paramssNoContextBounds ( paramss : List [ ParamClause ]) : List [ ParamClause ] =
321
323
val iflag = paramss.lastOption.flatMap(_.headOption) match
322
324
case Some (param) if param.mods.isOneOf(GivenOrImplicit ) =>
323
325
param.mods.flags & GivenOrImplicit
@@ -329,15 +331,32 @@ object desugar {
329
331
tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss)
330
332
}(identity)
331
333
332
- rhs match
333
- case MacroTree (call) =>
334
- cpy.DefDef (meth)(rhs = call).withMods(meth.mods | Macro | Erased )
335
- case _ =>
336
- addEvidenceParams(
337
- cpy.DefDef (meth)(
338
- name = normalizeName(meth, tpt).asTermName,
339
- paramss = paramssNoContextBounds),
340
- evidenceParamBuf.toList)
334
+ meth match
335
+ case meth @ DefDef (_, paramss, tpt, rhs) =>
336
+ val newParamss = paramssNoContextBounds(paramss)
337
+ rhs match
338
+ case MacroTree (call) =>
339
+ cpy.DefDef (meth)(rhs = call).withMods(meth.mods | Macro | Erased )
340
+ case _ =>
341
+ addEvidenceParams(
342
+ cpy.DefDef (meth)(
343
+ name = normalizeName(meth, tpt).asTermName,
344
+ paramss = newParamss
345
+ ),
346
+ evidenceParamBuf.toList
347
+ )
348
+ case meth @ PolyFunction (tparams, fun) =>
349
+ val PolyFunction (tparams : List [untpd.TypeDef ] @ unchecked, fun) = meth : @ unchecked
350
+ val Function (vparams : List [untpd.ValDef ] @ unchecked, rhs) = fun : @ unchecked
351
+ val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil )
352
+ val params = evidenceParamBuf.toList
353
+ if params.isEmpty then
354
+ meth
355
+ else
356
+ val boundNames = getBoundNames(params, newParamss)
357
+ val recur = fitEvidenceParams(params, nme.apply, boundNames)
358
+ val (paramsFst, paramsSnd) = recur(newParamss)
359
+ functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs)
341
360
end elimContextBounds
342
361
343
362
def addDefaultGetters (meth : DefDef )(using Context ): Tree =
@@ -465,6 +484,74 @@ object desugar {
465
484
case _ =>
466
485
(Nil , tree)
467
486
487
+ private def referencesName (vdef : ValDef , names : Set [TermName ])(using Context ): Boolean =
488
+ vdef.tpt.existsSubTree:
489
+ case Ident (name : TermName ) => names.contains(name)
490
+ case _ => false
491
+
492
+ /** Fit evidence `params` into the `mparamss` parameter lists, making sure
493
+ * that all parameters referencing `params` are after them.
494
+ * - for methods the final parameter lists are := result._1 ++ result._2
495
+ * - for poly functions, each element of the pair contains at most one term
496
+ * parameter list
497
+ *
498
+ * @param params the evidence parameters list that should fit into `mparamss`
499
+ * @param methName the name of the method that `mparamss` belongs to
500
+ * @param boundNames the names of the evidence parameters
501
+ * @param mparamss the original parameter lists of the method
502
+ * @return a pair of parameter lists containing all parameter lists in a
503
+ * reference-correct order; make sure that `params` is always at the
504
+ * intersection of the pair elements; this is relevant, for poly functions
505
+ * where `mparamss` is guaranteed to have exectly one term parameter list,
506
+ * then each pair element will have at most one term parameter list
507
+ */
508
+ private def fitEvidenceParams (
509
+ params : List [ValDef ],
510
+ methName : Name ,
511
+ boundNames : Set [TermName ]
512
+ )(mparamss : List [ParamClause ])(using Context ): (List [ParamClause ], List [ParamClause ]) = mparamss match
513
+ case ValDefs (mparams) :: _ if mparams.exists(referencesName(_, boundNames)) =>
514
+ (params :: Nil ) -> mparamss
515
+ case ValDefs (mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit ) =>
516
+ val normParams =
517
+ if params.head.mods.flags.is(Given ) != mparam.mods.flags.is(Given ) then
518
+ params.map: param =>
519
+ val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit ))
520
+ param.withMods(param.mods.withFlags(normFlags))
521
+ .showing(i " adapted param $result ${result.mods.flags} for ${methName}" , Printers .desugar)
522
+ else params
523
+ ((normParams ++ mparams) :: Nil ) -> Nil
524
+ case mparams :: mparamss1 =>
525
+ val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1)
526
+ (mparams :: fst) -> snd
527
+ case Nil =>
528
+ Nil -> (params :: Nil )
529
+
530
+ /** Create a chain of possibly contextual functions from the parameter lists */
531
+ private def functionsOf (paramss : List [ParamClause ], rhs : Tree )(using Context ): Tree = paramss match
532
+ case Nil => rhs
533
+ case ValDefs (head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit ) =>
534
+ val paramTpts = head.map(_.tpt)
535
+ val paramNames = head.map(_.name)
536
+ val paramsErased = head.map(_.mods.flags.is(Erased ))
537
+ makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
538
+ case ValDefs (head) :: rest =>
539
+ Function (head, functionsOf(rest, rhs))
540
+ case TypeDefs (head) :: rest =>
541
+ PolyFunction (head, functionsOf(rest, rhs))
542
+ case _ =>
543
+ assert(false , i " unexpected paramss $paramss" )
544
+ EmptyTree
545
+
546
+ private def getBoundNames (params : List [ValDef ], paramss : List [ParamClause ])(using Context ): Set [TermName ] =
547
+ var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
548
+ for mparams <- paramss; mparam <- mparams do
549
+ mparam match
550
+ case tparam : TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot .unapply(_).isDefined) =>
551
+ boundNames += tparam.name.toTermName
552
+ case _ =>
553
+ boundNames
554
+
468
555
/** Add all evidence parameters in `params` as implicit parameters to `meth`.
469
556
* The position of the added parameters is determined as follows:
470
557
*
@@ -479,36 +566,23 @@ object desugar {
479
566
private def addEvidenceParams (meth : DefDef , params : List [ValDef ])(using Context ): DefDef =
480
567
if params.isEmpty then return meth
481
568
482
- var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
483
- for mparams <- meth.paramss; mparam <- mparams do
484
- mparam match
485
- case tparam : TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot .unapply(_).isDefined) =>
486
- boundNames += tparam.name.toTermName
487
- case _ =>
569
+ val boundNames = getBoundNames(params, meth.paramss)
488
570
489
- def referencesBoundName (vdef : ValDef ): Boolean =
490
- vdef.tpt.existsSubTree:
491
- case Ident (name : TermName ) => boundNames.contains(name)
492
- case _ => false
571
+ val fitParams = fitEvidenceParams(params, meth.name, boundNames)
493
572
494
- def recur (mparamss : List [ParamClause ]): List [ParamClause ] = mparamss match
495
- case ValDefs (mparams) :: _ if mparams.exists(referencesBoundName) =>
496
- params :: mparamss
497
- case ValDefs (mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit ) =>
498
- val normParams =
499
- if params.head.mods.flags.is(Given ) != mparam.mods.flags.is(Given ) then
500
- params.map: param =>
501
- val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit ))
502
- param.withMods(param.mods.withFlags(normFlags))
503
- .showing(i " adapted param $result ${result.mods.flags} for ${meth.name}" , Printers .desugar)
504
- else params
505
- (normParams ++ mparams) :: Nil
506
- case mparams :: mparamss1 =>
507
- mparams :: recur(mparamss1)
508
- case Nil =>
509
- params :: Nil
510
-
511
- cpy.DefDef (meth)(paramss = recur(meth.paramss))
573
+ if meth.removeAttachment(PolyFunctionApply ).isDefined then
574
+ // for PolyFunctions we are limited to a single term param list, so we
575
+ // reuse the fitEvidenceParams logic to compute the new parameter lists
576
+ // and then we add the other parameter lists as function types to the
577
+ // return type
578
+ val (paramsFst, paramsSnd) = fitParams(meth.paramss)
579
+ if ctx.mode.is(Mode .Type ) then
580
+ cpy.DefDef (meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt))
581
+ else
582
+ cpy.DefDef (meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs))
583
+ else
584
+ val (paramsFst, paramsSnd) = fitParams(meth.paramss)
585
+ cpy.DefDef (meth)(paramss = paramsFst ++ paramsSnd)
512
586
end addEvidenceParams
513
587
514
588
/** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */
@@ -1224,27 +1298,29 @@ object desugar {
1224
1298
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
1225
1299
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1226
1300
*/
1227
- def makePolyFunctionType (tree : PolyFunction )(using Context ): RefinedTypeTree =
1228
- val PolyFunction (tparams : List [untpd.TypeDef ] @ unchecked, fun @ untpd.Function (vparamTypes, res)) = tree : @ unchecked
1229
- val paramFlags = fun match
1230
- case fun : FunctionWithMods =>
1231
- // TODO: make use of this in the desugaring when pureFuns is enabled.
1232
- // val isImpure = funFlags.is(Impure)
1233
-
1234
- // Function flags to be propagated to each parameter in the desugared method type.
1235
- val givenFlag = fun.mods.flags.toTermFlags & Given
1236
- fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1237
- case _ =>
1238
- vparamTypes.map(_ => EmptyFlags )
1239
-
1240
- val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1241
- case ((p : ValDef , paramFlags), n) => p.withAddedFlags(paramFlags)
1242
- case ((p, paramFlags), n) => makeSyntheticParameter(n + 1 , p).withAddedFlags(paramFlags)
1243
- }.toList
1244
-
1245
- RefinedTypeTree (ref(defn.PolyFunctionType ), List (
1246
- DefDef (nme.apply, tparams :: vparams :: Nil , res, EmptyTree ).withFlags(Synthetic )
1247
- )).withSpan(tree.span)
1301
+ def makePolyFunctionType (tree : PolyFunction )(using Context ): RefinedTypeTree = (tree : @ unchecked) match
1302
+ case PolyFunction (tparams : List [untpd.TypeDef ] @ unchecked, fun @ untpd.Function (vparamTypes, res)) =>
1303
+ val paramFlags = fun match
1304
+ case fun : FunctionWithMods =>
1305
+ // TODO: make use of this in the desugaring when pureFuns is enabled.
1306
+ // val isImpure = funFlags.is(Impure)
1307
+
1308
+ // Function flags to be propagated to each parameter in the desugared method type.
1309
+ val givenFlag = fun.mods.flags.toTermFlags & Given
1310
+ fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1311
+ case _ =>
1312
+ vparamTypes.map(_ => EmptyFlags )
1313
+
1314
+ val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1315
+ case ((p : ValDef , paramFlags), n) => p.withAddedFlags(paramFlags)
1316
+ case ((p, paramFlags), n) => makeSyntheticParameter(n + 1 , p).withAddedFlags(paramFlags)
1317
+ }.toList
1318
+
1319
+ RefinedTypeTree (ref(defn.PolyFunctionType ), List (
1320
+ DefDef (nme.apply, tparams :: vparams :: Nil , res, EmptyTree )
1321
+ .withFlags(Synthetic )
1322
+ .withAttachment(PolyFunctionApply , ())
1323
+ )).withSpan(tree.span)
1248
1324
end makePolyFunctionType
1249
1325
1250
1326
/** Invent a name for an anonympus given of type or template `impl`. */
0 commit comments