Skip to content

Fix #2903: Reduce the depth of trees generated in PatternMatcher #3575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 23, 2017

Conversation

nicolasstucki
Copy link
Contributor

  • Extract all match arguments before checking conditions on them like scalac does.
    This avoids an extra nested block for each match variable.
  • Merge conditions of nested if expressions if their else branch is the same.
    This optimization combined with the previous removes most of the nested ifs
    created to check the matched args.

@nicolasstucki
Copy link
Contributor Author

nicolasstucki commented Nov 28, 2017

This is the code that goes into the backend.

Current dotc

 def test(x: Object): Unit = 
      {
        case val x117: Object = x
        if x117.isInstanceOf[Foo01] then 
          {
            case val x154: Foo01 = x117.asInstanceOf[Foo01]
            case val x155: Foo01 = Foo01.unapply(x154)
            case val x156: Int = x155._1()
            if 1.==(x156) then 
              {
                case val x157: Int = x155._2()
                if 2.==(x157) then 
                  {
                    case val x158: Int = x155._3()
                    if 3.==(x158) then 
                      {
                        case val x159: Int = x155._4()
                        if 4.==(x159) then 
                          {
                            case val x160: Int = x155._5()
                            if 5.==(x160) then 
                              {
                                case val x161: Int = x155._6()
                                if 6.==(x161) then 
                                  {
                                    case val x162: Int = x155._7()
                                    if 7.==(x162) then 
                                      {
                                        case val x163: Int = x155._8()
                                        if 8.==(x163) then 
                                          {
                                            case val x164: Int = x155._9()
                                            if 9.==(x164) then 
                                              {
                                                case val x165: Int = x155._10()
                                                if 10.==(x165) then 
                                                  {
                                                    Test.stuff()
                                                  }
                                                 else {
                                                    <label> def case381() =  .... // Here are all other cases nested
                                                    case381()
                                                  }
                                              }
                                             else case381()
                                          }
                                         else case381()
                                      }
                                     else case381()
                                  }
                                 else case381()
                              }
                             else case381()
                          }
                         else case381()
                      }
                     else case381()
                  }
                 else case381()
              }
             else case381()
          }
         else case381()
      }
  }
