Skip to content

Commit 82232ec

Browse files
committed
An overdue overhaul of macro internals.
- Avoid reset + retypecheck, instead hang onto the original types/symbols - Eliminated duplication between AsyncDefinitionUseAnalyzer and ExprBuilder - Instead, decide what do lift *after* running ExprBuilder - Account for transitive references local classes/objects and lift them as needed. - Make the execution context an regular implicit parameter of the macro - Fixes interaction with existential skolems and singleton types Fixes scala#6, scala#13, scala#16, scala#17, scala#19, scala#21.
1 parent d63b63f commit 82232ec

18 files changed

+1021
-998
lines changed

src/main/scala/scala/async/AnfTransform.scala

Lines changed: 213 additions & 237 deletions
Large diffs are not rendered by default.

src/main/scala/scala/async/Async.scala

Lines changed: 24 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ package scala.async
77
import scala.language.experimental.macros
88
import scala.reflect.macros.Context
99
import scala.reflect.internal.annotations.compileTimeOnly
10+
import scala.tools.nsc.Global
11+
import language.reflectiveCalls
12+
import scala.concurrent.ExecutionContext
1013

1114
object Async extends AsyncBase {
1215

@@ -15,18 +18,22 @@ object Async extends AsyncBase {
1518
lazy val futureSystem = ScalaConcurrentFutureSystem
1619
type FS = ScalaConcurrentFutureSystem.type
1720

18-
def async[T](body: T) = macro asyncImpl[T]
21+
def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T]
1922

20-
override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
23+
override def asyncImpl[T: c.WeakTypeTag](c: Context)
24+
(body: c.Expr[T])
25+
(execContext: c.Expr[futureSystem.ExecContext]): c.Expr[Future[T]] = {
26+
super.asyncImpl[T](c)(body)(execContext)
27+
}
2128
}
2229

2330
object AsyncId extends AsyncBase {
2431
lazy val futureSystem = IdentityFutureSystem
2532
type FS = IdentityFutureSystem.type
2633

27-
def async[T](body: T) = macro asyncImpl[T]
34+
def async[T](body: T) = macro asyncIdImpl[T]
2835

29-
override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = super.asyncImpl[T](c)(body)
36+
def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit)
3037
}
3138

3239
/**
@@ -62,124 +69,26 @@ abstract class AsyncBase {
6269

6370
protected[async] def fallbackEnabled = false
6471

65-
def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
72+
def asyncImpl[T: c.WeakTypeTag](c: Context)
73+
(body: c.Expr[T])
74+
(execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
6675
import c.universe._
6776

68-
val analyzer = AsyncAnalysis[c.type](c, this)
69-
val utils = TransformUtils[c.type](c)
70-
import utils.{name, defn}
71-
72-
analyzer.reportUnsupportedAwaits(body.tree)
73-
74-
// Transform to A-normal form:
75-
// - no await calls in qualifiers or arguments,
76-
// - if/match only used in statement position.
77-
val anfTree: Block = {
78-
val anf = AnfTransform[c.type](c)
79-
val restored = utils.restorePatternMatchingFunctions(body.tree)
80-
val stats1 :+ expr1 = anf(restored)
81-
val block = Block(stats1, expr1)
82-
c.typeCheck(block).asInstanceOf[Block]
83-
}
84-
85-
// Analyze the block to find locals that will be accessed from multiple
86-
// states of our generated state machine, e.g. a value assigned before
87-
// an `await` and read afterwards.
88-
val renameMap: Map[Symbol, TermName] = {
89-
analyzer.defTreesUsedInSubsequentStates(anfTree).map {
90-
vd =>
91-
(vd.symbol, name.fresh(vd.name.toTermName))
92-
}.toMap
93-
}
94-
95-
val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree)
96-
import builder.futureSystemOps
97-
val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap)
98-
import asyncBlock.asyncStates
99-
logDiagnostics(c)(anfTree, asyncStates.map(_.toString))
100-
101-
// Important to retain the original declaration order here!
102-
val localVarTrees = anfTree.collect {
103-
case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol =>
104-
utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol))
105-
case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol =>
106-
DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap)))
107-
}
108-
109-
val onCompleteHandler = {
110-
Function(
111-
List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)),
112-
asyncBlock.onCompleteHandler)
113-
}
114-
val resumeFunTree = asyncBlock.resumeFunTree[T]
115-
116-
val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType))
117-
118-
lazy val stateMachine: ClassDef = {
119-
val body: List[Tree] = {
120-
val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
121-
val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree)
122-
val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree)
123-
val applyDefDef: DefDef = {
124-
val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
125-
val applyBody = asyncBlock.onCompleteHandler
126-
DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), applyBody)
127-
}
128-
val apply0DefDef: DefDef = {
129-
// We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
130-
// See SI-1247 for the the optimization that avoids creatio
131-
val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
132-
val applyBody = asyncBlock.onCompleteHandler
133-
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
134-
}
135-
List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef)
136-
}
137-
val template = {
138-
Template(List(stateMachineType), emptyValDef, body)
139-
}
140-
ClassDef(NoMods, name.stateMachineT, Nil, template)
141-
}
142-
143-
def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
144-
145-
val code: c.Expr[futureSystem.Fut[T]] = {
146-
val isSimple = asyncStates.size == 1
147-
val tree =
148-
if (isSimple)
149-
Block(Nil, futureSystemOps.spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
150-
else {
151-
Block(List[Tree](
152-
stateMachine,
153-
ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(name.stateMachineT)), nme.CONSTRUCTOR), Nil)),
154-
futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil))
155-
),
156-
futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree)
157-
}
158-
c.Expr[futureSystem.Fut[T]](tree)
159-
}
160-
161-
AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}")
162-
code
163-
}
77+
val asyncMacro = AsyncMacro(c, futureSystem)
78+
79+
val code = asyncMacro.asyncTransform[T](
80+
body.tree.asInstanceOf[asyncMacro.global.Tree],
81+
execContext.tree.asInstanceOf[asyncMacro.global.Tree],
82+
fallbackEnabled)(implicitly[c.WeakTypeTag[T]].asInstanceOf[asyncMacro.global.WeakTypeTag[T]])
16483

165-
def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {
166-
def location = try {
167-
c.macroApplication.pos.source.path
168-
} catch {
169-
case _: UnsupportedOperationException =>
170-
c.macroApplication.pos.toString
171-
}
172-
173-
AsyncUtils.vprintln(s"In file '$location':")
174-
AsyncUtils.vprintln(s"${c.macroApplication}")
175-
AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
176-
states foreach (s => AsyncUtils.vprintln(s))
84+
AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}")
85+
c.Expr[futureSystem.Fut[T]](code.asInstanceOf[Tree])
17786
}
17887
}
17988

18089
/** Internal class used by the `async` macro; should not be manually extended by client code */
18190
abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) {
182-
def result$async: Result
91+
def result: Result
18392

184-
def execContext$async: EC
93+
def execContext: EC
18594
}

