Skip to content

Fix #502: Optimize Array.apply([...]) to [...] #6821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Compiler {
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]`
new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses
new TailRec, // Rewrite tail recursion to loops
new Mixin, // Expand trait fields and trait initializers
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,8 @@ class Definitions {
@threadUnsafe lazy val ClassTagType: TypeRef = ctx.requiredClassRef("scala.reflect.ClassTag")
def ClassTagClass(implicit ctx: Context): ClassSymbol = ClassTagType.symbol.asClass
def ClassTagModule(implicit ctx: Context): Symbol = ClassTagClass.companionModule
@threadUnsafe lazy val ClassTagModule_applyR: TermRef = ClassTagModule.requiredMethodRef(nme.apply)
def ClassTagModule_apply(implicit ctx: Context): Symbol = ClassTagModule_applyR.symbol

@threadUnsafe lazy val QuotedExprType: TypeRef = ctx.requiredClassRef("scala.quoted.Expr")
def QuotedExprClass(implicit ctx: Context): ClassSymbol = QuotedExprType.symbol.asClass
Expand Down
69 changes: 69 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/ArrayApply.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package dotty.tools.dotc
package transform

import core._
import MegaPhase._
import Contexts.Context
import Symbols._
import Types._
import StdNames._
import ast.Trees._
import dotty.tools.dotc.ast.tpd

import scala.reflect.ClassTag


/** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode.
*
* Transforms `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]`
*/
class ArrayApply extends MiniPhase {
import tpd._

override def phaseName: String = "arrayApply"

override def transformApply(tree: tpd.Apply)(implicit ctx: Context): tpd.Tree = {
if (tree.symbol.name == nme.apply && tree.symbol.owner == defn.ArrayModule) { // Is `Array.apply`
tree.args match {
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
seqLit

case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) =>
tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)

case _ =>
tree
}

} else tree
}

/** Only optimize when classtag if it is one of
* - `ClassTag.apply(classOf[XYZ])`
* - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ``
* - `ClassTag.XYZ` for primitive types
*/
private def elideClassTag(ct: Tree)(implicit ctx: Context): Boolean = ct match {
case Apply(_, rc :: Nil) if ct.symbol == defn.ClassTagModule_apply =>
rc match {
case _: Literal => true // ClassTag.apply(classOf[XYZ])
case rc: RefTree if rc.name == nme.TYPE_ =>
// ClassTag.apply(java.lang.XYZ.Type)
defn.ScalaBoxedClasses().contains(rc.symbol.maybeOwner.companionClass)
case _ => false
}
case Apply(ctm: RefTree, _) if ctm.symbol.maybeOwner.companionModule == defn.ClassTagModule =>
// ClassTag.XYZ
nme.ScalaValueNames.contains(ctm.name)
case _ => false
}

object StripAscription {
def unapply(tree: Tree)(implicit ctx: Context): Some[Tree] = tree match {
case Typed(expr, _) => unapply(expr)
case _ => Some(tree)
}
}
}
109 changes: 109 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package dotty.tools.backend.jvm

import org.junit.Test
import org.junit.Assert._

import scala.tools.asm.Opcodes._

class ArrayApplyOptTest extends DottyBytecodeTest {
import ASMConverters._

@Test def testArrayEmptyGenericApply= {
test("Array[String]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "java/lang/String"), Op(POP), Op(RETURN)))
test("Array[Unit]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "scala/runtime/BoxedUnit"), Op(POP), Op(RETURN)))
test("Array[Object]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "java/lang/Object"), Op(POP), Op(RETURN)))
test("Array[Boolean]()", newArray0Opcodes(T_BOOLEAN))
test("Array[Byte]()", newArray0Opcodes(T_BYTE))
test("Array[Short]()", newArray0Opcodes(T_SHORT))
test("Array[Int]()", newArray0Opcodes(T_INT))
test("Array[Long]()", newArray0Opcodes(T_LONG))
test("Array[Float]()", newArray0Opcodes(T_FLOAT))
test("Array[Double]()", newArray0Opcodes(T_DOUBLE))
test("Array[Char]()", newArray0Opcodes(T_CHAR))
test("Array[T]()", newArray0Opcodes(T_INT))
}

@Test def testArrayGenericApply= {
def opCodes(tpe: String) =
List(Op(ICONST_2), TypeOp(ANEWARRAY, tpe), Op(DUP), Op(ICONST_0), Ldc(LDC, "a"), Op(AASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, "b"), Op(AASTORE), Op(POP), Op(RETURN))
test("""Array("a", "b")""", opCodes("java/lang/String"))
test("""Array[Object]("a", "b")""", opCodes("java/lang/Object"))
}

@Test def testArrayApplyBoolean =
test("Array(true, false)", newArray2Opcodes(T_BOOLEAN, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(BASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_0), Op(BASTORE))))

@Test def testArrayApplyByte =
test("Array[Byte](1, 2)", newArray2Opcodes(T_BYTE, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(BASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(BASTORE))))

@Test def testArrayApplyShort =
test("Array[Short](1, 2)", newArray2Opcodes(T_SHORT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(SASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(SASTORE))))

@Test def testArrayApplyInt = {
test("Array(1, 2)", newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE))))
test("""Array[T](t, t)""", newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Field(GETSTATIC, "Foo$", "MODULE$", "LFoo$;"), Invoke(INVOKEVIRTUAL, "Foo$", "t", "()I", false), Op(IASTORE), Op(DUP), Op(ICONST_1), Field(GETSTATIC, "Foo$", "MODULE$", "LFoo$;"), Invoke(INVOKEVIRTUAL, "Foo$", "t", "()I", false), Op(IASTORE))))
}

@Test def testArrayApplyLong =
test("Array(2L, 3L)", newArray2Opcodes(T_LONG, List(Op(DUP), Op(ICONST_0), Ldc(LDC, 2), Op(LASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3), Op(LASTORE))))

@Test def testArrayApplyFloat =
test("Array(2.1f, 3.1f)", newArray2Opcodes(T_FLOAT, List(Op(DUP), Op(ICONST_0), Ldc(LDC, 2.1f), Op(FASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3.1f), Op(FASTORE))))

@Test def testArrayApplyDouble =
test("Array(2.2d, 3.2d)", newArray2Opcodes(T_DOUBLE, List(Op(DUP), Op(ICONST_0), Ldc(LDC, 2.2d), Op(DASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3.2d), Op(DASTORE))))

@Test def testArrayApplyChar =
test("Array('x', 'y')", newArray2Opcodes(T_CHAR, List(Op(DUP), Op(ICONST_0), IntOp(BIPUSH, 120), Op(CASTORE), Op(DUP), Op(ICONST_1), IntOp(BIPUSH, 121), Op(CASTORE))))

@Test def testArrayApplyUnit =
test("Array[Unit]((), ())", List(Op(ICONST_2), TypeOp(ANEWARRAY, "scala/runtime/BoxedUnit"), Op(DUP),
Op(ICONST_0), Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), Op(AASTORE), Op(DUP),
Op(ICONST_1), Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), Op(AASTORE), Op(POP), Op(RETURN)))

@Test def testArrayInlined = test(
"""{
| inline def array(xs: =>Int*): Array[Int] = Array(xs: _*)
| array(1, 2)
|}""".stripMargin,
newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE), TypeOp(CHECKCAST, "[I")))
)

@Test def testArrayInlined2 = test(
"""{
| inline def array(x: =>Int, xs: =>Int*): Array[Int] = Array(x, xs: _*)
| array(1, 2)
|}""".stripMargin,
newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE)))
)

private def newArray0Opcodes(tpe: Int, init: List[Any] = Nil): List[Any] =
Op(ICONST_0) :: IntOp(NEWARRAY, tpe) :: init ::: Op(POP) :: Op(RETURN) :: Nil

private def newArray2Opcodes(tpe: Int, init: List[Any] = Nil): List[Any] =
Op(ICONST_2) :: IntOp(NEWARRAY, tpe) :: init ::: Op(POP) :: Op(RETURN) :: Nil

private def test(code: String, expectedInstructions: List[Any])= {
val source =
s"""class Foo {
| import Foo._
| def test: Unit = $code
|}
|object Foo {
| opaque type T = Int
| def t: T = 1
|}
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Foo.class", directory = false).input
val clsNode = loadClassNode(clsIn)
val meth = getMethod(clsNode, "test")

val instructions = instructionsFromMethod(meth)

assertEquals(expectedInstructions, instructions)
}
}

}
3 changes: 3 additions & 0 deletions tests/run/i502.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Ok
foo
bar
16 changes: 16 additions & 0 deletions tests/run/i502.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import scala.reflect.ClassTag

object Test extends App {
Array[Int](1, 2)

try {
Array[Int](1, 2)(null)
???
} catch {
case _: NullPointerException => println("Ok")
}

Array[Int](1, 2)({println("foo"); the[ClassTag[Int]]})

Array[Int](1, 2)(ClassTag.apply({ println("bar"); classOf[Int]}))
}
6 changes: 6 additions & 0 deletions tests/run/t6611b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
object Test extends App {
val a = Array("1")
val a2 = Array(a: _*)
a2(0) = "2"
assert(a(0) == "1")
}