Skip to content

Commit c3bc662

Browse files
Backport "Some fixes for AnnotatedTypes mapping" to LTS (#22124)
Backports #19957 to the 3.3.5. PR submitted by the release tooling. [skip ci]
2 parents 4a27365 + 8b26184 commit c3bc662

23 files changed

+242
-18
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package dotty.tools.benchmarks
2+
3+
import org.openjdk.jmh.annotations.{Benchmark, BenchmarkMode, Fork, Level, Measurement, Mode as JMHMode, Param, Scope, Setup, State, Warmup}
4+
import java.util.concurrent.TimeUnit.SECONDS
5+
6+
import dotty.tools.dotc.{Driver, Run, Compiler}
7+
import dotty.tools.dotc.ast.{tpd, TreeTypeMap}, tpd.{Apply, Block, Tree, TreeAccumulator, TypeApply}
8+
import dotty.tools.dotc.core.Annotations.{Annotation, ConcreteAnnotation, EmptyAnnotation}
9+
import dotty.tools.dotc.core.Contexts.{ContextBase, Context, ctx, withMode}
10+
import dotty.tools.dotc.core.Mode
11+
import dotty.tools.dotc.core.Phases.Phase
12+
import dotty.tools.dotc.core.Symbols.{defn, mapSymbols, Symbol}
13+
import dotty.tools.dotc.core.Types.{AnnotatedType, NoType, SkolemType, TermRef, Type, TypeMap}
14+
import dotty.tools.dotc.parsing.Parser
15+
import dotty.tools.dotc.typer.TyperPhase
16+
17+
/** Measures the performance of mapping over annotated types.
18+
*
19+
* Run with: scala3-bench-micro / Jmh / run AnnotationsMappingBenchmark
20+
*/
21+
@Fork(value = 4)
22+
@Warmup(iterations = 4, time = 1, timeUnit = SECONDS)
23+
@Measurement(iterations = 4, time = 1, timeUnit = SECONDS)
24+
@BenchmarkMode(Array(JMHMode.Throughput))
25+
@State(Scope.Thread)
26+
class AnnotationsMappingBenchmark:
27+
var tp: Type = null
28+
var specialIntTp: Type = null
29+
var context: Context = null
30+
var typeFunction: Context ?=> Type => Type = null
31+
var typeMap: TypeMap = null
32+
33+
@Param(Array("v1", "v2", "v3", "v4"))
34+
var valName: String = null
35+
36+
@Param(Array("id", "mapInts"))
37+
var typeFunctionName: String = null
38+
39+
@Setup(Level.Iteration)
40+
def setup(): Unit =
41+
val testPhase =
42+
new Phase:
43+
final override def phaseName = "testPhase"
44+
final override def run(using ctx: Context): Unit =
45+
val pkg = ctx.compilationUnit.tpdTree.symbol
46+
tp = pkg.requiredClass("Test").requiredValueRef(valName).underlying
47+
specialIntTp = pkg.requiredClass("Test").requiredType("SpecialInt").typeRef
48+
context = ctx
49+
50+
val compiler =
51+
new Compiler:
52+
private final val baseCompiler = new Compiler()
53+
final override def phases = List(List(Parser()), List(TyperPhase()), List(testPhase))
54+
55+
val driver =
56+
new Driver:
57+
final override def newCompiler(using Context): Compiler = compiler
58+
59+
driver.process(Array("-classpath", System.getProperty("BENCH_CLASS_PATH"), "tests/someAnnotatedTypes.scala"))
60+
61+
typeFunction =
62+
typeFunctionName match
63+
case "id" => tp => tp
64+
case "mapInts" => tp => (if tp frozen_=:= defn.IntType then specialIntTp else tp)
65+
case _ => throw new IllegalArgumentException(s"Unknown type function: $typeFunctionName")
66+
67+
typeMap =
68+
new TypeMap(using context):
69+
final override def apply(tp: Type): Type = typeFunction(mapOver(tp))
70+
71+
@Benchmark def applyTypeMap() = typeMap.apply(tp)

bench-micro/src/main/scala/dotty/tools/benchmarks/lazyvals/ContendedInitialization.scala

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package dotty.tools.benchmarks.lazyvals
22

3+
import compiletime.uninitialized
34
import org.openjdk.jmh.annotations._
45
import LazyVals.LazyHolder
56
import org.openjdk.jmh.infra.Blackhole
@@ -16,12 +17,12 @@ import java.util.concurrent.{Executors, ExecutorService}
1617
class ContendedInitialization {
1718

1819
@Param(Array("2000000", "5000000"))
19-
var size: Int = _
20+
var size: Int = uninitialized
2021

2122
@Param(Array("2", "4", "8"))
22-
var nThreads: Int = _
23+
var nThreads: Int = uninitialized
2324

24-
var executor: ExecutorService = _
25+
var executor: ExecutorService = uninitialized
2526

2627
@Setup
2728
def prepare: Unit = {

bench-micro/src/main/scala/dotty/tools/benchmarks/lazyvals/InitializedAccess.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package dotty.tools.benchmarks.lazyvals
22

3+
import compiletime.uninitialized
34
import org.openjdk.jmh.annotations._
45
import LazyVals.LazyHolder
56
import org.openjdk.jmh.infra.Blackhole
@@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
1415
@State(Scope.Benchmark)
1516
class InitializedAccess {
1617

17-
var holder: LazyHolder = _
18+
var holder: LazyHolder = uninitialized
1819

1920
@Setup
2021
def prepare: Unit = {

bench-micro/src/main/scala/dotty/tools/benchmarks/lazyvals/InitializedAccessAny.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package dotty.tools.benchmarks.lazyvals
22

3+
import compiletime.uninitialized
34
import org.openjdk.jmh.annotations._
45
import LazyVals.LazyAnyHolder
56
import org.openjdk.jmh.infra.Blackhole
@@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
1415
@State(Scope.Benchmark)
1516
class InitializedAccessAny {
1617

17-
var holder: LazyAnyHolder = _
18+
var holder: LazyAnyHolder = uninitialized
1819

1920
@Setup
2021
def prepare: Unit = {

bench-micro/src/main/scala/dotty/tools/benchmarks/lazyvals/InitializedAccessGeneric.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package dotty.tools.benchmarks.lazyvals
22

3+
import compiletime.uninitialized
34
import org.openjdk.jmh.annotations._
45
import LazyVals.LazyGenericHolder
56
import org.openjdk.jmh.infra.Blackhole
@@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
1415
@State(Scope.Benchmark)
1516
class InitializedAccessGeneric {
1617

17-
var holder: LazyGenericHolder[String] = _
18+
var holder: LazyGenericHolder[String] = uninitialized
1819

1920
@Setup
2021
def prepare: Unit = {

bench-micro/src/main/scala/dotty/tools/benchmarks/lazyvals/InitializedAccessInt.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package dotty.tools.benchmarks.lazyvals
22

3+
import compiletime.uninitialized
34
import org.openjdk.jmh.annotations.*
45
import org.openjdk.jmh.infra.Blackhole
56
import LazyVals.LazyIntHolder
@@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
1415
@State(Scope.Benchmark)
1516
class InitializedAccessInt {
1617

17-
var holder: LazyIntHolder = _
18+
var holder: LazyIntHolder = uninitialized
1819

1920
@Setup
2021
def prepare: Unit = {

bench-micro/src/main/scala/dotty/tools/benchmarks/lazyvals/InitializedAccessMultiple.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package dotty.tools.benchmarks.lazyvals
22

3+
import compiletime.uninitialized
34
import org.openjdk.jmh.annotations._
45
import LazyVals.LazyHolder
56
import org.openjdk.jmh.infra.Blackhole
@@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
1415
@State(Scope.Benchmark)
1516
class InitializedAccessMultiple {
1617

17-
var holders: Array[LazyHolder] = _
18+
var holders: Array[LazyHolder] = uninitialized
1819

1920
@Setup
2021
def prepare: Unit = {

bench-micro/src/main/scala/dotty/tools/benchmarks/lazyvals/InitializedAccessString.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package dotty.tools.benchmarks.lazyvals
22

3+
import compiletime.uninitialized
34
import org.openjdk.jmh.annotations._
45
import LazyVals.LazyStringHolder
56
import org.openjdk.jmh.infra.Blackhole
@@ -14,7 +15,7 @@ import java.util.concurrent.TimeUnit
1415
@State(Scope.Benchmark)
1516
class InitializedAccessString {
1617

17-
var holder: LazyStringHolder = _
18+
var holder: LazyStringHolder = uninitialized
1819

1920
@Setup
2021
def prepare: Unit = {
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
class Test:
2+
class FlagAnnot extends annotation.StaticAnnotation
3+
class StringAnnot(val s: String) extends annotation.StaticAnnotation
4+
class LambdaAnnot(val f: Int => Boolean) extends annotation.StaticAnnotation
5+
6+
type SpecialInt <: Int
7+
8+
val v1: Int @FlagAnnot = 42
9+
10+
val v2: Int @StringAnnot("hello") = 42
11+
12+
val v3: Int @LambdaAnnot(it => it == 42) = 42
13+
14+
val v4: Int @LambdaAnnot(it => {
15+
def g(x: Int, y: Int) = x - y + 5
16+
g(it, 7) * 2 == 80
17+
}) = 42
18+
19+
/*val v5: Int @LambdaAnnot(it => {
20+
class Foo(x: Int):
21+
def xPlus10 = x + 10
22+
def xPlus20 = x + 20
23+
def xPlus(y: Int) = x + y
24+
val foo = Foo(it)
25+
foo.xPlus10 - foo.xPlus20 + foo.xPlus(30) == 62
26+
}) = 42*/
27+
28+
def main(args: Array[String]): Unit = ???

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

+9-1
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,17 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
135135
loop(tree, Nil)
136136

137137
/** All term arguments of an application in a single flattened list */
138+
def allTermArguments(tree: Tree): List[Tree] = unsplice(tree) match {
139+
case Apply(fn, args) => allArguments(fn) ::: args
140+
case TypeApply(fn, args) => allArguments(fn)
141+
case Block(_, expr) => allArguments(expr)
142+
case _ => Nil
143+
}
144+
145+
/** All type and term arguments of an application in a single flattened list */
138146
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
139147
case Apply(fn, args) => allArguments(fn) ::: args
140-
case TypeApply(fn, _) => allArguments(fn)
148+
case TypeApply(fn, args) => allArguments(fn) ::: args
141149
case Block(_, expr) => allArguments(expr)
142150
case _ => Nil
143151
}

compiler/src/dotty/tools/dotc/core/Annotations.scala

+7-4
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ object Annotations {
3030
def derivedAnnotation(tree: Tree)(using Context): Annotation =
3131
if (tree eq this.tree) this else Annotation(tree)
3232

33-
/** All arguments to this annotation in a single flat list */
34-
def arguments(using Context): List[Tree] = tpd.allArguments(tree)
33+
/** All term arguments of this annotation in a single flat list */
34+
def arguments(using Context): List[Tree] = tpd.allTermArguments(tree)
3535

3636
def argument(i: Int)(using Context): Option[Tree] = {
3737
val args = arguments
@@ -54,15 +54,18 @@ object Annotations {
5454
* type, since ranges cannot be types of trees.
5555
*/
5656
def mapWith(tm: TypeMap)(using Context) =
57-
val args = arguments
57+
val args = tpd.allArguments(tree)
5858
if args.isEmpty then this
5959
else
60+
// Checks if `tm` would result in any change by applying it to types
61+
// inside the annotations' arguments and checking if the resulting types
62+
// are different.
6063
val findDiff = new TreeAccumulator[Type]:
6164
def apply(x: Type, tree: Tree)(using Context): Type =
6265
if tm.isRange(x) then x
6366
else
6467
val tp1 = tm(tree.tpe)
65-
foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree)
68+
foldOver(if !tp1.exists || (tp1 frozen_=:= tree.tpe) then x else tp1, tree)
6669
val diff = findDiff(NoType, args)
6770
if tm.isRange(diff) then EmptyAnnotation
6871
else if diff.exists then derivedAnnotation(tm.mapOver(tree))

compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ object PositionPickler:
3333
pickler: TastyPickler,
3434
addrOfTree: TreeToAddr,
3535
treeAnnots: untpd.MemberDef => List[tpd.Tree],
36+
typeAnnots: List[tpd.Tree],
3637
relativePathReference: String,
3738
source: SourceFile,
3839
roots: List[Tree],
@@ -136,6 +137,9 @@ object PositionPickler:
136137
}
137138
for (root <- roots)
138139
traverse(root, NoSource)
140+
141+
for annotTree <- typeAnnots do
142+
traverse(annotTree, NoSource)
139143
end picklePositions
140144
end PositionPickler
141145

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

+7
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class TreePickler(pickler: TastyPickler) {
3939
*/
4040
private val annotTrees = util.EqHashMap[untpd.MemberDef, mutable.ListBuffer[Tree]]()
4141

42+
/** A set of annotation trees appearing in annotated types.
43+
*/
44+
private val annotatedTypeTrees = mutable.ListBuffer[Tree]()
45+
4246
/** A map from member definitions to their doc comments, so that later
4347
* parallel comment pickling does not need to access symbols of trees (which
4448
* would involve accessing symbols of named types and possibly changing phases
@@ -52,6 +56,8 @@ class TreePickler(pickler: TastyPickler) {
5256
val ts = annotTrees.lookup(tree)
5357
if ts == null then Nil else ts.toList
5458

59+
def typeAnnots: List[Tree] = annotatedTypeTrees.toList
60+
5561
def docString(tree: untpd.MemberDef): Option[Comment] =
5662
Option(docStrings.lookup(tree))
5763

@@ -262,6 +268,7 @@ class TreePickler(pickler: TastyPickler) {
262268
case tpe: AnnotatedType =>
263269
writeByte(ANNOTATEDtype)
264270
withLength { pickleType(tpe.parent, richTypes); pickleTree(tpe.annot.tree) }
271+
annotatedTypeTrees += tpe.annot.tree
265272
case tpe: AndType =>
266273
writeByte(ANDtype)
267274
withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) }

compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ object PickledQuotes {
223223
if tree.span.exists then
224224
val positionWarnings = new mutable.ListBuffer[Message]()
225225
val reference = ctx.settings.sourceroot.value
226-
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
226+
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
227227
ctx.compilationUnit.source, tree :: Nil, positionWarnings)
228228
positionWarnings.foreach(report.warning(_))
229229

compiler/src/dotty/tools/dotc/transform/Pickler.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class Pickler extends Phase {
9999
if tree.span.exists then
100100
val reference = ctx.settings.sourceroot.value
101101
PositionPickler.picklePositions(
102-
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
102+
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
103103
unit.source, tree :: Nil, positionWarnings,
104104
scratch.positionBuffer, scratch.pickledIndices)
105105

compiler/test/dotty/tools/dotc/printing/PrintingTest.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import scala.language.unsafeNulls
66

77
import vulpix.FileDiff
88
import vulpix.TestConfiguration
9-
import vulpix.TestConfiguration
9+
import vulpix.ParallelTesting
1010
import reporting.TestReporter
1111

1212
import java.io._
@@ -25,7 +25,9 @@ import java.io.File
2525
class PrintingTest {
2626

2727
def options(phase: String, flags: List[String]) =
28-
List(s"-Xprint:$phase", "-color:never", "-nowarn", "-classpath", TestConfiguration.basicClasspath) ::: flags
28+
val outDir = ParallelTesting.defaultOutputDir + "printing" + File.pathSeparator
29+
File(outDir).mkdirs()
30+
List(s"-Xprint:$phase", "-color:never", "-nowarn", "-d", outDir, "-classpath", TestConfiguration.basicClasspath) ::: flags
2931

3032
private def compileFile(path: JPath, phase: String): Boolean = {
3133
val baseFilePath = path.toString.stripSuffix(".scala")

tests/pos/annot-17939b.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.annotation.Annotation
2+
class myRefined(f: ? => Boolean) extends Annotation
3+
4+
def test(axes: Int) = true
5+
6+
trait Tensor:
7+
def mean(axes: Int): Int @myRefined(_ => test(axes))
8+
9+
class TensorImpl() extends Tensor:
10+
def mean(axes: Int) = ???

tests/pos/annot-18064.scala

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//> using options "-Xprint:typer"
2+
3+
class myAnnot[T]() extends annotation.Annotation
4+
5+
trait Tensor[T]:
6+
def add: Tensor[T] @myAnnot[T]()
7+
8+
class TensorImpl[A]() extends Tensor[A]:
9+
def add /* : Tensor[A] @myAnnot[A] */ = this

tests/pos/annot-5789.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class Annot[T] extends scala.annotation.Annotation
2+
3+
class D[T](val f: Int@Annot[T])
4+
5+
object A{
6+
def main(a:Array[String]) = {
7+
val c = new D[Int](1)
8+
c.f
9+
}
10+
}

0 commit comments

Comments
 (0)