Skip to content

Commit b4b02b4

Browse files
committed
Handle return properly
1 parent 8281500 commit b4b02b4

File tree

3 files changed

+60
-3
lines changed

3 files changed

+60
-3
lines changed

compiler/src/dotty/tools/dotc/transform/init/Objects.scala

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ object Objects:
212212
given Trace = Trace.empty.add(classSym.defTree)
213213
given Env.Data = Env.emptyEnv(tpl.constr.symbol)
214214
given Heap.MutableData = Heap.empty()
215+
given returns: Returns.Data = Returns.empty()
215216
given regions: Regions.Data = Regions.empty // explicit name to avoid naming conflict
216217

217218
val obj = ObjectRef(classSym)
@@ -487,7 +488,33 @@ object Objects:
487488

488489
inline def cache(using c: Cache.Data): Cache.Data = c
489490

490-
type Contextual[T] = (Context, State.Data, Env.Data, Cache.Data, Heap.MutableData, Regions.Data, Trace) ?=> T
491+
492+
/**
493+
* Handle return statements in methods and non-local returns in functions.
494+
*/
495+
object Returns:
496+
private class ReturnData(val method: Symbol, val values: mutable.ArrayBuffer[Value])
497+
opaque type Data = mutable.ArrayBuffer[ReturnData]
498+
499+
def empty(): Data = mutable.ArrayBuffer()
500+
501+
def installHandler(meth: Symbol)(using data: Data): Unit =
502+
data.addOne(ReturnData(meth, mutable.ArrayBuffer()))
503+
504+
def popHandler(meth: Symbol)(using data: Data): Value =
505+
val returnData = data.remove(data.size - 1)
506+
assert(returnData.method == meth, "Symbol mismatch in return handlers, expect = " + meth + ", found = " + returnData.method)
507+
returnData.values.join
508+
509+
def handle(meth: Symbol, value: Value)(using data: Data, trace: Trace, ctx: Context): Unit =
510+
data.findLast(_.method == meth) match
511+
case Some(returnData) =>
512+
returnData.values.addOne(value)
513+
514+
case None =>
515+
report.error("[Internal error] Unhandled return for method " + meth + " in " + meth.owner.show + ". Trace:\n" + Trace.show, Trace.position)
516+
517+
type Contextual[T] = (Context, State.Data, Env.Data, Cache.Data, Heap.MutableData, Regions.Data, Returns.Data, Trace) ?=> T
491518

492519
// --------------------------- domain operations -----------------------------
493520

@@ -595,7 +622,13 @@ object Objects:
595622
val env2 = Env.of(ddef, args.map(_.value), outerEnv)
596623
extendTrace(ddef) {
597624
given Env.Data = env2
598-
eval(ddef.rhs, ref, cls, cacheResult = true)
625+
// eval(ddef.rhs, ref, cls, cacheResult = true)
626+
cache.cachedEval(ref, ddef.rhs, cacheResult = true) { expr =>
627+
Returns.installHandler(meth)
628+
val res = cases(expr, thisV, cls)
629+
val returns = Returns.popHandler(meth)
630+
res.join(returns)
631+
}
599632
}
600633
else
601634
Bottom
@@ -1079,7 +1112,8 @@ object Objects:
10791112
evalExprs(cases.map(_.body), thisV, klass).join
10801113

10811114
case Return(expr, from) =>
1082-
eval(expr, thisV, klass)
1115+
Returns.handle(from.symbol, eval(expr, thisV, klass))
1116+
Bottom
10831117

10841118
case WhileDo(cond, body) =>
10851119
evalExprs(cond :: body :: Nil, thisV, klass)

tests/init-global/neg/return.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object A:
2+
def foo(x: Int): Int => Int =
3+
if x <= 0 then
4+
return (a: Int) => a + B.n // error
5+
6+
(a: Int) => a * a + x
7+
8+
object B:
9+
val n = A.foo(-10)(20)
10+

tests/init-global/neg/return2.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
object A:
2+
def foo(x: Int): Int => Int =
3+
val f = (a: Int) => a + B.n // error
4+
var i = 0
5+
6+
val g = () => return f
7+
8+
if x <= 0 then g()
9+
10+
(a: Int) => a * a + x
11+
12+
object B:
13+
val n = A.foo(-10)(20)

0 commit comments

Comments
 (0)