From cbb6ece7b242127bcca1a020b2e6457d7bdca68a Mon Sep 17 00:00:00 2001 From: Simon Grondin Date: Wed, 19 Jul 2023 17:10:32 -0500 Subject: [PATCH] Safe Fiber races: ~combine and n_any Co-authored-by: Thomas Leonard --- lib_eio/core/eio__core.mli | 19 ++++-- lib_eio/core/fiber.ml | 41 ++++++++----- tests/fiber.md | 117 ++++++++++++++++++++++++++++++++++++- 3 files changed, 156 insertions(+), 21 deletions(-) diff --git a/lib_eio/core/eio__core.mli b/lib_eio/core/eio__core.mli index 4a2fac252..08fe261c9 100644 --- a/lib_eio/core/eio__core.mli +++ b/lib_eio/core/eio__core.mli @@ -206,7 +206,7 @@ module Fiber : sig (** [all fs] is like [both], but for any number of fibers. [all []] returns immediately. *) - val first : (unit -> 'a) -> (unit -> 'a) -> 'a + val first : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a (** [first f g] runs [f ()] and [g ()] concurrently. They run in a new cancellation sub-context, and when one finishes the other is cancelled. @@ -216,15 +216,24 @@ module Fiber : sig If both fibers fail, {!Exn.combine} is used to combine the exceptions. - Warning: it is always possible that {i both} operations will succeed (and one result will be thrown away). - This is because there is a period of time after the first operation succeeds, - but before its fiber finishes, during which the other operation may also succeed. *) + Warning: it is always possible that {i both} operations will succeed. + This is because there is a period of time after the first operation succeeds + when it is waiting in the run-queue to resume + during which the other operation may also succeed. - val any : (unit -> 'a) list -> 'a + If both fibers succeed, [combine a b] is used to combine the results + (where [a] is the result of the first fiber to return and [b] is the second result). + The default is [fun a _ -> a], which discards the later result. *) + + val any : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) list -> 'a (** [any fs] is like [first], but for any number of fibers. [any []] just waits forever (or until cancelled). *) + val n_any : (unit -> 'a) list -> 'a list + (** [n_any fs] is like [any], expect that if multiple fibers return values + then they are all returned, in the order in which the fibers finished. *) + val await_cancel : unit -> 'a (** [await_cancel ()] waits until cancelled. @raise Cancel.Cancelled *) diff --git a/lib_eio/core/fiber.ml b/lib_eio/core/fiber.ml index feebec027..1035a1fe0 100644 --- a/lib_eio/core/fiber.ml +++ b/lib_eio/core/fiber.ml @@ -87,16 +87,22 @@ let await_cancel () = Suspend.enter "await_cancel" @@ fun fiber enqueue -> Cancel.Fiber_context.set_cancel_fn fiber (fun ex -> enqueue (Error ex)) -let any fs = - let r = ref `None in +type 'a any_status = + | New + | Ex of (exn * Printexc.raw_backtrace) + | OK of 'a + +let any_gen ~return ~combine fs = + let r = ref New in let parent_c = Cancel.sub_unchecked Any (fun cc -> let wrap h = match h () with | x -> begin match !r with - | `None -> r := `Ok x; Cancel.cancel cc Not_first - | `Ex _ | `Ok _ -> () + | New -> r := OK (return x); Cancel.cancel cc Not_first + | OK prev -> r := OK (combine prev x) + | Ex _ -> () end | exception Cancel.Cancelled _ when not (Cancel.is_on cc) -> (* If this is in response to us asking the fiber to cancel then we can just ignore it. @@ -105,11 +111,11 @@ let any fs = () | exception ex -> begin match !r with - | `None -> r := `Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex - | `Ok _ -> r := `Ex (ex, Printexc.get_raw_backtrace ()) - | `Ex prev -> + | New -> r := Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex + | OK _ -> r := Ex (ex, Printexc.get_raw_backtrace ()) + | Ex prev -> let bt = Printexc.get_raw_backtrace () in - r := `Ex (Exn.combine prev (ex, bt)) + r := Ex (Exn.combine prev (ex, bt)) end in let vars = Cancel.Fiber_context.get_vars () in @@ -121,7 +127,7 @@ let any fs = let p, r = Promise.create_with_id (Cancel.Fiber_context.tid new_fiber) in fork_raw new_fiber (fun () -> match wrap f with - | x -> Promise.resolve_ok r x + | () -> Promise.resolve_ok r () | exception ex -> Promise.resolve_error r ex ); p :: aux fs @@ -131,16 +137,21 @@ let any fs = ) in match !r, Cancel.get_error parent_c with - | `Ok r, None -> r - | (`Ok _ | `None), Some ex -> raise ex - | `Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt - | `Ex ex1, Some ex2 -> + | OK r, None -> r + | (OK _ | New), Some ex -> raise ex + | Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt + | Ex ex1, Some ex2 -> let bt2 = Printexc.get_raw_backtrace () in let ex, bt = Exn.combine ex1 (ex2, bt2) in Printexc.raise_with_backtrace ex bt - | `None, None -> assert false + | New, None -> assert false + +let n_any fs = + List.rev (any_gen fs ~return:(fun x -> [x]) ~combine:(fun xs x -> x :: xs)) + +let any ?(combine=(fun x _ -> x)) fs = any_gen fs ~return:Fun.id ~combine -let first f g = any [f; g] +let first ?combine f g = any ?combine [f; g] let is_cancelled () = let ctx = Effect.perform Cancel.Get_context in diff --git a/tests/fiber.md b/tests/fiber.md index 53cdf7662..fe06a9a93 100644 --- a/tests/fiber.md +++ b/tests/fiber.md @@ -38,7 +38,7 @@ Second finishes, first is cancelled: - : unit = () ``` -If both succeed, we pick the first one: +If both succeed and no ~combine, we pick the first one by default: ```ocaml # run @@ fun () -> @@ -49,6 +49,73 @@ If both succeed, we pick the first one: - : unit = () ``` +If both succeed we let ~combine decide: + +```ocaml +# run @@ fun () -> + Fiber.first ~combine:(fun _ x -> x) + (fun () -> "a") + (fun () -> "b");; ++b +- : unit = () +``` + +It allows for safe Stream.take races (both): + +```ocaml +# run @@ fun () -> + let stream = Eio.Stream.create 1 in + Fiber.first ~combine:(fun x y -> x ^ y) + (fun () -> + Fiber.yield (); + Eio.Stream.add stream "b"; + "a" + ) + (fun () -> Eio.Stream.take stream);; ++ab +- : unit = () +``` + +It allows for safe Stream.take races (f is first): + +```ocaml +# run @@ fun () -> + let stream = Eio.Stream.create 1 in + let out = + Fiber.first ~combine:(fun x y -> x ^ y) + (fun () -> + Eio.Stream.add stream "b"; + Fiber.yield (); + "a" + ) + (fun () -> + Fiber.yield (); + Eio.Stream.take stream) + in + out ^ Int.to_string (Eio.Stream.length stream);; ++a1 +- : unit = () +``` + +It allows for safe Stream.take races (g is first): + +```ocaml +# run @@ fun () -> + let stream = Eio.Stream.create 1 in + let out = + Fiber.first ~combine:(fun x y -> x ^ y) + (fun () -> + Eio.Stream.add stream "b"; + Fiber.yield (); + "a" + ) + (fun () -> Eio.Stream.take stream) + in + out ^ Int.to_string (Eio.Stream.length stream);; ++b0 +- : unit = () +``` + One crashes - report it: ```ocaml @@ -201,6 +268,54 @@ Exception: Stdlib.Exit. - : unit = () ``` +`Fiber.any` with combine collects all results: + +```ocaml +# run @@ fun () -> + Fiber.any + ~combine:(fun x y -> x @ y) + (List.init 3 (fun x () -> traceln "%d" x; [x])) + |> Fmt.(str "%a" (Dump.list int));; ++0 ++1 ++2 ++[0; 1; 2] +- : unit = () +``` + +# Fiber.n_any + +`Fiber.n_any` behaves just like `Fiber.any` when there's only one result: + +```ocaml +# run @@ fun () -> + Fiber.n_any (List.init 3 (fun x () -> traceln "%d" x; Fiber.yield (); x)) + |> Fmt.(str "%a" (Dump.list int));; ++0 ++1 ++2 ++[0] +- : unit = () +``` + +`Fiber.n_any` collects all results: + +```ocaml +# run @@ fun () -> + (Fiber.n_any (List.init 4 (fun x () -> + traceln "%d" x; + if x = 1 then Fiber.yield (); + x + ))) + |> Fmt.(str "%a" (Dump.list int));; ++0 ++1 ++2 ++3 ++[0; 2; 3] +- : unit = () +``` + # Fiber.await_cancel ```ocaml