Skip to content

Commit

Permalink
Safe Fiber races: ~combine and n_any
Browse files Browse the repository at this point in the history
  • Loading branch information
SGrondin committed Nov 4, 2023
1 parent bc1e231 commit 0b10025
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 21 deletions.
13 changes: 8 additions & 5 deletions lib_eio/core/eio__core.mli
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ module Fiber : sig
(** [all fs] is like [both], but for any number of fibers.
[all []] returns immediately. *)

val first : (unit -> 'a) -> (unit -> 'a) -> 'a
val first : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a
(** [first f g] runs [f ()] and [g ()] concurrently.
They run in a new cancellation sub-context, and when one finishes the other is cancelled.
Expand All @@ -209,15 +209,18 @@ module Fiber : sig
If both fibers fail, {!Exn.combine} is used to combine the exceptions.
Warning: it is always possible that {i both} operations will succeed (and one result will be thrown away).
This is because there is a period of time after the first operation succeeds,
but before its fiber finishes, during which the other operation may also succeed. *)
If both fibers succeed simultaneously, [combine] is used to combine the results.
If [combine] is not provided, the first one to complete is returned. *)

val any : (unit -> 'a) list -> 'a
val any : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) list -> 'a
(** [any fs] is like [first], but for any number of fibers.
[any []] just waits forever (or until cancelled). *)

