Skip to content

Commit 675b2ac

Browse files
authored
Merge pull request #4949 from dotty-staging/fix-#4947
Fix #4947: Do not replace positions of inlined arguments
2 parents dab03fd + 4c51bc3 commit 675b2ac

File tree

12 files changed

+451
-4
lines changed

12 files changed

+451
-4
lines changed

compiler/src/dotty/tools/dotc/typer/Inliner.scala

+29-4
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,37 @@ object Inliner {
149149

150150
/** Replace `Inlined` node by a block that contains its bindings and expansion */
151151
def dropInlined(inlined: tpd.Inlined)(implicit ctx: Context): Tree = {
152-
val reposition = new TreeMap {
153-
override def transform(tree: Tree)(implicit ctx: Context): Tree = {
154-
super.transform(tree).withPos(inlined.call.pos)
152+
if (enclosingInlineds.nonEmpty) inlined // Remove in the outer most inlined call
153+
else {
154+
val inlinedAtPos = inlined.call.pos
155+
val callSourceFile = ctx.source.file
156+
157+
/** Removes all Inlined trees, replacing them with blocks.
158+
* Repositions all trees directly inside an inlined expansion of a non empty call to the position of the call.
159+
* Any tree directly inside an empty call (inlined in the inlined code) retains their position.
160+
*/
161+
class Reposition extends TreeMap {
162+
override def transform(tree: Tree)(implicit ctx: Context): Tree = {
163+
tree match {
164+
case tree: Inlined => transformInline(tree)
165+
case _ =>
166+
val transformed = super.transform(tree)
167+
enclosingInlineds match {
168+
case call :: _ if call.symbol.sourceFile != callSourceFile =>
169+
// Until we implement JSR-45, we cannot represent in output positions in other source files.
170+
// So, reposition inlined code from other files with the call position:
171+
transformed.withPos(inlinedAtPos)
172+
case _ => transformed
173+
}
174+
}
175+
}
176+
def transformInline(tree: tpd.Inlined)(implicit ctx: Context): Tree = {
177+
tpd.seq(transformSub(tree.bindings), transform(tree.expansion)(inlineContext(tree.call)))
178+
}
155179
}
180+
181+
(new Reposition).transformInline(inlined)
156182
}
157-
tpd.seq(inlined.bindings, reposition.transform(inlined.expansion))
158183
}
159184
}
160185

compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala

+1
Original file line numberDiff line numberDiff line change
@@ -335,4 +335,5 @@ class TestBCode extends DottyBytecodeTest {
335335
assert(!fooInvoke, "foo should not be called\n")
336336
}
337337
}
338+
338339
}

compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala

+244
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ package dotty.tools.backend.jvm
33
import org.junit.Assert._
44
import org.junit.Test
55

