Skip to content

Commit 2ff38fd

Browse files
SGrondintalex5
andcommitted
Safe Fiber races: ~combine and n_any
Co-authored-by: Thomas Leonard <[email protected]>
1 parent c9db164 commit 2ff38fd

File tree

3 files changed

+156
-21
lines changed

3 files changed

+156
-21
lines changed

lib_eio/core/eio__core.mli

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ module Fiber : sig
206206
(** [all fs] is like [both], but for any number of fibers.
207207
[all []] returns immediately. *)
208208

209-
val first : (unit -> 'a) -> (unit -> 'a) -> 'a
209+
val first : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a
210210
(** [first f g] runs [f ()] and [g ()] concurrently.
211211
212212
They run in a new cancellation sub-context, and when one finishes the other is cancelled.
@@ -216,15 +216,24 @@ module Fiber : sig
216216
217217
If both fibers fail, {!Exn.combine} is used to combine the exceptions.
218218
219-
Warning: it is always possible that {i both} operations will succeed (and one result will be thrown away).
220-
This is because there is a period of time after the first operation succeeds,
221-
but before its fiber finishes, during which the other operation may also succeed. *)
219+
Warning: it is always possible that {i both} operations will succeed.
220+
This is because there is a period of time after the first operation succeeds
221+
when it is waiting in the run-queue to resume
222+
during which the other operation may also succeed.
222223
223-
val any : (unit -> 'a) list -> 'a
224+
If both fibers succeed, [combine a b] is used to combine the results
225+
(where [a] is the result of the first fiber to return and [b] is the second result).
226+
The default is [fun a _ -> a] which discards the later result. *)
227+
228+
val any : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) list -> 'a
224229
(** [any fs] is like [first], but for any number of fibers.
225230
226231
[any []] just waits forever (or until cancelled). *)
227232

