Skip to content

Commit 6896ea2

Browse files
committed
Attempt to beta reduce only if parameter and argument lists have same shape
It's possible to define Functions with wrong apply methods by hand which will give an error but pass on a function that does fails beta reduction. Fixes #21952
1 parent 58f88a6 commit 6896ea2

File tree

3 files changed

+43
-23
lines changed

3 files changed

+43
-23
lines changed

compiler/src/dotty/tools/dotc/transform/BetaReduce.scala

+37-20
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ object BetaReduce:
7676
val bindingsBuf = new ListBuffer[DefTree]
7777
def recur(fn: Tree, argss: List[List[Tree]]): Option[Tree] = fn match
7878
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
79-
Some(reduceApplication(ddef, argss, bindingsBuf))
79+
reduceApplication(ddef, argss, bindingsBuf)
8080
case Block((TypeDef(_, template: Template)) :: Nil, Typed(Apply(Select(New(_), _), _), _)) if template.constr.rhs.isEmpty =>
8181
template.body match
82-
case (ddef: DefDef) :: Nil => Some(reduceApplication(ddef, argss, bindingsBuf))
82+
case (ddef: DefDef) :: Nil => reduceApplication(ddef, argss, bindingsBuf)
8383
case _ => None
8484
case Block(stats, expr) if stats.forall(isPureBinding) =>
8585
recur(expr, argss).map(cpy.Block(fn)(stats, _))
@@ -106,12 +106,22 @@ object BetaReduce:
106106
case _ =>
107107
tree
108108

109-
/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
110-
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])(using Context): Tree =
109+
/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings.
110+
* @return optionally, the expanded call, or none if the actual argument
111+
* lists do not match in shape the formal parameters
112+
*/
113+
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])
114+
(using Context): Option[Tree] =
111115
val (targs, args) = argss.flatten.partition(_.isType)
112116
val tparams = ddef.leadingTypeParams
113117
val vparams = ddef.termParamss.flatten
114118

119+
def shapeMatch(paramss: List[ParamClause], argss: List[List[Tree]]): Boolean = (paramss, argss) match
120+
case (params :: paramss1, args :: argss1) if params.length == args.length =>
121+
shapeMatch(paramss1, argss1)
122+
case (Nil, Nil) => true
123+
case _ => false
124+
115125
val targSyms =
116126
for (targ, tparam) <- targs.zip(tparams) yield
117127
targ.tpe.dealias match
@@ -143,19 +153,26 @@ object BetaReduce:
143153
bindings += binding.withSpan(arg.span)
144154
bindingSymbol
145155

146-
val expansion = TreeTypeMap(
147-
oldOwners = ddef.symbol :: Nil,
148-
newOwners = ctx.owner :: Nil,
149-
substFrom = (tparams ::: vparams).map(_.symbol),
150-
substTo = targSyms ::: argSyms
151-
).transform(ddef.rhs)
152-
153-
val expansion1 = new TreeMap {
154-
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
155-
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
156-
case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) =>
157-
cpy.Literal(tree)(Constant(()))
158-
case _ => super.transform(tree)
159-
}.transform(expansion)
160-
161-
expansion1
156+
if shapeMatch(ddef.paramss, argss) then
157+
// We can't assume arguments always match. It's possible to construct a
158+
// function with wrong apply method by hand which causes `shapeMatch` to fail.
159+
// See neg/i21952.scala
160+
val expansion = TreeTypeMap(
161+
oldOwners = ddef.symbol :: Nil,
162+
newOwners = ctx.owner :: Nil,
163+
substFrom = (tparams ::: vparams).map(_.symbol),
164+
substTo = targSyms ::: argSyms
165+
).transform(ddef.rhs)
166+
167+
val expansion1 = new TreeMap {
168+
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
169+
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
170+
case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) =>
171+
cpy.Literal(tree)(Constant(()))
172+
case _ => super.transform(tree)
173+
}.transform(expansion)
174+
175+
Some(expansion1)
176+
else None
177+
end reduceApplication
178+
end BetaReduce

compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala

+5-3
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ class InlinePatterns extends MiniPhase:
6060
template.body match
6161
case List(ddef @ DefDef(`name`, _, _, _)) =>
6262
val bindings = new ListBuffer[DefTree]()
63-
val expansion1 = BetaReduce.reduceApplication(ddef, argss, bindings)
64-
val bindings1 = bindings.result()
65-
seq(bindings1, expansion1)
63+
BetaReduce.reduceApplication(ddef, argss, bindings) match
64+
case Some(expansion1) =>
65+
val bindings1 = bindings.result()
66+
seq(bindings1, expansion1)
67+
case None => tree
6668
case _ => tree
6769
case _ => tree
6870

tests/neg/i21952.scala

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
val _ = (new Function[(Int, Int), Int] {def apply(a: Int, b: Int): Int = a * b})(2, 3) // error

0 commit comments

Comments
 (0)