6+
import scala.tools.asm.Opcodes._
7+
8+
import scala.collection.JavaConverters._
9+
610
class InlineBytecodeTests extends DottyBytecodeTest {
711
import ASMConverters._
812
@Test def inlineUnit = {
@@ -37,4 +41,244 @@ class InlineBytecodeTests extends DottyBytecodeTest {
3741
diffInstructions(instructions2, instructions3))
3842
}
3943
}
44+
45+
@Test def i4947 = {
46+
val source = """class Foo {
47+
| transparent def track[T](f: => T): T = {
48+
| foo("tracking") // line 3
49+
| f // line 4
50+
| }
51+
| def main(args: Array[String]): Unit = { // line 6
52+
| track { // line 7
53+
| foo("abc") // line 8
54+
| track { // line 9
55+
| foo("inner") // line 10
56+
| }
57+
| } // line 11
58+
| }
59+
| def foo(str: String): Unit = ()
60+
|}
61+
""".stripMargin
62+
63+
checkBCode(source) { dir =>
64+
val clsIn = dir.lookupName("Foo.class", directory = false).input
65+
val clsNode = loadClassNode(clsIn, skipDebugInfo = false)
66+
67+
val track = clsNode.methods.asScala.find(_.name == "track")
68+
assert(track.isEmpty, "method `track` should have been erased")
69+
70+
val main = getMethod(clsNode, "main")
71+
val instructions = instructionsFromMethod(main)
72+
val expected =
73+
List(
74+
Label(0),
75+
LineNumber(6, Label(0)),
76+
LineNumber(3, Label(0)),
77+
VarOp(ALOAD, 0),
78+
Ldc(LDC, "tracking"),
79+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
80+
Label(6),
81+
LineNumber(8, Label(6)),
82+
VarOp(ALOAD, 0),
83+
Ldc(LDC, "abc"),
84+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
85+
Label(11),
86+
LineNumber(3, Label(11)),
87+
VarOp(ALOAD, 0),
88+
Ldc(LDC, "tracking"),
89+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
90+
Label(16),
91+
LineNumber(10, Label(16)),
92+
VarOp(ALOAD, 0),
93+
Ldc(LDC, "inner"),
94+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
95+
Op(RETURN),
96+
Label(22)
97+
)
98+
assert(instructions == expected,
99+
"`track` was not properly inlined in `main`\n" + diffInstructions(instructions, expected))
100+
101+
}
102+
}
103+
104+
@Test def i4947b = {
105+
val source = """class Foo {
106+
| transparent def track2[T](f: => T): T = {
107+
| foo("tracking2") // line 3
108+
| f // line 4
109+
| }
110+
| transparent def track[T](f: => T): T = {
111+
| foo("tracking") // line 7
112+
| track2 { // line 8
113+
| f // line 9
114+
| }
115+
| }
116+
| def main(args: Array[String]): Unit = { // line 12
117+
| track { // line 13
118+
| foo("abc") // line 14
119+
| }
120+
| }
121+
| def foo(str: String): Unit = ()
122+
|}
123+
""".stripMargin
124+
125+
checkBCode(source) { dir =>
126+
val clsIn = dir.lookupName("Foo.class", directory = false).input
127+
val clsNode = loadClassNode(clsIn, skipDebugInfo = false)
128+
129+
val track = clsNode.methods.asScala.find(_.name == "track")
130+
assert(track.isEmpty, "method `track` should have been erased")
131+
132+
val track2 = clsNode.methods.asScala.find(_.name == "track2")
133+
assert(track2.isEmpty, "method `track2` should have been erased")
134+
135+
val main = getMethod(clsNode, "main")
136+
val instructions = instructionsFromMethod(main)
137+
val expected =
138+
List(
139+
Label(0),
140+
LineNumber(12, Label(0)),
141+
LineNumber(7, Label(0)),
142+
VarOp(ALOAD, 0),
143+
Ldc(LDC, "tracking"),
144+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
145+
Label(6),
146+
LineNumber(3, Label(6)),
147+
VarOp(ALOAD, 0),
148+
Ldc(LDC, "tracking2"),
149+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
150+
Label(11),
151+
LineNumber(14, Label(11)),
152+
VarOp(ALOAD, 0),
153+
Ldc(LDC, "abc"),
154+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
155+
Op(RETURN),
156+
Label(17)
157+
)
158+
assert(instructions == expected,
159+
"`track` was not properly inlined in `main`\n" + diffInstructions(instructions, expected))
160+
161+
}
162+
}
163+
164+
@Test def i4947c = {
165+
val source = """class Foo {
166+
| transparent def track2[T](f: => T): T = {
167+
| foo("tracking2") // line 3
168+
| f // line 4
169+
| }
170+
| transparent def track[T](f: => T): T = {
171+
| track2 { // line 7
172+
| foo("fgh") // line 8
173+
| f // line 9
174+
| }
175+
| }
176+
| def main(args: Array[String]): Unit = { // line 12
177+
| track { // line 13
178+
| foo("abc") // line 14
179+
| }
180+
| }
181+
| def foo(str: String): Unit = ()
182+
|}
183+
""".stripMargin
184+
185+
checkBCode(source) { dir =>
186+
val clsIn = dir.lookupName("Foo.class", directory = false).input
187+
val clsNode = loadClassNode(clsIn, skipDebugInfo = false)
188+
189+
val track = clsNode.methods.asScala.find(_.name == "track")
190+
assert(track.isEmpty, "method `track` should have been erased")
191+
192+
val track2 = clsNode.methods.asScala.find(_.name == "track2")
193+
assert(track2.isEmpty, "method `track2` should have been erased")
194+
195+
val main = getMethod(clsNode, "main")
196+
val instructions = instructionsFromMethod(main)
197+
val expected =
198+
List(
199+
Label(0),
200+
LineNumber(12, Label(0)),
201+
LineNumber(3, Label(0)),
202+
VarOp(ALOAD, 0),
203+
Ldc(LDC, "tracking2"),
204+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
205+
Label(6),
206+
LineNumber(8, Label(6)),
207+
VarOp(ALOAD, 0),
208+
Ldc(LDC, "fgh"),
209+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
210+
Label(11),
211+
LineNumber(14, Label(11)),
212+
VarOp(ALOAD, 0),
213+
Ldc(LDC, "abc"),
214+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
215+
Op(RETURN),
216+
Label(17)
217+
)
218+
assert(instructions == expected,
219+
"`track` was not properly inlined in `main`\n" + diffInstructions(instructions, expected))
220+
221+
}
222+
}
223+
224+
@Test def i4947d = {
225+
val source = """class Foo {
226+
| transparent def track2[T](f: => T): T = {
227+
| foo("tracking2") // line 3
228+
| f // line 4
229+
| }
230+
| transparent def track[T](f: => T): T = {
231+
| track2 { // line 7
232+
| track2 { // line 8
233+
| f // line 9
234+
| }
235+
| }
236+
| }
237+
| def main(args: Array[String]): Unit = { // line 13
238+
| track { // line 14
239+
| foo("abc") // line 15
240+
| }
241+
| }
242+
| def foo(str: String): Unit = ()
243+
|}
244+
""".stripMargin
245+
246+
checkBCode(source) { dir =>
247+
val clsIn = dir.lookupName("Foo.class", directory = false).input
248+
val clsNode = loadClassNode(clsIn, skipDebugInfo = false)
249+
250+
val track = clsNode.methods.asScala.find(_.name == "track")
251+
assert(track.isEmpty, "method `track` should have been erased")
252+
253+
val track2 = clsNode.methods.asScala.find(_.name == "track2")
254+
assert(track2.isEmpty, "method `track2` should have been erased")
255+
256+
val main = getMethod(clsNode, "main")
257+
val instructions = instructionsFromMethod(main)
258+
val expected =
259+
List(
260+
Label(0),
261+
LineNumber(13, Label(0)),
262+
LineNumber(3, Label(0)),
263+
VarOp(ALOAD, 0),
264+
Ldc(LDC, "tracking2"),
265+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
266+
Label(6),
267+
LineNumber(3, Label(6)),
268+
VarOp(ALOAD, 0),
269+
Ldc(LDC, "tracking2"),
270+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
271+
Label(11),
272+
LineNumber(15, Label(11)),
273+
VarOp(ALOAD, 0),
274+
Ldc(LDC, "abc"),
275+
Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false),
276+
Op(RETURN),
277+
Label(17)
278+
)
279+
assert(instructions == expected,
280+
"`track` was not properly inlined in `main`\n" + diffInstructions(instructions, expected))
281+
282+
}
283+
}
40284
}