233+
val n_any : (unit -> 'a) list -> 'a list
234+
(** [n_any fs] is like [any], expect that if multiple fibers return values
235+
then thay are all returned, in the order in which the fibers finished. *)
236+
228237
val await_cancel : unit -> 'a
229238
(** [await_cancel ()] waits until cancelled.
230239
@raise Cancel.Cancelled *)

lib_eio/core/fiber.ml

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,22 @@ let await_cancel () =
8787
Suspend.enter "await_cancel" @@ fun fiber enqueue ->
8888
Cancel.Fiber_context.set_cancel_fn fiber (fun ex -> enqueue (Error ex))
8989

90-
let any fs =
91-
let r = ref `None in
90+
type 'a any_status =
91+
| New
92+
| Ex of (exn * Printexc.raw_backtrace)
93+
| OK of 'a
94+
95+
let any_gen ~return ~combine fs =
96+
let r = ref New in
9297
let parent_c =
9398
Cancel.sub_unchecked Any (fun cc ->
9499
let wrap h =
95100
match h () with
96101
| x ->
97102
begin match !r with
98-
| `None -> r := `Ok x; Cancel.cancel cc Not_first
99-
| `Ex _ | `Ok _ -> ()
103+
| New -> r := OK (return x); Cancel.cancel cc Not_first
104+
| OK prev -> r := OK (combine prev x)
105+
| Ex _ -> ()
100106
end
101107
| exception Cancel.Cancelled _ when not (Cancel.is_on cc) ->
102108
(* If this is in response to us asking the fiber to cancel then we can just ignore it.
@@ -105,11 +111,11 @@ let any fs =
105111
()
106112
| exception ex ->
107113
begin match !r with
108-
| `None -> r := `Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex
109-
| `Ok _ -> r := `Ex (ex, Printexc.get_raw_backtrace ())
110-
| `Ex prev ->
114+
| New -> r := Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex
115+
| OK _ -> r := Ex (ex, Printexc.get_raw_backtrace ())
116+
| Ex prev ->
111117
let bt = Printexc.get_raw_backtrace () in
112-
r := `Ex (Exn.combine prev (ex, bt))
118+
r := Ex (Exn.combine prev (ex, bt))
113119
end
114120
in
115121
let vars = Cancel.Fiber_context.get_vars () in
@@ -121,7 +127,7 @@ let any fs =
121127
let p, r = Promise.create_with_id (Cancel.Fiber_context.tid new_fiber) in
122128
fork_raw new_fiber (fun () ->
123129
match wrap f with
124-
| x -> Promise.resolve_ok r x
130+
| () -> Promise.resolve_ok r ()
125131
| exception ex -> Promise.resolve_error r ex
126132
);
127133
p :: aux fs
@@ -131,16 +137,21 @@ let any fs =
131137
)
132138
in
133139
match !r, Cancel.get_error parent_c with
134-
| `Ok r, None -> r
135-
| (`Ok _ | `None), Some ex -> raise ex
136-
| `Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt
137-
| `Ex ex1, Some ex2 ->
140+
| OK r, None -> r
141+
| (OK _ | New), Some ex -> raise ex
142+
| Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt
143+
| Ex ex1, Some ex2 ->
138144
let bt2 = Printexc.get_raw_backtrace () in
139145
let ex, bt = Exn.combine ex1 (ex2, bt2) in
140146
Printexc.raise_with_backtrace ex bt
141-
| `None, None -> assert false
147+
| New, None -> assert false
148+
149+
let n_any fs =
150+
List.rev (any_gen fs ~return:(fun x -> [x]) ~combine:(fun xs x -> x :: xs))
151+
152+
let any ?(combine=(fun x _ -> x)) fs = any_gen fs ~return:Fun.id ~combine
142153

143-
let first f g = any [f; g]
154+
let first ?combine f g = any ?combine [f; g]
144155

145156
let is_cancelled () =
146157
let ctx = Effect.perform Cancel.Get_context in

tests/fiber.md

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Second finishes, first is cancelled:
3838
- : unit = ()
3939
```
4040

41-
If both succeed, we pick the first one:
41+
If both succeed and no ~combine, we pick the first one by default:
4242

4343
```ocaml
4444
# run @@ fun () ->
@@ -49,6 +49,73 @@ If both succeed, we pick the first one:
4949
- : unit = ()
5050
```
5151

52+
If both succeed we let ~combine decide:
53+
54+
```ocaml
55+
# run @@ fun () ->
56+
Fiber.first ~combine:(fun _ x -> x)
57+
(fun () -> "a")
58+
(fun () -> "b");;
59+
+b
60+
- : unit = ()
61+
```
62+
63+
It allows for safe Stream.take races (both):
64+
65+
```ocaml
66+
# run @@ fun () ->
67+
let stream = Eio.Stream.create 1 in
68+
Fiber.first ~combine:(fun x y -> x ^ y)
69+
(fun () ->
70+
Fiber.yield ();
71+
Eio.Stream.add stream "b";
72+
"a"
73+
)
74+
(fun () -> Eio.Stream.take stream);;
75+
+ab
76+
- : unit = ()
77+
```
78+
79+
It allows for safe Stream.take races (f is first):
80+
81+
```ocaml
82+
# run @@ fun () ->
83+
let stream = Eio.Stream.create 1 in
84+
let out =
85+
Fiber.first ~combine:(fun x y -> x ^ y)
86+
(fun () ->
87+
Eio.Stream.add stream "b";
88+
Fiber.yield ();
89+
"a"
90+
)
91+
(fun () ->
92+
Fiber.yield ();
93+
Eio.Stream.take stream)
94+
in
95+
out ^ Int.to_string (Eio.Stream.length stream);;
96+
+a1
97+
- : unit = ()
98+
```
99+
100+
It allows for safe Stream.take races (g is first):
101+
102+
```ocaml
103+
# run @@ fun () ->
104+
let stream = Eio.Stream.create 1 in
105+
let out =
106+
Fiber.first ~combine:(fun x y -> x ^ y)
107+
(fun () ->
108+
Eio.Stream.add stream "b";
109+
Fiber.yield ();
110+
"a"
111+
)
112+
(fun () -> Eio.Stream.take stream)
113+
in
114+
out ^ Int.to_string (Eio.Stream.length stream);;
115+
+b0
116+
- : unit = ()
117+
```
118+
52119
One crashes - report it:
53120

54121
```ocaml
@@ -201,6 +268,54 @@ Exception: Stdlib.Exit.
201268
- : unit = ()
202269
```
203270

271+
`Fiber.any` with combine collects all results:
272+
273+
```ocaml
274+
# run @@ fun () ->
275+
Fiber.any
276+
~combine:(fun x y -> x @ y)
277+
(List.init 3 (fun x () -> traceln "%d" x; [x]))
278+
|> Fmt.(str "%a" (Dump.list int));;
279+
+0
280+
+1
281+
+2
282+
+[0; 1; 2]
283+
- : unit = ()
284+
```
285+
286+
# Fiber.n_any
287+
288+
`Fiber.n_any` behaves just like `Fiber.any` when there's only one result:
289+
290+
```ocaml
291+
# run @@ fun () ->
292+
Fiber.n_any (List.init 3 (fun x () -> traceln "%d" x; Fiber.yield (); x))
293+
|> Fmt.(str "%a" (Dump.list int));;
294+
+0
295+
+1
296+
+2
297+
+[0]
298+
- : unit = ()
299+
```
300+
301+
`Fiber.n_any` collects all results:
302+
303+
```ocaml
304+
# run @@ fun () ->
305+
(Fiber.n_any (List.init 4 (fun x () ->
306+
traceln "%d" x;
307+
if x = 1 then Fiber.yield ();
308+
x
309+
)))
310+
|> Fmt.(str "%a" (Dump.list int));;
311+
+0
312+
+1
313+
+2
314+
+3
315+
+[0; 2; 3]
316+
- : unit = ()
317+
```
318+
204319
# Fiber.await_cancel
205320

206321
```ocaml

0 commit comments

Comments
 (0)