Skip to content

Commit 253cc63

Browse files
committed
Rework dead state elimination to happen as part of expr builder
If we're in the expr position of a block, the nested state generation can reuse the successor state. Also, output .dot graph of state machine in verbose mode. Sample: https://gist.github.com/88225478b11c118609b9348d61e13630 View with a local Graphviz install or http://graphviz.it/#/gallery/unix.gv Sample generated from late expansion of: val result = run( """ import scala.async.run.late.{autoawait,lateasync} case class FixedFoo(foo: Int) class Foobar(val foo: Int, val bar: Double) { def guard: Boolean = true @autoawait @lateasync def getValue = 4.2 @autoawait @lateasync def func(f: Any) = { ("": Any) match { case (x1, y1) if guard => x1.toString; y1.toString case (x2, y2) if guard => x2.toString; y2.toString case (x3, y3) if guard => x3.toString; y3.toString case (x4, y4) => getValue; x4.toString; y4.toString } } } object Test { @lateasync def test() = new Foobar(0, 0).func(4) } """)
1 parent 337e8ec commit 253cc63

File tree

2 files changed

+97
-59
lines changed

2 files changed

+97
-59
lines changed

src/main/scala/scala/async/internal/AsyncTransform.scala

+7-4
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ trait AsyncTransform {
7070
buildAsyncBlock(anfTree, symLookup)
7171
}
7272

73-
if(AsyncUtils.verbose)
74-
logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString))
75-
7673
val liftedFields: List[Tree] = liftables(asyncBlock.asyncStates)
7774