def test(x: Object): Unit = {
      case <synthetic> val x1: Object = x;
      case615(){
        if (x1.$isInstanceOf[Foo01]())
          {
            <synthetic> val x2: Foo01 = (x1.$asInstanceOf[Foo01](): Foo01);
            {
              <synthetic> val p3: Int = x2.x1();
              <synthetic> val p4: Int = x2.x2();
              <synthetic> val p5: Int = x2.x3();
              <synthetic> val p6: Int = x2.x4();
              <synthetic> val p7: Int = x2.x5();
              <synthetic> val p8: Int = x2.x6();
              <synthetic> val p9: Int = x2.x7();
              <synthetic> val p10: Int = x2.x8();
              <synthetic> val p11: Int = x2.x9();
              <synthetic> val p12: Int = x2.x10();
                  if (1.==(p3))
                        if (2.==(p4))
                          if (3.==(p5))
                            if (4.==(p6))
                              if (5.==(p7))
                                if (6.==(p8))
                                  if (7.==(p9))
                                    if (8.==(p10))
                                      if (9.==(p11))
                                        if (10.==(p12))
                                          matchEnd614({
                                            Test.this.stuff();
                                            scala.runtime.BoxedUnit.UNIT
                                          })
                                        else
                                          case616()
                                      else
                                        case616()
                                    else
                                      case616()
                                  else
                                    case616()
                                else
                                  case616()
                              else
                                case616()
                            else
                              case616()
                          else
                            case616()
                        else
                          case616()
                  else
                    case616()
          }
        else
          case616()
      };
      case616(){ // Other cases follow below
      .........

dotc with this change

def test(x: Object): Unit = 
      {
        case val x161: Object = x
        if x161.isInstanceOf[Foo1] then 
          {
            case val x162: Foo1 = x161.asInstanceOf[Foo1]
            case val x163: Foo1 = Foo1.unapply(x162)
            case val x164: Int = x163._1()
            case val x165: Int = x163._2()
            case val x166: Int = x163._3()
            case val x167: Int = x163._4()
            case val x168: Int = x163._5()
            case val x169: Int = x163._6()
            case val x170: Int = x163._7()
            case val x171: Int = x163._8()
            case val x172: Int = x163._9()
            case val x173: Int = x163._10()
            if 
              1.==(x164).&&(2.==(x165)).&&(3.==(x166)).&&(4.==(x167)).&&(
                5.==(x168)
              ).&&(6.==(x169)).&&(7.==(x170)).&&(8.==(x171)).&&(9.==(x172)).&&(
                10.==(x173)
              )
             then 
              {
                Test.stuff()
              }
             else 
              {
                <label> def case521(): Unit = ... // other cases go in here
                case521()
              }
          }
         else case521()
      }

Note that it labels may still be deeply nested case521. The depth of this tree will be proportional to the number of cases as long as ifs can be merged. Currently, nested unapplies will still generate nested case vals which cannot be fully optimised by this. For example Foo(A(1), A(2), A(3), ...) will generate a similar nesting as before.

@nicolasstucki
Copy link
Contributor Author

test performance please

@dottybot
Copy link
Member

performance test scheduled: 1 job(s) in queue, 1 running.

@dottybot
Copy link
Member

Performance test finished successfully:

Visit http://dotty-bench.epfl.ch/3575/ to see the changes.

Benchmarks is based on merging with master (2aa5950)

case selector :: selectors1 => letAbstract(selector)(sym => matchArgsSelectorsPlan(selectors1, sym :: syms))
case Nil => matchArgsPaternPlan(args, syms.reverse)
}
def matchArgsPaternPlan(args: List[Tree], syms: List[Symbol]): Plan =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Patern -> Pattern

case plan2: TestPlan if plan.onFailure == plan2.onFailure =>
emmitWithMashedConditions(plan2 :: plans)
case _ =>
def emmitCondWithPos(plan: TestPlan) = emitCondition(plan).withPos(plan.pos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: emmit -> emit

@nicolasstucki
Copy link
Contributor Author

test performance please

@dottybot
Copy link
Member

performance test scheduled: 1 job(s) in queue, 0 running.

@dottybot
Copy link
Member

performance test scheduled: 1 job(s) in queue, 1 running.

@dottybot
Copy link
Member

Performance test finished successfully:

Visit http://dotty-bench.epfl.ch/3575/ to see the changes.

Benchmarks is based on merging with master (52279cb)

@dottybot
Copy link
Member

Performance test finished successfully:

Visit http://dotty-bench.epfl.ch/3575/ to see the changes.

Benchmarks is based on merging with master (6ade9be)

def matchArgsPatternPlan(args: List[Tree], syms: List[Symbol]): Plan =
((args, syms): @unchecked) match { // both lists should always have the same size
case (arg :: args1, sym :: syms1) => patternPlan(sym, arg, matchArgsPatternPlan(args1, syms1), onFailure)
case (Nil, Nil) => onSuccess
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you can speed this up by doing like it was done initially:

args match {
  case arg :: args1 =>
    val sym :: sym = syms
    patternPlan(sym, arg, matchArgsPatternPlan(args1, syms1), onFailure)
  case Nil => onSuccess
}

You don't need to construct and deconstruct a tuple anymore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would need to add an assertion in the Nil case to ensure failure. At some point, we should make that optimization work, there is no good reason for it to create the in those cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually with -optimize we already remove the deconstruction of the tuple. It currently becomes:

        case val x1: Tuple2 = new Tuple2(xs, ys)
        if 
          xs.isInstanceOf[scala.collection.immutable.::].&&(
            ys.isInstanceOf[scala.collection.immutable.::]
          )
         then 
          patternPlan(sym, arg, matchArgsPatternPlan(args1, syms1), onFailure)
         else 
          {
             def case2(case x16: scala.collection.immutable.List): String = 
              if Nil().==(x16).&&(Nil().==(ys)) then onSuccess else 
                {
                   def case1(): String = throw new MatchError(x1)
                  case1()
                }
            case2(xs)
          }

There is still a spurious new Tuple2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if the -optimise option removed it, this is not something we want to enable by default since it adds a significant compilation overhead

case Nil => onSuccess
}
def matchArgsPlan(selectors: List[Tree], args: List[Tree], onSuccess: Plan): Plan = {
def matchArgsSelectorsPlan(selectors: List[Tree], syms: List[Symbol]): Plan =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document this, ideally giving an example.

else {
/** Merge nested `if`s that have the same `else` branch into a single `if`.
* This optimization targets calls to label defs for case failure jumps to next case.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give an example here as well?

* Extract all match arguments before checking conditions on them like scalac does.
  This avoids an extra nested block for each match variable.
* Merge conditions of nested `if` expressions if their `else` branch is the same.
  This optimization combined with the previous removes most of the nested `if`s
  created to check the matched args.
@nicolasstucki
Copy link
Contributor Author

@odersky. Added examples in the docs.

@nicolasstucki nicolasstucki requested review from odersky and removed request for smarter December 22, 2017 15:11
@nicolasstucki nicolasstucki assigned odersky and unassigned smarter Dec 22, 2017
@odersky odersky merged commit 64021f9 into scala:master Dec 23, 2017
@allanrenucci allanrenucci deleted the fix-#2903 branch December 23, 2017 18:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants