Skip to content

Commit e13ca2f

Browse files
committed
feat: loopChoose wip, untested
1 parent e4a0603 commit e13ca2f

File tree

6 files changed

+270
-1
lines changed

6 files changed

+270
-1
lines changed

basis-library/schedulers/spork/ForkJoin.sml

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ sig
6161
val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a
6262
val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
6363
val parform: (int * int) -> (int -> unit) -> unit
64+
val seqLoop: (int * int) -> (int -> unit) -> unit
65+
val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
6466
end =
6567
struct
6668
type idx = LoopIndex.t
@@ -147,6 +149,26 @@ struct
147149

148150
fun __inline_always__ parform (lo: int, hi: int) (f: int -> unit) : unit =
149151
reducem (fn _ => ()) () (lo, hi) f
152+
153+
154+
fun __inline_always__ seqLoop (lo: int, hi: int) (f: int -> unit) : unit =
155+
let
156+
fun loop (i: idx, j: idx) : unit =
157+
if LoopIndex.equal (i, j) then ()
158+
else (__inline_always__ f (LoopIndex.toInt i); loop (LoopIndex.increment i, j))
159+
in
160+
loop (LoopIndex.fromInt (Int.min (lo, hi)), LoopIndex.fromInt hi)
161+
end
162+
163+
164+
fun __inline_always__ seqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a =
165+
let
166+
fun loop (acc: 'a) (i: idx, j: idx) : 'a =
167+
if LoopIndex.equal (i, j) then acc
168+
else loop (__inline_always__ combine (acc, __inline_always__ f (LoopIndex.toInt i))) (LoopIndex.increment i, j)
169+
in
170+
loop zero (LoopIndex.fromInt (Int.min (lo, hi)), LoopIndex.fromInt hi)
171+
end
150172
end
151173

152174

@@ -170,6 +192,9 @@ sig
170192
val parfor: int -> (int * int) -> (int -> unit) -> unit
171193
val alloc: int -> 'a array
172194

195+
val seqLoop: (int * int) -> (int -> unit) -> unit
196+
val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
197+
173198
val idleTimeSoFar: unit -> Time.time
174199
val workTimeSoFar: unit -> Time.time
175200
val maxForkDepthSoFar: unit -> int
@@ -280,7 +305,47 @@ struct
280305
val fInt16 = Unrolled16.parform
281306
val fInt32 = Unrolled32.parform
282307
val fInt64 = Unrolled64.parform
283-
val fIntInf = Unrolled64.parform
308+
val fIntInf = Unrolled64.parform
309+
end)
310+
311+
structure SeqLoop =
312+
Int_ChooseFromInt (struct
313+
type 'a t = (int * int) -> (int -> unit) -> unit
314+
val fInt8 = Loops8.seqLoop
315+
val fInt16 = Loops16.seqLoop
316+
val fInt32 = Loops32.seqLoop
317+
val fInt64 = Loops64.seqLoop
318+
val fIntInf = LoopsInt.seqLoop
319+
end)
320+
321+
structure SeqReduce =
322+
Int_ChooseFromInt (struct
323+
type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
324+
val fInt8 = Loops8.seqReduce
325+
val fInt16 = Loops16.seqReduce
326+
val fInt32 = Loops32.seqReduce
327+
val fInt64 = Loops64.seqReduce
328+
val fIntInf = LoopsInt.seqReduce
329+
end)
330+
331+
structure UnrolledSeqLoop =
332+
Int_ChooseFromInt (struct
333+
type 'a t = (int * int) -> (int -> unit) -> unit
334+
val fInt8 = Unrolled8.seqLoop
335+
val fInt16 = Unrolled16.seqLoop
336+
val fInt32 = Unrolled32.seqLoop
337+
val fInt64 = Unrolled64.seqLoop
338+
val fIntInf = Unrolled64.seqLoop
339+
end)
340+
341+
structure UnrolledSeqReduce =
342+
Int_ChooseFromInt (struct
343+
type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
344+
val fInt8 = Unrolled8.seqReduce
345+
val fInt16 = Unrolled16.seqReduce
346+
val fInt32 = Unrolled32.seqReduce
347+
val fInt64 = Unrolled64.seqReduce
348+
val fIntInf = Unrolled64.seqReduce
284349
end)
285350

286351
local
@@ -314,6 +379,22 @@ struct
314379
in
315380
primSporkChoose (__inline_always__ loopBody, __inline_always__ unrolledImpl, __inline_always__ regularImpl)
316381
end
382+
383+
fun __inline_always__ unifiedSeqLoop (lo: int, hi: int) (f: int -> unit) : unit =
384+
let
385+
fun __inline_always__ regularImpl () = __inline_always__ SeqLoop.f (lo, hi) f
386+
fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledSeqLoop.f (lo, hi) f
387+
in
388+
Scheduler.primLoopChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl)
389+
end
390+
391+
fun __inline_always__ unifiedSeqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a =
392+
let
393+
fun __inline_always__ regularImpl () = __inline_always__ SeqReduce.f combine zero (lo, hi) f
394+
fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledSeqReduce.f combine zero (lo, hi) f
395+
in
396+
Scheduler.primLoopChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl)
397+
end
317398
in
318399
val reducem = __inline_always__ unifiedReducem
319400
val reduce = __inline_always__ unifiedReducem
@@ -322,6 +403,8 @@ struct
322403
val parformDefault = __inline_always__ Parform.f
323404
val pareduce = __inline_always__ unifiedPareduce
324405
val parfor = __inline_always__ ForkJoin0.parfor
406+
val seqLoop = __inline_always__ unifiedSeqLoop
407+
val seqReduce = __inline_always__ unifiedSeqReduce
325408
end
326409

