diff --git a/lib_eio/core/eio__core.mli b/lib_eio/core/eio__core.mli index 830680a86..3bbafa555 100644 --- a/lib_eio/core/eio__core.mli +++ b/lib_eio/core/eio__core.mli @@ -199,7 +199,7 @@ module Fiber : sig (** [all fs] is like [both], but for any number of fibers. [all []] returns immediately. *) - val first : (unit -> 'a) -> (unit -> 'a) -> 'a + val first : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a (** [first f g] runs [f ()] and [g ()] concurrently. They run in a new cancellation sub-context, and when one finishes the other is cancelled. @@ -209,15 +209,18 @@ module Fiber : sig If both fibers fail, {!Exn.combine} is used to combine the exceptions. - Warning: it is always possible that {i both} operations will succeed (and one result will be thrown away). - This is because there is a period of time after the first operation succeeds, - but before its fiber finishes, during which the other operation may also succeed. *) + If both fibers succeed simultaneously, [combine] is used to combine the results. + If [combine] is not provided, the first one to complete is returned. *) - val any : (unit -> 'a) list -> 'a + val any : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) list -> 'a (** [any fs] is like [first], but for any number of fibers. [any []] just waits forever (or until cancelled). *) + val n_any : (unit -> 'a) list -> 'a list + (** [n_any fs] is like [any], the difference being that when multiple fibers are + completed simultaneously (and none are rejected), all the results are returned. *) + val await_cancel : unit -> 'a (** [await_cancel ()] waits until cancelled. @raise Cancel.Cancelled *) diff --git a/lib_eio/core/fiber.ml b/lib_eio/core/fiber.ml index 14e40f44d..ae08f51d9 100644 --- a/lib_eio/core/fiber.ml +++ b/lib_eio/core/fiber.ml @@ -89,16 +89,22 @@ let await_cancel () = Suspend.enter @@ fun fiber enqueue -> Cancel.Fiber_context.set_cancel_fn fiber (fun ex -> enqueue (Error ex)) -let any fs = - let r = ref `None in +type 'a any_status = + | New + | Ex of (exn * Printexc.raw_backtrace) + | OK of 'a list + +let rev_any fs = + let r = ref New in let parent_c = Cancel.sub_unchecked (fun cc -> let wrap h = match h () with | x -> begin match !r with - | `None -> r := `Ok x; Cancel.cancel cc Not_first - | `Ex _ | `Ok _ -> () + | New -> r := OK [x]; Cancel.cancel cc Not_first + | OK ll -> r := OK (x :: ll) + | Ex _ -> () end | exception Cancel.Cancelled _ when not (Cancel.is_on cc) -> (* If this is in response to us asking the fiber to cancel then we can just ignore it. @@ -107,11 +113,11 @@ let any fs = () | exception ex -> begin match !r with - | `None -> r := `Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex - | `Ok _ -> r := `Ex (ex, Printexc.get_raw_backtrace ()) - | `Ex prev -> + | New -> r := Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex + | OK _ -> r := Ex (ex, Printexc.get_raw_backtrace ()) + | Ex prev -> let bt = Printexc.get_raw_backtrace () in - r := `Ex (Exn.combine prev (ex, bt)) + r := Ex (Exn.combine prev (ex, bt)) end in let vars = Cancel.Fiber_context.get_vars () in @@ -123,7 +129,7 @@ let any fs = let p, r = Promise.create_with_id (Cancel.Fiber_context.tid new_fiber) in fork_raw new_fiber (fun () -> match wrap f with - | x -> Promise.resolve_ok r x + | () -> Promise.resolve_ok r () | exception ex -> Promise.resolve_error r ex ); p :: aux fs @@ -133,16 +139,29 @@ let any fs = ) in match !r, Cancel.get_error parent_c with - | `Ok r, None -> r - | (`Ok _ | `None), Some ex -> raise ex - | `Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt - | `Ex ex1, Some ex2 -> + | OK ll, None -> ll + | (OK _ | New), Some ex -> raise ex + | Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt + | Ex ex1, Some ex2 -> let bt2 = Printexc.get_raw_backtrace () in let ex, bt = Exn.combine ex1 (ex2, bt2) in Printexc.raise_with_backtrace ex bt - | `None, None -> assert false + | New, None -> assert false + +let n_any ll = List.rev (rev_any ll) + +let any ?(combine = (fun x _ -> x)) ll = + (* The results are backwards *) + let rec reduce_right = function + | [y; x] -> combine x y + | x :: rest -> combine (reduce_right rest) x + | [] -> assert false + in + match rev_any ll with + | [x] -> x + | ll -> reduce_right ll -let first f g = any [f; g] +let first ?combine f g = any ?combine [f; g] let check () = let ctx = Effect.perform Cancel.Get_context in diff --git a/tests/fiber.md b/tests/fiber.md index 53cdf7662..d2d52356a 100644 --- a/tests/fiber.md +++ b/tests/fiber.md @@ -38,7 +38,7 @@ Second finishes, first is cancelled: - : unit = () ``` -If both succeed, we pick the first one: +If both succeed and no ~combine, we pick the first one by default: ```ocaml # run @@ fun () -> @@ -49,6 +49,73 @@ If both succeed, we pick the first one: - : unit = () ``` +If both succeed we let ~combine decide: + +```ocaml +# run @@ fun () -> + Fiber.first ~combine:(fun _ x -> x) + (fun () -> "a") + (fun () -> "b");; ++b +- : unit = () +``` + +It allows for safe Stream.take races (both): + +```ocaml +# run @@ fun () -> + let stream = Eio.Stream.create 1 in + Fiber.first ~combine:(fun x y -> x ^ y) + (fun () -> + Fiber.yield (); + Eio.Stream.add stream "b"; + "a" + ) + (fun () -> Eio.Stream.take stream);; ++ab +- : unit = () +``` + +It allows for safe Stream.take races (f is first): + +```ocaml +# run @@ fun () -> + let stream = Eio.Stream.create 1 in + let out = + Fiber.first ~combine:(fun x y -> x ^ y) + (fun () -> + Eio.Stream.add stream "b"; + Fiber.yield (); + "a" + ) + (fun () -> + Fiber.yield (); + Eio.Stream.take stream) + in + out ^ Int.to_string (Eio.Stream.length stream);; ++a1 +- : unit = () +``` + +It allows for safe Stream.take races (g is first): + +```ocaml +# run @@ fun () -> + let stream = Eio.Stream.create 1 in + let out = + Fiber.first ~combine:(fun x y -> x ^ y) + (fun () -> + Eio.Stream.add stream "b"; + Fiber.yield (); + "a" + ) + (fun () -> Eio.Stream.take stream) + in + out ^ Int.to_string (Eio.Stream.length stream);; ++b0 +- : unit = () +``` + One crashes - report it: ```ocaml @@ -201,6 +268,55 @@ Exception: Stdlib.Exit. - : unit = () ``` +`Fiber.any` with combine collects all results: + +```ocaml +# run @@ fun () -> + Fiber.any + ~combine:(fun x y -> x @ y) + (List.init 3 (fun x () -> traceln "%d" x; [x])) + |> List.map string_of_int + |> String.concat "," + ;; ++0 ++1 ++2 ++0,1,2 +- : unit = () +``` + +# Fiber.n_any + +`Fiber.n_any` behaves just like `Fiber.any` when there's only one result: + +```ocaml +# run @@ fun () -> + Fiber.n_any (List.init 3 (fun x () -> traceln "%d" x; Fiber.yield (); x)) + |> List.map string_of_int + |> String.concat "," + ;; ++0 ++1 ++2 ++0 +- : unit = () +``` + +`Fiber.n_any` collects all results: + +```ocaml +# run @@ fun () -> + (Fiber.n_any (List.init 3 (fun x () -> traceln "%d" x; x))) + |> List.map string_of_int + |> String.concat "," + ;; ++0 ++1 ++2 ++0,1,2 +- : unit = () +``` + # Fiber.await_cancel ```ocaml