Skip to content

Commit a4c771c

Browse files
committed
Add a Var abstraction in QuoteUtils
1 parent 0ebbcff commit a4c771c

File tree

4 files changed

+72
-69
lines changed

4 files changed

+72
-69
lines changed

compiler/test/dotty/tools/dotc/BootstrappedOnlyCompilationTests.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,7 @@ class BootstrappedOnlyCompilationTests extends ParallelTesting {
120120
implicit val testGroup: TestGroup = TestGroup("runWithCompiler")
121121
aggregateTests(
122122
compileFilesInDir("tests/run-with-compiler", withCompilerOptions),
123-
compileDir("tests/run-with-compiler-custom-args/tasty-interpreter", withCompilerOptions),
124-
compileFile("tests/run-with-compiler-custom-args/staged-streams_1.scala", withCompilerOptions without "-Yno-deep-subtypes")
123+
compileDir("tests/run-with-compiler-custom-args/tasty-interpreter", withCompilerOptions)
125124
).checkRuns()
126125
}
127126

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package scala.quoted
2+
package util
3+
4+
/** An abstraction for variable definition to use in a quoted program.
5+
* It decouples the operations of get and update, if needed to be spliced separately.
6+
*/
7+
sealed trait Var[T] {
8+
def get given QuoteContext: Expr[T]
9+
def update(x: Expr[T]) given QuoteContext: Expr[Unit]
10+
}
11+
12+
object Var {
13+
def apply[T: Type, U: Type](init: Expr[T])(body: Var[T] => Expr[U]) given QuoteContext: Expr[U] = '{
14+
var x = $init
15+
${
16+
body(
17+
new Var[T] {
18+
def get given QuoteContext: Expr[T] = 'x
19+
def update(e: Expr[T]) given QuoteContext: Expr[Unit] = '{ x = $e }
20+
}
21+
)
22+
}
23+
}
24+
}

tests/run-with-compiler-custom-args/staged-streams_1.scala renamed to tests/run-with-compiler/staged-streams_1.scala

Lines changed: 47 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,13 @@
11
import scala.quoted._
2+
import scala.quoted.util._
23
import given scala.quoted.autolift._
34

45
/**
56
* Port of the strymonas library as described in O. Kiselyov et al., Stream fusion, to completeness (POPL 2017)
67
*/
7-
88
object Test {
99

10-
// TODO: remove as it exists in Quoted Lib
11-
sealed trait Var[T] {
12-
def get given QuoteContext: Expr[T]
13-
def update(x: Expr[T]) given QuoteContext: Expr[Unit]
14-
}
15-
16-
object Var {
17-
def apply[T: Type, U: Type](init: Expr[T])(body: Var[T] => Expr[U]) given QuoteContext: Expr[U] = '{
18-
var x = $init
19-
${
20-
body(
21-
new Var[T] {
22-
def get given QuoteContext: Expr[T] = 'x
23-
def update(e: Expr[T]) given QuoteContext: Expr[Unit] = '{ x = $e }
24-
}
25-
)
26-
}
27-
}
28-
}
10+
type E[T] = given QuoteContext => Expr[T]
2911

3012
/*** Producer represents a linear production of values with a loop structure.
3113
*
@@ -61,27 +43,29 @@ object Test {
6143
* @param k the continuation that is invoked after the new state is defined in the body of `init`
6244
* @return expr value of unit per the CPS-encoding
6345
*/
64-
def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit]
46+
def init(k: St => Expr[Unit]): E[Unit]
6547

6648
/** Step method that defines the transformation of data.
6749
*
6850
* @param st the state needed for this iteration step
6951
* @param k the continuation that accepts each element and proceeds with the step-wise processing
7052
* @return expr value of unit per the CPS-encoding
7153
*/
72-
def step(st: St, k: (A => Expr[Unit])) given QuoteContext: Expr[Unit]
54+
def step(st: St, k: (A => Expr[Unit])): E[Unit]
7355

7456
/** The condition that checks for termination
7557
*
7658
* @param st the state needed for this iteration check
7759
* @return the expression for a boolean
7860
*/
79-
def hasNext(st: St) given QuoteContext: Expr[Boolean]
61+
def hasNext(st: St): E[Boolean]
8062
}
8163

82-
trait Cardinality
83-
case object AtMost1 extends Cardinality
84-
case object Many extends Cardinality
64+
enum Cardinality {
65+
case AtMost1
66+
case Many
67+
}
68+
import Cardinality._
8569