327410
val pareduceBreakExn = __inline_always__ PareduceBreakExn.f

basis-library/schedulers/spork/Scheduler.sml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ struct
109109
* (unit -> 'a) (* unrolled implementation *)
110110
* (unit -> 'a) (* regular implementation *)
111111
-> 'a;
112+
val primLoopChoose' =
113+
_prim "loop_choose"
114+
: ('u -> 'a) (* loop body *)
115+
* (unit -> 'a) (* unrolled implementation *)
116+
* (unit -> 'a) (* regular implementation *)
117+
-> 'a;
112118

113119
fun __inline_always__ primSporkFair (body, spwn, seq, sync, exnseq, exnsync) =
114120
__inline_always__ primSporkFair' (body, (), spwn, (), seq, sync, exnseq, exnsync)
@@ -118,6 +124,8 @@ struct
118124
__inline_always__ primSporkGive' (body, (), spwn, (), seq, sync, exnseq, exnsync)
119125
fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) =
120126
__inline_always__ primSporkChoose' (loopBody, unrolled, regular)
127+
fun __inline_always__ primLoopChoose (loopBody, unrolled, regular) =
128+
__inline_always__ primLoopChoose' (loopBody, unrolled, regular)
121129

122130
val primForkThreadAndSetData = _prim "spork_forkThreadAndSetData": Thread.t * 'a -> Thread.p;
123131
val primForkThreadAndSetData_youngest = _prim "spork_forkThreadAndSetData_youngest": Thread.t * 'a -> Thread.p;

basis-library/schedulers/spork/UnrolledLoops.sml

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ sig
44
val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a
55
val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
66
val parform: (int * int) -> (int -> unit) -> unit
7+
val seqLoop: (int * int) -> (int -> unit) -> unit
8+
val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
79
end =
810
struct
911

@@ -251,4 +253,62 @@ struct
251253
fun __inline_always__ parform (lo: int, hi: int) (f: int -> unit) : unit =
252254
pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ())
253255

256+
257+
fun __inline_always__ seqLoop (lo: int, hi: int) (f: int -> unit) : unit =
258+
let
259+
fun loop8 (i: word, j: word) : unit =
260+
if WordImpl.<= (WordImpl.+ (i, eight), j) then
261+
let
262+
val _ = __inline_always__ f (w2i i)
263+
val _ = __inline_always__ f (w2i (WordImpl.+ (i, one)))
264+
val _ = __inline_always__ f (w2i (WordImpl.+ (i, two)))
265+
val _ = __inline_always__ f (w2i (WordImpl.+ (i, three)))
266+
val _ = __inline_always__ f (w2i (WordImpl.+ (i, four)))
267+
val _ = __inline_always__ f (w2i (WordImpl.+ (i, five)))
268+
val _ = __inline_always__ f (w2i (WordImpl.+ (i, six)))
269+
val _ = __inline_always__ f (w2i (WordImpl.+ (i, seven)))
270+
in
271+
loop8 (WordImpl.+ (i, eight), j)
272+
end
273+
else
274+
loop1 (i, j)
275+
276+
and loop1 (i: word, j: word) : unit =
277+
if WordImpl.< (i, j) then
278+
(__inline_always__ f (w2i i); loop1 (WordImpl.+ (i, one), j))
279+
else
280+
()
281+
in
282+
loop8 (i2w lo, i2w hi)
283+
end
284+
285+
286+
fun __inline_always__ seqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a =
287+
let
288+
fun loop8 (acc: 'a, i: word, j: word) : 'a =
289+
if WordImpl.<= (WordImpl.+ (i, eight), j) then
290+
let
291+
val acc = __inline_always__ combine (acc, __inline_always__ f (w2i i))
292+
val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, one))))
293+
val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, two))))
294+
val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, three))))
295+
val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, four))))
296+
val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, five))))
297+
val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, six))))
298+
val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, seven))))
299+
in
300+
loop8 (acc, WordImpl.+ (i, eight), j)
301+
end
302+
else
303+
loop1 (acc, i, j)
304+
305+
and loop1 (acc: 'a, i: word, j: word) : 'a =
306+
if WordImpl.< (i, j) then
307+
loop1 (__inline_always__ combine (acc, __inline_always__ f (w2i i)), WordImpl.+ (i, one), j)
308+
else
309+
acc
310+
in
311+
loop8 (zero, i2w lo, i2w hi)
312+
end
313+
254314
end

mlton/atoms/prim.fun

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ datatype 'a t =
113113
| MLton_share (* to rssa (as nop or runtime C fn) *)
114114
| MLton_size (* to rssa (as runtime C fn) *)
115115
| MLton_touch (* to rssa (as nop) or backend (as nop) *)
116+
| Loop_choose (* closure convert *)
116117
(* Choose between unrolled and regular at compile time *)
117118
| Spork_choose (* closure convert *)
118119
| Spork of {tokenSplitPolicy: Word32.word} (* closure convert *)
@@ -296,6 +297,7 @@ fun toString (n: 'a t): string =
296297
| MLton_share => "MLton_share"
297298
| MLton_size => "MLton_size"
298299
| MLton_touch => "MLton_touch"
300+
| Loop_choose => "loop_choose"
299301
| Spork_choose => "spork_choose"
300302
| Spork {tokenSplitPolicy=0w0} => "spork_fair"
301303
| Spork {tokenSplitPolicy=0w1} => "spork_keep"
@@ -465,6 +467,7 @@ val equals: 'a t * 'a t -> bool =
465467
| (MLton_share, MLton_share) => true
466468
| (MLton_size, MLton_size) => true
467469
| (MLton_touch, MLton_touch) => true
470+
| (Loop_choose, Loop_choose) => true
468471
(* TODO: Check usage properly *)
469472
| (Spork_choose, Spork_choose) => true
470473
| (Spork {tokenSplitPolicy = tsp1}, Spork {tokenSplitPolicy = tsp2}) => tsp1 = tsp2
@@ -654,6 +657,7 @@ val map: 'a t * ('a -> 'b) -> 'b t =
654657
| MLton_touch => MLton_touch
655658
| Spork tsp => Spork tsp
656659
(* TODO: Check usage properly *)
660+
| Loop_choose => Loop_choose
657661
| Spork_choose => Spork_choose
658662
| Spork_forkThreadAndSetData z => Spork_forkThreadAndSetData z
659663
| Spork_getData spid => Spork_getData spid
@@ -872,6 +876,7 @@ val kind: 'a t -> Kind.t =
872876
| MLton_touch => SideEffect
873877
| Spork _ => SideEffect
874878
(* TODO: Check usage properly *)
879+
| Loop_choose => SideEffect
875880
| Spork_choose => SideEffect
876881
| Spork_forkThreadAndSetData _ => SideEffect
877882
| Spork_getData _ => DependsOnState
@@ -1087,6 +1092,7 @@ in
10871092
Spork {tokenSplitPolicy = 0w1},
10881093
Spork {tokenSplitPolicy = 0w2},
10891094
(* TODO: Check usage properly *)
1095+
Loop_choose,
10901096
Spork_choose,
10911097
Spork_forkThreadAndSetData {youngest=true},
10921098
Spork_forkThreadAndSetData {youngest=false},
@@ -1455,6 +1461,15 @@ fun 'a checkApp (prim: 'a t,
14551461
in
14561462
(eightArgs (cont, taa, spwn, tba, seq, sync, exnseq, exnsync), tc)
14571463
end)
1464+
| Loop_choose =>
1465+
(* TODO: Check usage properly *)
1466+
twoTargs (fn (ta, tu) =>
1467+
let
1468+
val loopBody = arrow (tu, ta) (* First arg: loop body function 'u -> 'a *)
1469+
val impl = arrow (unit, ta) (* Second and third args: thunks unit -> 'a *)
1470+
in
1471+
(threeArgs (loopBody, impl, impl), ta)
1472+
end)
14581473
| Spork_choose =>
14591474
(* TODO: Check usage properly *)
14601475
(* spork_choose: ('u -> 'v) -> (unit -> 'a) -> (unit -> 'a) -> 'a
@@ -1621,6 +1636,14 @@ fun ('a, 'b) extractTargs (prim: 'b t,
16211636
in
16221637
six (taa, tar, tba, tbr, td, tc)
16231638
end
1639+
| Loop_choose =>
1640+
(* TODO: Check usage properly *)
1641+
let
1642+
val ta = result (* Result type 'a *)
1643+
val (tu, _) = deArrow (arg 0) (* First arg: loop body ('u -> 'v) *)
1644+
in
1645+
Vector.new2 (ta, tu)
1646+
end
16241647
| Spork_choose =>
16251648
(* TODO: Check usage properly *)
16261649
(* spork_choose: ('u -> 'v) -> 'a -> 'a -> 'a *)

mlton/atoms/prim.sig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ signature PRIM =
104104
| MLton_share (* to rssa (as nop or runtime C fn) *)
105105
| MLton_size (* to rssa (as runtime C fn) *)
106106
| MLton_touch (* to rssa (as nop) or backend (as nop) *)
107+
| Loop_choose (* closure convert *)
107108
| Spork_choose (* TODO: closure convert / SSA / SSA2 / RSSA ? *)
108109
| Spork of {tokenSplitPolicy: Word32.word} (* closure convert *)
109110
| Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *)

0 commit comments

Comments
 (0)