src/main/scala/scala/async/AsyncAnalysis.scala

Lines changed: 16 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -7,60 +7,37 @@ package scala.async
77
import scala.reflect.macros.Context
88
import scala.collection.mutable
99

10-
private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) {
11-
import c.universe._
10+
trait AsyncAnalysis {
11+
self: AsyncMacro =>
1212

13-
val utils = TransformUtils[c.type](c)
14-
15-
import utils._
13+
import global._
1614

1715
/**
1816
* Analyze the contents of an `async` block in order to:
1917
* - Report unsupported `await` calls under nested templates, functions, by-name arguments.
2018
*
2119
* Must be called on the original tree, not on the ANF transformed tree.
2220
*/
23-
def reportUnsupportedAwaits(tree: Tree): Boolean = {
24-
val analyzer = new UnsupportedAwaitAnalyzer
21+
def reportUnsupportedAwaits(tree: Tree, report: Boolean): Boolean = {
22+
val analyzer = new UnsupportedAwaitAnalyzer(report)
2523
analyzer.traverse(tree)
2624
analyzer.hasUnsupportedAwaits
2725
}
2826

29-
/**
30-
* Analyze the contents of an `async` block in order to:
31-
* - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
32-
* on whether or not they are accessed only from a single state.
33-
*
34-
* Must be called on the ANF transformed tree.
35-
*/
36-
def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = {
37-
val analyzer = new AsyncDefinitionUseAnalyzer
38-
analyzer.traverse(tree)
39-
val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct
40-
liftable
41-
}
42-
43-
private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
27+
private class UnsupportedAwaitAnalyzer(report: Boolean) extends AsyncTraverser {
4428
var hasUnsupportedAwaits = false
4529

4630
override def nestedClass(classDef: ClassDef) {
47-
val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
48-
if (!reportUnsupportedAwait(classDef, s"nested $kind")) {
49-
// do not allow local class definitions, because of SI-5467 (specific to case classes, though)
50-
if (classDef.symbol.asClass.isCaseClass)
51-
c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block")
52-
}
31+
val kind = if (classDef.symbol.isTrait) "trait" else "class"
32+
reportUnsupportedAwait(classDef, s"nested ${kind}")
5333
}
5434

5535
override def nestedModule(module: ModuleDef) {
56-
if (!reportUnsupportedAwait(module, "nested object")) {
57-
// local object definitions lead to spurious type errors (because of resetAllAttrs?)
58-
c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block")
59-
}
36+
reportUnsupportedAwait(module, "nested object")
6037
}
6138

62-
override def nestedMethod(module: DefDef) {
63-
reportUnsupportedAwait(module, "nested method")
39+
override def nestedMethod(defDef: DefDef) {
40+
reportUnsupportedAwait(defDef, "nested method")
6441
}
6542

6643
override def byNameArgument(arg: Tree) {
@@ -82,9 +59,10 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy
8259
reportUnsupportedAwait(tree, "try/catch")
8360
super.traverse(tree)
8461
case Return(_) =>
85-
c.abort(tree.pos, "return is illegal within a async block")
62+
abort(tree.pos, "return is illegal within a async block")
8663
case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) =>
87-
c.abort(tree.pos, "lazy vals are illegal within an async block")
64+
// TODO lift this restriction
65+
abort(tree.pos, "lazy vals are illegal within an async block")
8866
case _ =>
8967
super.traverse(tree)
9068
}
@@ -106,87 +84,8 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy
10684

10785
private def reportError(pos: Position, msg: String) {
10886
hasUnsupportedAwaits = true
109-
if (!asyncBase.fallbackEnabled)
110-
c.error(pos, msg)
87+
if (report)
88+
abort(pos, msg)
11189
}
11290
}
113-
114-
private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
115-
private var chunkId = 0
116-
117-
private def nextChunk() = chunkId += 1
118-
119-
private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
120-
121-
val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set()
122-
val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set()
123-
124-
override def nestedMethod(defDef: DefDef) {
125-
nestedMethodsToLift += defDef
126-
markReferencedVals(defDef)
127-
}
128-
129-
override def function(function: Function) {
130-
markReferencedVals(function)
131-
}
132-
133-
override def patMatFunction(tree: Match) {
134-
markReferencedVals(tree)
135-
}
136-
137-
private def markReferencedVals(tree: Tree) {
138-
tree foreach {
139-
case rt: RefTree =>
140-
valDefChunkId.get(rt.symbol) match {
141-
case Some((vd, defChunkId)) =>
142-
valDefsToLift += vd // lift all vals referred to by nested functions.
143-
case _ =>
144-
}
145-
case _ =>
146-
}
147-
}
148-
149-
override def traverse(tree: Tree) = {
150-
tree match {
151-
case If(cond, thenp, elsep) if tree exists isAwait =>
152-
traverseChunks(List(cond, thenp, elsep))
153-
case Match(selector, cases) if tree exists isAwait =>
154-
traverseChunks(selector :: cases)
155-
case LabelDef(name, params, rhs) if rhs exists isAwait =>
156-
traverseChunks(rhs :: Nil)
157-
case Apply(fun, args) if isAwait(fun) =>
158-
super.traverse(tree)
159-
nextChunk()
160-
case vd: ValDef =>
161-
super.traverse(tree)
162-
valDefChunkId += (vd.symbol -> (vd -> chunkId))
163-
val isPatternBinder = vd.name.toString.contains(name.bindSuffix)
164-
if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd
165-
case as: Assign =>
166-
if (isAwait(as.rhs)) {
167-
assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
168-
169-
// TODO test the orElse case, try to remove the restriction.
170-
val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}"))
171-
valDefsToLift += vd
172-
}
173-
super.traverse(tree)
174-
case rt: RefTree =>
175-
valDefChunkId.get(rt.symbol) match {
176-
case Some((vd, defChunkId)) if defChunkId != chunkId =>
177-
valDefsToLift += vd
178-
case _ =>
179-
}
180-
super.traverse(tree)
181-
case _ => super.traverse(tree)
182-
}
183-
}
184-
185-
private def traverseChunks(trees: List[Tree]) {
186-
trees.foreach {
187-
t => traverse(t); nextChunk()
188-
}
189-
}
190-
}
191-
19291
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package scala.async
2+
3+
import scala.tools.nsc.Global
4+
import scala.tools.nsc.transform.TypingTransformers
5+
6+
object AsyncMacro {
7+
def apply(c: reflect.macros.Context, futureSystem0: FutureSystem): AsyncMacro = {
8+
import language.reflectiveCalls
9+
val powerContext = c.asInstanceOf[c.type {val universe: Global; val callsiteTyper: universe.analyzer.Typer}]
10+
new AsyncMacro {
11+
val global: powerContext.universe.type = powerContext.universe
12+
val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper
13+
val futureSystem: futureSystem0.type = futureSystem0
14+
val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem0.mkOps(global)
15+
val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree]
16+
}
17+
}
18+
}
19+
20+
private[async] trait AsyncMacro
21+
extends TypingTransformers
22+
with AnfTransform with TransformUtils with Lifter
23+
with ExprBuilder with AsyncTransform with AsyncAnalysis {
24+
25+
val global: Global
26+
val callSiteTyper: global.analyzer.Typer
27+
val macroApplication: global.Tree
28+
29+
}

0 commit comments

Comments
 (0)