Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safe Fiber races: ~combine and Fiber.n_any #587

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions lib_eio/core/eio__core.mli
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ module Fiber : sig
(** [all fs] is like [both], but for any number of fibers.
[all []] returns immediately. *)

val first : (unit -> 'a) -> (unit -> 'a) -> 'a
val first : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a
(** [first f g] runs [f ()] and [g ()] concurrently.

They run in a new cancellation sub-context, and when one finishes the other is cancelled.
Expand All @@ -216,15 +216,24 @@ module Fiber : sig

If both fibers fail, {!Exn.combine} is used to combine the exceptions.

Warning: it is always possible that {i both} operations will succeed (and one result will be thrown away).
This is because there is a period of time after the first operation succeeds,
but before its fiber finishes, during which the other operation may also succeed. *)
SGrondin marked this conversation as resolved.
Show resolved Hide resolved
Warning: it is always possible that {i both} operations will succeed.
This is because there is a period of time after the first operation succeeds
when it is waiting in the run-queue to resume
during which the other operation may also succeed.

val any : (unit -> 'a) list -> 'a
If both fibers succeed, [combine a b] is used to combine the results
(where [a] is the result of the first fiber to return and [b] is the second result).
The default is [fun a _ -> a], which discards the later result. *)

val any : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) list -> 'a
(** [any fs] is like [first], but for any number of fibers.

[any []] just waits forever (or until cancelled). *)