8670
trait StagedStream[A]
8771
case class Linear[A](producer: Producer[A]) extends StagedStream[A]
@@ -98,19 +82,17 @@ object Test {
9882
* @tparam W the type of the accumulator
9983
* @return
10084
*/
101-
def fold[W: Type](z: Expr[W], f: ((Expr[W], Expr[A]) => Expr[W])) given QuoteContext: Expr[W] = {
102-
Var(z) { s: Var[W] => '{
103-
${
104-
foldRaw[Expr[A]]((a: Expr[A]) => '{
105-
${ s.update(f(s.get, a)) }
106-
}, stream)
107-
}
85+
def fold[W: Type](z: Expr[W], f: ((Expr[W], Expr[A]) => Expr[W])): E[W] = {
86+
Var(z) { s =>
87+
'{
88+
${ foldRaw[Expr[A]]((a: Expr[A]) => s.update(f(s.get, a)), stream) }
89+
10890
${ s.get }
10991
}
11092
}
11193
}
11294

113-
private def foldRaw[A](consumer: A => Expr[Unit], stream: StagedStream[A]) given QuoteContext: Expr[Unit] = {
95+
private def foldRaw[A](consumer: A => Expr[Unit], stream: StagedStream[A]): E[Unit] = {
11496
stream match {
11597
case Linear(producer) => {
11698
producer.card match {
@@ -166,15 +148,15 @@ object Test {
166148
type St = producer.St
167149
val card = producer.card
168150

169-
def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
151+
def init(k: St => Expr[Unit]): E[Unit] = {
170152
producer.init(k)
171153
}
172154

173-
def step(st: St, k: (B => Expr[Unit])) given QuoteContext: Expr[Unit] = {
155+
def step(st: St, k: (B => Expr[Unit])): E[Unit] = {
174156
producer.step(st, el => f(el)(k))
175157
}
176158

177-
def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
159+
def hasNext(st: St): E[Boolean] = {
178160
producer.hasNext(st)
179161
}
180162
}
@@ -229,13 +211,13 @@ object Test {
229211
type St = Expr[A]
230212
val card = AtMost1
231213

232-
def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] =
214+
def init(k: St => Expr[Unit]): E[Unit] =
233215
k(a)
234216

235-
def step(st: St, k: (Expr[A] => Expr[Unit])) given QuoteContext: Expr[Unit] =
217+
def step(st: St, k: (Expr[A] => Expr[Unit])): E[Unit] =
236218
k(st)
237219

238-
def hasNext(st: St) given QuoteContext: Expr[Boolean] =
220+
def hasNext(st: St): E[Boolean] =
239221
pred(st)
240222
}
241223

@@ -259,13 +241,13 @@ object Test {
259241
type St = producer.St
260242
val card = producer.card
261243

262-
def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] =
244+
def init(k: St => Expr[Unit]): E[Unit] =
263245
producer.init(k)
264246

265-
def step(st: St, k: (A => Expr[Unit])) given QuoteContext: Expr[Unit] =
247+
def step(st: St, k: (A => Expr[Unit])): E[Unit] =
266248
producer.step(st, el => k(el))
267249

268-
def hasNext(st: St) given QuoteContext: Expr[Boolean] =
250+
def hasNext(st: St): E[Boolean] =
269251
f(producer.hasNext(st))
270252
}
271253
case AtMost1 => producer
@@ -292,22 +274,20 @@ object Test {
292274
type St = (Var[Int], producer.St)
293275
val card = producer.card
294276

295-
def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
277+
def init(k: St => Expr[Unit]): E[Unit] = {
296278
producer.init(st => {
297279
Var(n) { counter =>
298280
k(counter, st)
299281
}
300282
})
301283
}
302284

303-
def step(st: St, k: (((Var[Int], A)) => Expr[Unit])) given QuoteContext: Expr[Unit] = {
285+
def step(st: St, k: (((Var[Int], A)) => Expr[Unit])): E[Unit] = {
304286
val (counter, currentState) = st
305-
producer.step(currentState, el => '{
306-
${k((counter, el))}
307-
})
287+
producer.step(currentState, el => k((counter, el)))
308288
}
309289

310-
def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
290+
def hasNext(st: St): E[Boolean] = {
311291
val (counter, currentState) = st
312292
producer.card match {
313293
case Many => '{ ${counter.get} > 0 && ${producer.hasNext(currentState)} }
@@ -365,7 +345,7 @@ object Test {
365345
pushLinear[A = Expr[A], C = B](producer1, producer2, nestf2)
366346

367347
case (Nested(producer1, nestf1), Linear(producer2)) =>
368-
mapRaw[(B, Expr[A]), (Expr[A], B)]((t => k => '{ ${k((t._2, t._1))} }), pushLinear[A = B, C = Expr[A]](producer2, producer1, nestf1))
348+
mapRaw[(B, Expr[A]), (Expr[A], B)]((t => k => k((t._2, t._1))), pushLinear[A = B, C = Expr[A]](producer2, producer1, nestf1))
369349

370350
case (Nested(producer1, nestf1), Nested(producer2, nestf2)) =>
371351
zipRaw[A, B](Linear(makeLinear(stream1)), stream2)
@@ -441,7 +421,7 @@ object Test {
441421
* @param k the continuation that consumes a variable.
442422
* @return the quote of the orchestrated code that will be executed as
443423
*/
444-
def makeAdvanceFunction[A](nadv: Var[Unit => Unit], k: A => Expr[Unit], stream: StagedStream[A]) given QuoteContext: Expr[Unit] = {
424+
def makeAdvanceFunction[A](nadv: Var[Unit => Unit], k: A => Expr[Unit], stream: StagedStream[A]): E[Unit] = {
445425
stream match {
446426
case Linear(producer) =>
447427
producer.card match {
@@ -482,7 +462,7 @@ object Test {
482462
type St = (Var[Boolean], Var[A], Var[Unit => Unit])
483463
val card: Cardinality = Many
484464

485-
def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
465+
def init(k: St => Expr[Unit]): E[Unit] = {
486466
producer.init(st =>
487467
Var('{ (_: Unit) => ()}){ nadv => {
488468
Var('{ true }) { hasNext => {
@@ -506,7 +486,7 @@ object Test {
506486
}})
507487
}
508488

509-
def step(st: St, k: Expr[A] => Expr[Unit]) given QuoteContext: Expr[Unit] = {
489+
def step(st: St, k: Expr[A] => Expr[Unit]): E[Unit] = {
510490
val (flag, current, nadv) = st
511491
'{
512492
var el = ${current.get}
@@ -517,7 +497,7 @@ object Test {
517497

518498
}
519499

520-
def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
500+
def hasNext(st: St): E[Boolean] = {
521501
val (flag, _, _) = st
522502
flag.get
523503
}
@@ -532,19 +512,19 @@ object Test {
532512
type St = (Var[Boolean], producer.St, nestedProducer.St)
533513
val card: Cardinality = Many
534514

535-
def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
536-
producer.init(s1 => '{ ${nestedProducer.init(s2 =>
515+
def init(k: St => Expr[Unit]): E[Unit] = {
516+
producer.init(s1 => nestedProducer.init(s2 =>
537517
Var(producer.hasNext(s1)) { flag =>
538518
k((flag, s1, s2))
539-
})}})
519+
}))
540520
}
541521

542-
def step(st: St, k: ((Var[Boolean], producer.St, B)) => Expr[Unit]) given QuoteContext: Expr[Unit] = {
522+
def step(st: St, k: ((Var[Boolean], producer.St, B)) => Expr[Unit]): E[Unit] = {
543523
val (flag, s1, s2) = st
544-
nestedProducer.step(s2, b => '{ ${k((flag, s1, b))} })
524+
nestedProducer.step(s2, b => k((flag, s1, b)))
545525
}
546526

547-
def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
527+
def hasNext(st: St): E[Boolean] = {
548528
val (flag, s1, s2) = st
549529
'{ ${flag.get} && ${nestedProducer.hasNext(s2)} }
550530
}
@@ -554,7 +534,7 @@ object Test {
554534
val (flag, s1, b) = t
555535

556536
mapRaw[C, (A, C)]((c => k => '{
557-
${producer.step(s1, a => '{ ${k((a, c))} })}
537+
${producer.step(s1, a => k((a, c)))}
558538
${flag.update(producer.hasNext(s1))}
559539
}), addTerminationCondition((b_flag: Expr[Boolean]) => '{ ${flag.get} && $b_flag }, nestedf(b)))
560540
})
@@ -567,16 +547,16 @@ object Test {
567547
type St = (producer1.St, producer2.St)
568548
val card: Cardinality = Many
569549

570-
def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
550+
def init(k: St => Expr[Unit]): E[Unit] = {
571551
producer1.init(s1 => producer2.init(s2 => k((s1, s2)) ))
572552
}
573553

574-
def step(st: St, k: ((A, B)) => Expr[Unit]) given QuoteContext: Expr[Unit] = {
554+
def step(st: St, k: ((A, B)) => Expr[Unit]): E[Unit] = {
575555
val (s1, s2) = st
576556
producer1.step(s1, el1 => producer2.step(s2, el2 => k((el1, el2)) ))
577557
}
578558

579-
def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
559+
def hasNext(st: St): E[Boolean] = {
580560
val (s1, s2) = st
581561
'{ ${producer1.hasNext(s1)} && ${producer2.hasNext(s2)} }
582562
}
@@ -597,15 +577,15 @@ object Test {
597577

598578
val card = Many
599579

600-
def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
580+
def init(k: St => Expr[Unit]): E[Unit] = {
601581
Var('{($arr).length}) { n =>
602582
Var(0){ i =>
603583
k((i, n, arr))
604584
}
605585
}
606586
}
607587

608-
def step(st: St, k: (Expr[A] => Expr[Unit])) given QuoteContext: Expr[Unit] = {
588+
def step(st: St, k: (Expr[A] => Expr[Unit])): E[Unit] = {
609589
val (i, _, arr) = st
610590
'{
611591
val el = ($arr).apply(${i.get})
@@ -614,7 +594,7 @@ object Test {
614594
}
615595
}
616596

617-
def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
597+
def hasNext(st: St): E[Boolean] = {
618598
val (i, n, _) = st
619599
'{
620600
(${i.get} < ${n.get})

0 commit comments

Comments
 (0)