7875
// live variables analysis
@@ -114,10 +111,14 @@ trait AsyncTransform {
114111
futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }`
115112
else
116113
startStateMachine
114+
115+
if(AsyncUtils.verbose) {
116+
logDiagnostics(anfTree, asyncBlock, asyncBlock.asyncStates.map(_.toString))
117+
}
117118
cleanupContainsAwaitAttachments(result)
118119
}
119120

120-
def logDiagnostics(anfTree: Tree, states: Seq[String]): Unit = {
121+
def logDiagnostics(anfTree: Tree, block: AsyncBlock, states: Seq[String]): Unit = {
121122
def location = try {
122123
macroPos.source.path
123124
} catch {
@@ -129,6 +130,8 @@ trait AsyncTransform {
129130
AsyncUtils.vprintln(s"${c.macroApplication}")
130131
AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
131132
states foreach (s => AsyncUtils.vprintln(s))
133+
AsyncUtils.vprintln("===== DOT =====")
134+
AsyncUtils.vprintln(block.toDot)
132135
}
133136

134137
/**

src/main/scala/scala/async/internal/ExprBuilder.scala

+90-55
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ trait ExprBuilder {
7676
mkHandlerCase(state, stats)
7777

7878
override val toString: String =
79-
s"AsyncStateWithoutAwait #$state, nextStates = $nextStates"
79+
s"AsyncStateWithoutAwait #$state, nextStates = ${nextStates.toList}"
8080
}
8181

8282
/** A sequence of statements that concludes with an `await` call. The `onComplete`
8383
* handler will unconditionally transition to `nextState`.
8484
*/
85-
final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, onCompleteState: Int, nextState: Int,
85+
final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, val onCompleteState: Int, nextState: Int,
8686
val awaitable: Awaitable, symLookup: SymLookup)
8787
extends AsyncState {
8888

@@ -268,11 +268,11 @@ trait ExprBuilder {
268268
}
269269

270270
// populate asyncStates
271-
def add(stat: Tree): Unit = stat match {
271+
def add(stat: Tree, afterState: Option[Int] = None): Unit = stat match {
272272
// the val name = await(..) pattern
273273
case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
274274
val onCompleteState = nextState()
275-
val afterAwaitState = nextState()
275+
val afterAwaitState = afterState.getOrElse(nextState())
276276
val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd)
277277
asyncStates += stateBuilder.resultWithAwait(awaitable, onCompleteState, afterAwaitState) // complete with await
278278
currState = afterAwaitState
@@ -283,7 +283,7 @@ trait ExprBuilder {
283283

284284
val thenStartState = nextState()
285285
val elseStartState = nextState()
286-
val afterIfState = nextState()
286+
val afterIfState = afterState.getOrElse(nextState())
287287

288288
asyncStates +=
289289
// the two Int arguments are the start state of the then branch and the else branch, respectively
@@ -305,7 +305,7 @@ trait ExprBuilder {
305305
java.util.Arrays.setAll(caseStates, new IntUnaryOperator {
306306
override def applyAsInt(operand: Int): Int = nextState()
307307
})
308-
val afterMatchState = nextState()
308+
val afterMatchState = afterState.getOrElse(nextState())
309309

310310
asyncStates +=
311311
stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup)
@@ -323,15 +323,16 @@ trait ExprBuilder {
323323
if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) =>
324324

325325
val startLabelState = stateIdForLabel(ld.symbol)
326-
val afterLabelState = nextState()
326+
val afterLabelState = afterState.getOrElse(nextState())
327327
asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
328328
labelDefStates(ld.symbol) = startLabelState
329329
val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
330330
asyncStates ++= builder.asyncStates
331331
currState = afterLabelState
332332
stateBuilder = new AsyncStateBuilder(currState, symLookup)
333333
case b @ Block(stats, expr) =>
334-
(stats :+ expr) foreach (add)
334+
for (stat <- stats) add(stat)
335+
add(expr, afterState = Some(endState))
335336
case _ =>
336337
checkForUnsupportedAwait(stat)
337338
stateBuilder += stat
@@ -345,6 +346,8 @@ trait ExprBuilder {
345346
def asyncStates: List[AsyncState]
346347

347348
def onCompleteHandler[T: WeakTypeTag]: Tree
349+
350+
def toDot: String
348351
}
349352

350353
case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) {
@@ -369,7 +372,78 @@ trait ExprBuilder {
369372
val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup)
370373

371374
new AsyncBlock {
372-
def asyncStates = blockBuilder.asyncStates.toList
375+
val switchIds = mutable.AnyRefMap[Integer, Integer]()
376+
377+
// render with http://graphviz.it/#/new
378+
def toDot: String = {
379+
val states = asyncStates
380+
def toHtmlLabel(label: String, preText: String, builder: StringBuilder): Unit = {
381+
builder.append("<b>").append(label).append("</b>").append("<br/>")
382+
builder.append("<font face=\"Courier\">")
383+
preText.split("\n").foreach {
384+
(line: String) =>
385+
builder.append("<br/>")
386+
builder.append(line.replaceAllLiterally("\"", "&quot;").replaceAllLiterally("<", "&lt;").replaceAllLiterally(">", "&gt;"))
387+
}
388+
builder.append("</font>")
389+
}
390+
val dotBuilder = new StringBuilder()
391+
dotBuilder.append("digraph {\n")
392+
def stateLabel(s: Int) = {
393+
if (s == 0) "INITIAL" else if (s == Int.MaxValue) "TERMINAL" else switchIds.getOrElse(s, s).toString
394+
}
395+
val length = asyncStates.size
396+
for ((state, i) <- asyncStates.zipWithIndex) {
397+
dotBuilder.append(s"""${stateLabel(state.state)} [label=""").append("<")
398+
if (i != length - 1) {
399+
val CaseDef(_, _, body) = state.mkHandlerCaseForState
400+
toHtmlLabel(stateLabel(state.state), showCode(body), dotBuilder)
401+
} else {
402+
toHtmlLabel(stateLabel(state.state), state.allStats.map(showCode(_)).mkString("\n"), dotBuilder)
403+
}
404+
dotBuilder.append("> ]\n")
405+
}
406+
for (state <- states; succ <- state.nextStates) {
407+
dotBuilder.append(s"""${stateLabel(state.state)} -> ${stateLabel(succ)}""")
408+
dotBuilder.append("\n")
409+
}
410+
dotBuilder.append("}\n")
411+
dotBuilder.toString
412+
}
413+
414+
lazy val asyncStates: List[AsyncState] = filterStates
415+
416+
def filterStates = {
417+
val all = blockBuilder.asyncStates.toList
418+
val (initial :: rest) = all
419+
val map = all.iterator.map(x => (x.state, x)).toMap
420+
var seen = mutable.HashSet[Int]()
421+
def loop(state: AsyncState): Unit = {
422+
seen.add(state.state)
423+
for (i <- state.nextStates) {
424+
if (i != Int.MaxValue && !seen.contains(i)) {
425+
loop(map(i))
426+
}
427+
}
428+
}
429+
loop(initial)
430+
val live = rest.filter(state => seen(state.state))
431+
var nextSwitchId = 0
432+
(initial :: live).foreach { state =>
433+
val switchId = nextSwitchId
434+
switchIds(state.state) = switchId
435+
nextSwitchId += 1
436+
state match {
437+
case state: AsyncStateWithAwait =>
438+
val switchId = nextSwitchId
439+
switchIds(state.onCompleteState) = switchId
440+
nextSwitchId += 1
441+
case _ =>
442+
}
443+
}
444+
initial :: live
445+
446+
}
373447

374448
def mkCombinedHandlerCases[T: WeakTypeTag]: List[CaseDef] = {
375449
val caseForLastState: CaseDef = {
@@ -413,7 +487,7 @@ trait ExprBuilder {
413487
val stateMemberSymbol = symLookup.stateMachineMember(name.state)
414488
val stateMemberRef = symLookup.memberRef(name.state)
415489
val body = Match(stateMemberRef, mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ++ List(CaseDef(Ident(nme.WILDCARD), EmptyTree, Throw(Apply(Select(New(Ident(defn.IllegalStateExceptionClass)), termNames.CONSTRUCTOR), List())))))
416-
val body1 = eliminateDeadStates(body)
490+
val body1 = compactStates(body)
417491

418492
maybeTry(
419493
body1,
@@ -432,48 +506,12 @@ trait ExprBuilder {
432506
})), EmptyTree)
433507
}
434508

435-
// Identify dead states: `case <id> => { state = nextId; (); (); ... }, eliminated, and compact state ids to
436-
// enable emission of a tableswitch.
437-
private def eliminateDeadStates(m: Match): Tree = {
438-
object DeadState {
439-
private val liveStates = mutable.AnyRefMap[Integer, Integer]()
440-
private val deadStates = mutable.AnyRefMap[Integer, Integer]()
441-
private var compactedStateId = 1
442-
for (CaseDef(Literal(Constant(stateId: Integer)), EmptyTree, body) <- m.cases) {
443-
body match {
444-
case _ if (stateId == 0) => liveStates(stateId) = stateId
445-
case Block(Assign(_, Literal(Constant(nextState: Integer))) :: rest, expr) if (expr :: rest).forall(t => isLiteralUnit(t)) =>
446-
deadStates(stateId) = nextState
447-
case _ =>
448-
liveStates(stateId) = compactedStateId
449-
compactedStateId += 1
450-
}
451-
}
452-
if (deadStates.nonEmpty)
453-
AsyncUtils.vprintln(s"${deadStates.size} dead states eliminated")
454-
def isDead(i: Integer) = deadStates.contains(i)
455-
def translatedStateId(i: Integer, tree: Tree): Integer = {
456-
def chaseDead(i: Integer): Integer = {
457-
val replacement = deadStates.getOrNull(i)
458-
if (replacement == null) i
459-
else chaseDead(replacement)
460-
}
461-
462-
val live = chaseDead(i)
463-
liveStates.get(live) match {
464-
case Some(x) => x
465-
case None => sys.error(s"$live, $liveStates \n$deadStates\n$m\n\n====\n$tree")
466-
}
467-
}
468-
}
509+
private def compactStates(m: Match): Tree = {
469510
val stateMemberSymbol = symLookup.stateMachineMember(name.state)
470-
// - remove CaseDef-s for dead states
471-
// - rewrite state transitions to dead states to instead transition to the
472-
// non-dead successor.
473-
val elimDeadStateTransform = new Transformer {
511+
val compactStateTransform = new Transformer {
474512
override def transform(tree: Tree): Tree = tree match {
475513
case as @ Assign(lhs, Literal(Constant(i: Integer))) if lhs.symbol == stateMemberSymbol =>
476-
val replacement = DeadState.translatedStateId(i, as)
514+
val replacement = switchIds(i)
477515
treeCopy.Assign(tree, lhs, Literal(Constant(replacement)))
478516
case _: Match | _: CaseDef | _: Block | _: If =>
479517
super.transform(tree)
@@ -482,12 +520,9 @@ trait ExprBuilder {
482520
}
483521
val cases1 = m.cases.flatMap {
484522
case cd @ CaseDef(Literal(Constant(i: Integer)), EmptyTree, rhs) =>
485-
if (DeadState.isDead(i)) Nil
486-
else {
487-
val replacement = DeadState.translatedStateId(i, cd)
488-
val rhs1 = elimDeadStateTransform.transform(rhs)
489-
treeCopy.CaseDef(cd, Literal(Constant(replacement)), EmptyTree, rhs1) :: Nil
490-
}
523+
val replacement = switchIds(i)
524+
val rhs1 = compactStateTransform.transform(rhs)
525+
treeCopy.CaseDef(cd, Literal(Constant(replacement)), EmptyTree, rhs1) :: Nil
491526
case x => x :: Nil
492527
}
493528
treeCopy.Match(m, m.selector, cases1)

0 commit comments

Comments
 (0)