tests/run/i4947.check

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
track: Test$.main(i4947.scala:4)
2+
track: Test$.main(i4947.scala:5)
3+
main1: Test$.main(i4947.scala:15)
4+
main2: Test$.main(i4947.scala:16)

tests/run/i4947.scala

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
object Test {
2+
3+
transparent def track[T](f: => T): T = {
4+
printStack("track")
5+
printStack("track")
6+
f
7+
}
8+
9+
def printStack(tag: String): Unit = {
10+
println(tag + ": "+ new Exception().getStackTrace().apply(1))
11+
}
12+
13+
def main(args: Array[String]): Unit = {
14+
track {
15+
printStack("main1")
16+
printStack("main2")
17+
}
18+
}
19+
20+
}

tests/run/i4947a.check

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
track (i = 0): Test$.main(i4947a.scala:4)
2+
track (i = 0): Test$.main(i4947a.scala:5)
3+
track (i = 2): Test$.main(i4947a.scala:4)
4+
track (i = 2): Test$.main(i4947a.scala:5)
5+
main1 (i = -1): Test$.main(i4947a.scala:21)
6+
main2 (i = -1): Test$.main(i4947a.scala:22)
7+
track (i = 1): Test$.main(i4947a.scala:4)
8+
track (i = 1): Test$.main(i4947a.scala:5)
9+
main1 (i = -1): Test$.main(i4947a.scala:21)
10+
main2 (i = -1): Test$.main(i4947a.scala:22)
11+
track (i = 0): Test$.main(i4947a.scala:4)
12+
track (i = 0): Test$.main(i4947a.scala:5)
13+
main1 (i = -1): Test$.main(i4947a.scala:21)
14+
main2 (i = -1): Test$.main(i4947a.scala:22)

tests/run/i4947a.scala

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
object Test {
2+
3+
transparent def fact[T](transparent i: Int)(f: => T): Int = {
4+
printStack("track", i)
5+
printStack("track", i)
6+
f
7+
if (i == 0)
8+
1
9+
else {
10+
i * fact(i-1)(f)
11+
}
12+
}
13+
14+
def printStack(tag: String, i: Int): Unit = {
15+
println(s"$tag (i = $i): ${new Exception().getStackTrace().apply(1)}")
16+
}
17+
18+
def main(args: Array[String]): Unit = {
19+
fact(0) {
20+
fact(2) {
21+
printStack("main1", -1)
22+
printStack("main2", -1)
23+
}
24+
}
25+
}
26+
27+
}

0 commit comments

Comments
 (0)