val n_any : (unit -> 'a) list -> 'a list
(** [n_any fs] is like [any], expect that if multiple fibers return values
then they are all returned, in the order in which the fibers finished. *)

val await_cancel : unit -> 'a
(** [await_cancel ()] waits until cancelled.
@raise Cancel.Cancelled *)
Expand Down
41 changes: 26 additions & 15 deletions lib_eio/core/fiber.ml
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,22 @@ let await_cancel () =
Suspend.enter "await_cancel" @@ fun fiber enqueue ->
Cancel.Fiber_context.set_cancel_fn fiber (fun ex -> enqueue (Error ex))

let any fs =
let r = ref `None in
type 'a any_status =
| New
| Ex of (exn * Printexc.raw_backtrace)
| OK of 'a

let any_gen ~return ~combine fs =
let r = ref New in
let parent_c =
Cancel.sub_unchecked Any (fun cc ->
let wrap h =
match h () with
| x ->
begin match !r with
| `None -> r := `Ok x; Cancel.cancel cc Not_first
| `Ex _ | `Ok _ -> ()
| New -> r := OK (return x); Cancel.cancel cc Not_first
| OK prev -> r := OK (combine prev x)
| Ex _ -> ()
end
| exception Cancel.Cancelled _ when not (Cancel.is_on cc) ->
(* If this is in response to us asking the fiber to cancel then we can just ignore it.
Expand All @@ -105,11 +111,11 @@ let any fs =
()
| exception ex ->
begin match !r with
| `None -> r := `Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex
| `Ok _ -> r := `Ex (ex, Printexc.get_raw_backtrace ())
| `Ex prev ->
| New -> r := Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex
| OK _ -> r := Ex (ex, Printexc.get_raw_backtrace ())
| Ex prev ->
let bt = Printexc.get_raw_backtrace () in
r := `Ex (Exn.combine prev (ex, bt))
r := Ex (Exn.combine prev (ex, bt))
end
in
let vars = Cancel.Fiber_context.get_vars () in
Expand All @@ -121,7 +127,7 @@ let any fs =
let p, r = Promise.create_with_id (Cancel.Fiber_context.tid new_fiber) in
fork_raw new_fiber (fun () ->
match wrap f with
| x -> Promise.resolve_ok r x
| () -> Promise.resolve_ok r ()
| exception ex -> Promise.resolve_error r ex
);
p :: aux fs
Expand All @@ -131,16 +137,21 @@ let any fs =
)
in
match !r, Cancel.get_error parent_c with
| `Ok r, None -> r
| (`Ok _ | `None), Some ex -> raise ex
| `Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt
| `Ex ex1, Some ex2 ->
| OK r, None -> r
| (OK _ | New), Some ex -> raise ex
| Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt
| Ex ex1, Some ex2 ->
let bt2 = Printexc.get_raw_backtrace () in
let ex, bt = Exn.combine ex1 (ex2, bt2) in
Printexc.raise_with_backtrace ex bt
| `None, None -> assert false
| New, None -> assert false

let n_any fs =
List.rev (any_gen fs ~return:(fun x -> [x]) ~combine:(fun xs x -> x :: xs))

let any ?(combine=(fun x _ -> x)) fs = any_gen fs ~return:Fun.id ~combine

let first f g = any [f; g]
let first ?combine f g = any ?combine [f; g]

let is_cancelled () =
let ctx = Effect.perform Cancel.Get_context in
Expand Down
117 changes: 116 additions & 1 deletion tests/fiber.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Second finishes, first is cancelled:
- : unit = ()
```

If both succeed, we pick the first one:
If both succeed and no ~combine, we pick the first one by default:

```ocaml
# run @@ fun () ->
Expand All @@ -49,6 +49,73 @@ If both succeed, we pick the first one:
- : unit = ()
```

If both succeed we let ~combine decide:

```ocaml
# run @@ fun () ->
Fiber.first ~combine:(fun _ x -> x)
(fun () -> "a")
(fun () -> "b");;
+b
- : unit = ()
```

It allows for safe Stream.take races (both):

```ocaml
# run @@ fun () ->
let stream = Eio.Stream.create 1 in
Fiber.first ~combine:(fun x y -> x ^ y)
(fun () ->
Fiber.yield ();
Eio.Stream.add stream "b";
"a"
)
(fun () -> Eio.Stream.take stream);;
+ab
- : unit = ()
```

It allows for safe Stream.take races (f is first):

```ocaml
# run @@ fun () ->
let stream = Eio.Stream.create 1 in
let out =
Fiber.first ~combine:(fun x y -> x ^ y)
(fun () ->
Eio.Stream.add stream "b";
Fiber.yield ();
"a"
)
(fun () ->
Fiber.yield ();
Eio.Stream.take stream)
in
out ^ Int.to_string (Eio.Stream.length stream);;
+a1
- : unit = ()
```

It allows for safe Stream.take races (g is first):

```ocaml
# run @@ fun () ->
let stream = Eio.Stream.create 1 in
let out =
Fiber.first ~combine:(fun x y -> x ^ y)
(fun () ->
Eio.Stream.add stream "b";
Fiber.yield ();
"a"
)
(fun () -> Eio.Stream.take stream)
in
out ^ Int.to_string (Eio.Stream.length stream);;
+b0
- : unit = ()
```

One crashes - report it:

```ocaml
Expand Down Expand Up @@ -201,6 +268,54 @@ Exception: Stdlib.Exit.
- : unit = ()
```

`Fiber.any` with combine collects all results:

```ocaml
# run @@ fun () ->
Fiber.any
~combine:(fun x y -> x @ y)
(List.init 3 (fun x () -> traceln "%d" x; [x]))
|> Fmt.(str "%a" (Dump.list int));;
+0
+1
+2
+[0; 1; 2]
- : unit = ()
```

# Fiber.n_any

`Fiber.n_any` behaves just like `Fiber.any` when there's only one result:

```ocaml
# run @@ fun () ->
Fiber.n_any (List.init 3 (fun x () -> traceln "%d" x; Fiber.yield (); x))
|> Fmt.(str "%a" (Dump.list int));;
+0
+1
+2
+[0]
- : unit = ()
```

`Fiber.n_any` collects all results:

```ocaml
# run @@ fun () ->
(Fiber.n_any (List.init 4 (fun x () ->
traceln "%d" x;
if x = 1 then Fiber.yield ();
x
)))
|> Fmt.(str "%a" (Dump.list int));;
+0
+1
+2
+3
+[0; 2; 3]
- : unit = ()
```

# Fiber.await_cancel

```ocaml
Expand Down
Loading