Skip to content

Commit 73168ee

Browse files
authored
Attempt to beta reduce only if parameters and arguments have same shape (#21970)
It's possible to define Functions with wrong apply methods by hand which will give an error but pass on a function that does fails beta reduction. Fixes #21952
2 parents 3a5e137 + 6896ea2 commit 73168ee

File tree

3 files changed

+43
-23
lines changed

3 files changed

+43
-23
lines changed

compiler/src/dotty/tools/dotc/transform/BetaReduce.scala

+37-20
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ object BetaReduce:
7676
val bindingsBuf = new ListBuffer[DefTree]
7777
def recur(fn: Tree, argss: List[List[Tree]]): Option[Tree] = fn match
7878
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
79-
Some(reduceApplication(ddef, argss, bindingsBuf))
79+
reduceApplication(ddef, argss, bindingsBuf)
8080
case Block((TypeDef(_, template: Template)) :: Nil, Typed(Apply(Select(New(_), _), _), _)) if template.constr.rhs.isEmpty =>
8181
template.body match
82-
case (ddef: DefDef) :: Nil => Some(reduceApplication(ddef, argss, bindingsBuf))
82+
case (ddef: DefDef) :: Nil => reduceApplication(ddef, argss, bindingsBuf)
8383
case _ => None
8484
case Block(stats, expr) if stats.forall(isPureBinding) =>
8585
recur(expr, argss).map(cpy.Block(fn)(stats, _))
@@ -106,12 +106,22 @@ object BetaReduce:
106106
case _ =>
107107
tree
108108

109-
/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
110-
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])(using Context): Tree =
109+
/** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings.
110+
* @return optionally, the expanded call, or none if the actual argument
111+
* lists do not match in shape the formal parameters
112+
*/
113+
def reduceApplication(ddef: DefDef, argss: List[List[Tree]], bindings: ListBuffer[DefTree])
114+
(using Context): Option[Tree] =
111115
val (targs, args) = argss.flatten.partition(_.isType)
112116
val tparams = ddef.leadingTypeParams
113117
val vparams = ddef.termParamss.flatten
114118

119+
def shapeMatch(paramss: List[ParamClause], argss: List[List[Tree]]): Boolean = (paramss, argss) match
120+
case (params :: paramss1, args :: argss1) if params.length == args.length =>
121+
shapeMatch(paramss1, argss1)
122+
case (Nil, Nil) => true
123+
case _ => false
124+
115125
val targSyms =
116126
for (targ, tparam) <- targs.zip(tparams) yield
117127
targ.tpe.dealias match
@@ -143,19 +153,26 @@ object BetaReduce:
143153
bindings += binding.withSpan(arg.span)
144154
bindingSymbol
145155

146-
val expansion = TreeTypeMap(
147-
oldOwners = ddef.symbol :: Nil,
148-
newOwners = ctx.owner :: Nil,
149-
substFrom = (tparams ::: vparams).map(_.symbol),
150-
substTo = targSyms ::: argSyms
151-
).transform(ddef.rhs)
152-
153-
val expansion1 = new TreeMap {
154-
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
155-
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
156-
case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) =>
157-
cpy.Literal(tree)(Constant(()))
158-
case _ => super.transform(tree)
159-
}.transform(expansion)
160-
161-
expansion1
156+
if shapeMatch(ddef.paramss, argss) then
157+
// We can't assume arguments always match. It's possible to construct a
158+
// function with wrong apply method by hand which causes `shapeMatch` to fail.
159+
// See neg/i21952.scala
160+
val expansion = TreeTypeMap(
161+
oldOwners = ddef.symbol :: Nil,
162+
newOwners = ctx.owner :: Nil,
163+
substFrom = (tparams ::: vparams).map(_.symbol),
164+
substTo = targSyms ::: argSyms
165+
).transform(ddef.rhs)
166+
167+
val expansion1 = new TreeMap {
168+
override def transform(tree: Tree)(using Context) = tree.tpe.widenTermRefExpr match
169+
case ConstantType(const) if isPureExpr(tree) => cpy.Literal(tree)(const)
170+
case tpe: TypeRef if tree.isTerm && tpe.derivesFrom(defn.UnitClass) && isPureExpr(tree) =>
171+
cpy.Literal(tree)(Constant(()))
172+
case _ => super.transform(tree)
173+
}.transform(expansion)
174+
175+
Some(expansion1)
176+
else None
177+
end reduceApplication
178+
end BetaReduce

compiler/src/dotty/tools/dotc/transform/InlinePatterns.scala

+5-3
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ class InlinePatterns extends MiniPhase:
6060
template.body match
6161
case List(ddef @ DefDef(`name`, _, _, _)) =>
6262
val bindings = new ListBuffer[DefTree]()
63-
val expansion1 = BetaReduce.reduceApplication(ddef, argss, bindings)
64-
val bindings1 = bindings.result()
65-
seq(bindings1, expansion1)
63+
BetaReduce.reduceApplication(ddef, argss, bindings) match
64+
case Some(expansion1) =>
65+
val bindings1 = bindings.result()
66+
seq(bindings1, expansion1)
67+
case None => tree
6668
case _ => tree
6769
case _ => tree
6870

tests/neg/i21952.scala

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
val _ = (new Function[(Int, Int), Int] {def apply(a: Int, b: Int): Int = a * b})(2, 3) // error

0 commit comments

Comments
 (0)