val n_any : (unit -> 'a) list -> 'a list
(** [n_any fs] is like [any], the difference being that when multiple fibers are
completed simultaneously (and none are rejected), all the results are returned. *)

val await_cancel : unit -> 'a
(** [await_cancel ()] waits until cancelled.
@raise Cancel.Cancelled *)
Expand Down
49 changes: 34 additions & 15 deletions lib_eio/core/fiber.ml
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,22 @@ let await_cancel () =
Suspend.enter @@ fun fiber enqueue ->
Cancel.Fiber_context.set_cancel_fn fiber (fun ex -> enqueue (Error ex))

let any fs =
let r = ref `None in
type 'a any_status =
| New
| Ex of (exn * Printexc.raw_backtrace)
| OK of 'a list

let rev_any fs =
let r = ref New in
let parent_c =
Cancel.sub_unchecked (fun cc ->
let wrap h =
match h () with
| x ->
begin match !r with
| `None -> r := `Ok x; Cancel.cancel cc Not_first
| `Ex _ | `Ok _ -> ()
| New -> r := OK [x]; Cancel.cancel cc Not_first
| OK ll -> r := OK (x :: ll)
| Ex _ -> ()
end
| exception Cancel.Cancelled _ when not (Cancel.is_on cc) ->
(* If this is in response to us asking the fiber to cancel then we can just ignore it.
Expand All @@ -107,11 +113,11 @@ let any fs =
()
| exception ex ->
begin match !r with
| `None -> r := `Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex
| `Ok _ -> r := `Ex (ex, Printexc.get_raw_backtrace ())
| `Ex prev ->
| New -> r := Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex
| OK _ -> r := Ex (ex, Printexc.get_raw_backtrace ())
| Ex prev ->
let bt = Printexc.get_raw_backtrace () in
r := `Ex (Exn.combine prev (ex, bt))
r := Ex (Exn.combine prev (ex, bt))
end
in
let vars = Cancel.Fiber_context.get_vars () in
Expand All @@ -123,7 +129,7 @@ let any fs =
let p, r = Promise.create_with_id (Cancel.Fiber_context.tid new_fiber) in
fork_raw new_fiber (fun () ->
match wrap f with
| x -> Promise.resolve_ok r x
| () -> Promise.resolve_ok r ()
| exception ex -> Promise.resolve_error r ex
);
p :: aux fs
Expand All @@ -133,16 +139,29 @@ let any fs =
)
in
match !r, Cancel.get_error parent_c with
| `Ok r, None -> r
| (`Ok _ | `None), Some ex -> raise ex
| `Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt
| `Ex ex1, Some ex2 ->
| OK ll, None -> ll
| (OK _ | New), Some ex -> raise ex
| Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt
| Ex ex1, Some ex2 ->
let bt2 = Printexc.get_raw_backtrace () in
let ex, bt = Exn.combine ex1 (ex2, bt2) in
Printexc.raise_with_backtrace ex bt
| `None, None -> assert false
| New, None -> assert false

let n_any ll = List.rev (rev_any ll)

let any ?(combine = (fun x _ -> x)) ll =
(* The results are backwards *)
let rec reduce_right = function
| [y; x] -> combine x y
| x :: rest -> combine (reduce_right rest) x
| [] -> assert false
in
match rev_any ll with
| [x] -> x
| ll -> reduce_right ll

let first f g = any [f; g]
let first ?combine f g = any ?combine [f; g]

let is_cancelled () =
let ctx = Effect.perform Cancel.Get_context in
Expand Down
118 changes: 117 additions & 1 deletion tests/fiber.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Second finishes, first is cancelled:
- : unit = ()
```

If both succeed, we pick the first one:
If both succeed and no ~combine, we pick the first one by default:

```ocaml
# run @@ fun () ->
Expand All @@ -49,6 +49,73 @@ If both succeed, we pick the first one:
- : unit = ()
```

If both succeed we let ~combine decide:

```ocaml
# run @@ fun () ->
Fiber.first ~combine:(fun _ x -> x)
(fun () -> "a")
(fun () -> "b");;
+b
- : unit = ()
```

It allows for safe Stream.take races (both):

```ocaml
# run @@ fun () ->
let stream = Eio.Stream.create 1 in
Fiber.first ~combine:(fun x y -> x ^ y)
(fun () ->
Fiber.yield ();
Eio.Stream.add stream "b";
"a"
)
(fun () -> Eio.Stream.take stream);;
+ab
- : unit = ()
```

It allows for safe Stream.take races (f is first):

```ocaml
# run @@ fun () ->
let stream = Eio.Stream.create 1 in
let out =
Fiber.first ~combine:(fun x y -> x ^ y)
(fun () ->
Eio.Stream.add stream "b";
Fiber.yield ();
"a"
)
(fun () ->
Fiber.yield ();
Eio.Stream.take stream)
in
out ^ Int.to_string (Eio.Stream.length stream);;
+a1
- : unit = ()
```

It allows for safe Stream.take races (g is first):

```ocaml
# run @@ fun () ->
let stream = Eio.Stream.create 1 in
let out =
Fiber.first ~combine:(fun x y -> x ^ y)
(fun () ->
Eio.Stream.add stream "b";
Fiber.yield ();
"a"
)
(fun () -> Eio.Stream.take stream)
in
out ^ Int.to_string (Eio.Stream.length stream);;
+b0
- : unit = ()
```

One crashes - report it:

```ocaml
Expand Down Expand Up @@ -201,6 +268,55 @@ Exception: Stdlib.Exit.
- : unit = ()
```

`Fiber.any` with combine collects all results:

```ocaml
# run @@ fun () ->
Fiber.any
~combine:(fun x y -> x @ y)
(List.init 3 (fun x () -> traceln "%d" x; [x]))
|> List.map string_of_int
|> String.concat ","
;;
+0
+1
+2
+0,1,2
- : unit = ()
```

# Fiber.n_any

`Fiber.n_any` behaves just like `Fiber.any` when there's only one result:

```ocaml
# run @@ fun () ->
Fiber.n_any (List.init 3 (fun x () -> traceln "%d" x; Fiber.yield (); x))
|> List.map string_of_int
|> String.concat ","
;;
+0
+1
+2
+0
- : unit = ()
```

`Fiber.n_any` collects all results:

```ocaml
# run @@ fun () ->
(Fiber.n_any (List.init 3 (fun x () -> traceln "%d" x; x)))
|> List.map string_of_int
|> String.concat ","
;;
+0
+1
+2
+0,1,2
- : unit = ()
```

# Fiber.await_cancel

```ocaml
Expand Down

0 comments on commit 0b10025

Please sign in to comment.