From f27467ac96a606a0415a03b74232af029753381d Mon Sep 17 00:00:00 2001 From: Guillaume Endignoux Date: Sun, 1 Dec 2024 18:05:02 +0100 Subject: [PATCH] Add find_first() adaptor. --- src/core/range.rs | 328 +++++++++++++++++++++++++++++++++++++--- src/core/thread_pool.rs | 120 ++++++++++++++- src/iter/mod.rs | 169 +++++++++++++++++++++ src/iter/source/mod.rs | 17 +++ src/lib.rs | 59 ++++++++ 5 files changed, 672 insertions(+), 21 deletions(-) diff --git a/src/core/range.rs b/src/core/range.rs index f58010b..351408b 100644 --- a/src/core/range.rs +++ b/src/core/range.rs @@ -48,14 +48,28 @@ pub trait RangeOrchestrator { /// A range of items similar to [`std::ops::Range`], but that can steal from or /// be stolen by other threads. pub trait Range { + /// Type of iterator returned by [`iter()`](Self::iter). type Iter<'a>: Iterator where Self: 'a; + /// Type of iterator returned by + /// [`upper_bounded_iter()`](Self::upper_bounded_iter). + type UpperBoundedIter<'a, 'bound>: Iterator + where + Self: 'a; + /// Returns an iterator over the items in this range. The item can be /// dynamically stolen from/by other threads, but the iterator provides /// a safe abstraction over that. fn iter(&self) -> Self::Iter<'_>; + + /// Returns an iterator over the items in this range. Items larger than the + /// (dynamic) bound are skipped. + fn upper_bounded_iter<'a, 'bound>( + &'a self, + bound: &'bound AtomicUsize, + ) -> Self::UpperBoundedIter<'a, 'bound>; } /// A factory that hands out a fixed range to each thread, without any stealing. @@ -114,10 +128,8 @@ pub struct FixedRange { num_elements: Arc, } -impl Range for FixedRange { - type Iter<'a> = std::ops::Range; - - fn iter(&self) -> Self::Iter<'_> { +impl FixedRange { + fn range(&self) -> std::ops::Range { let num_elements = self.num_elements.load(Ordering::Relaxed); let start = (self.id * num_elements) / self.num_threads; let end = ((self.id + 1) * num_elements) / self.num_threads; @@ -125,6 +137,49 @@ impl Range for FixedRange { } } +impl Range for FixedRange { + type Iter<'a> = std::ops::Range; + type UpperBoundedIter<'a, 'bound> = UpperBoundedRange<'bound>; + + fn iter(&self) -> Self::Iter<'_> { + self.range() + } + + fn upper_bounded_iter<'a, 'bound>( + &'a self, + bound: &'bound AtomicUsize, + ) -> Self::UpperBoundedIter<'a, 'bound> { + UpperBoundedRange { + range: self.range(), + bound, + } + } +} + +/// An upper-bounded iterator for a [`FixedRange`]. +pub struct UpperBoundedRange<'bound> { + /// Underlying contiguous range. + range: std::ops::Range, + /// Dynamic upper bound. + bound: &'bound AtomicUsize, +} + +impl Iterator for UpperBoundedRange<'_> { + type Item = usize; + + fn next(&mut self) -> Option { + self.range.next().and_then(|x| { + if x <= self.bound.load(Ordering::Relaxed) { + Some(x) + } else { + // The upper bound can only decrease, so once it's reached the iterator is + // exhausted. + None + } + }) + } +} + /// A factory for ranges that implement work stealing among threads. /// /// Whenever a thread finishes processing its range, it looks for another range @@ -184,7 +239,7 @@ pub struct WorkStealingRangeOrchestrator { impl RangeOrchestrator for WorkStealingRangeOrchestrator { fn reset_ranges(&self, num_elements: usize) { - log_debug!("Resetting ranges."); + log_debug!("Resetting ranges"); let num_threads = self.ranges.len() as u64; let num_elements = u32::try_from(num_elements).unwrap_or_else(|_| { panic!( @@ -232,6 +287,7 @@ pub struct WorkStealingRange { impl Range for WorkStealingRange { type Iter<'a> = WorkStealingRangeIterator<'a>; + type UpperBoundedIter<'a, 'bound> = UpperBoundedWorkStealingRangeIterator<'a, 'bound>; fn iter(&self) -> Self::Iter<'_> { WorkStealingRangeIterator { @@ -243,6 +299,21 @@ impl Range for WorkStealingRange { global_stats: self.stats.clone(), } } + + fn upper_bounded_iter<'a, 'bound>( + &'a self, + bound: &'bound AtomicUsize, + ) -> Self::UpperBoundedIter<'a, 'bound> { + UpperBoundedWorkStealingRangeIterator { + id: self.id, + ranges: &self.ranges, + bound, + #[cfg(feature = "log_parallelism")] + stats: WorkStealingStats::default(), + #[cfg(feature = "log_parallelism")] + global_stats: self.stats.clone(), + } + } } /// A [start, end) pair that can atomically be modified. @@ -318,6 +389,20 @@ impl PackedRange { self.end() - self.start() } + /// Checks if the range is empty. + #[inline(always)] + fn is_empty(self) -> bool { + self.start() == self.end() + } + + /// Upper bound this range by the given maximum. + #[inline(always)] + fn upper_bound(self, bound: usize) -> Self { + let start = (self.start() as usize).min(bound) as u32; + let end = (self.end() as usize).min(bound) as u32; + Self::new(start, end) + } + /// Increments the start of the range. #[inline(always)] fn increment_start(self) -> (u32, Self) { @@ -333,16 +418,7 @@ impl PackedRange { let end = self.end(); // The result fits in u32 because the inputs fit in u32. let middle = ((start as u64 + end as u64) / 2) as u32; - ( - PackedRange::new(start, middle), - PackedRange::new(middle, end), - ) - } - - /// Checks if the range is empty. - #[inline(always)] - fn is_empty(self) -> bool { - self.start() == self.end() + (Self::new(start, middle), Self::new(middle, end)) } } @@ -392,6 +468,13 @@ pub struct WorkStealingRangeIterator<'a> { global_stats: Arc>, } +#[cfg(feature = "log_parallelism")] +impl Drop for WorkStealingRangeIterator<'_> { + fn drop(&mut self) { + *self.global_stats.lock().unwrap() += &self.stats; + } +} + impl Iterator for WorkStealingRangeIterator<'_> { type Item = usize; @@ -410,7 +493,7 @@ impl Iterator for WorkStealingRangeIterator<'_> { { self.stats.increments += 1; log_trace!( - "[thread {}] Incremented range to {}..{}.", + "[thread {}] Incremented range to {}..{}", self.id, my_new_range.start(), my_new_range.end() @@ -425,7 +508,7 @@ impl Iterator for WorkStealingRangeIterator<'_> { { self.stats.failed_increments += 1; log_debug!( - "[thread {}] Failed to increment range, new range is {}..{}.", + "[thread {}] Failed to increment range, new range is {}..{}", self.id, range.start(), range.end() @@ -455,7 +538,7 @@ impl WorkStealingRangeIterator<'_> { #[cfg(feature = "log_parallelism")] log_debug!( - "[thread {}] Range {}..{} is empty, scanning other threads.", + "[thread {}] Range {}..{} is empty, scanning other threads", self.id, my_range.start(), my_range.end() @@ -506,6 +589,14 @@ impl WorkStealingRangeIterator<'_> { #[cfg(feature = "log_parallelism")] { self.stats.thefts += 1; + log_trace!( + "[thread {}] Stole range {}:{}..{} from thread {}", + self.id, + taken, + my_new_range.start(), + my_new_range.end(), + max_index + ); } return Some(taken as usize); } @@ -534,10 +625,207 @@ impl WorkStealingRangeIterator<'_> { // Didn't manage to steal anything: exit the iterator. #[cfg(feature = "log_parallelism")] + log_debug!("[thread {}] Didn't find anything to steal", self.id); + None + } +} + +/// A upper-bounded iterator for a [`WorkStealingRange`]. +pub struct UpperBoundedWorkStealingRangeIterator<'a, 'bound> { + /// Index of the thread that owns this range. + id: usize, + /// Handle to the ranges of all the threads. + ranges: &'a [AtomicRange], + /// Dynamic upper bound. + bound: &'bound AtomicUsize, + /// Local work-stealing statistics. + #[cfg(feature = "log_parallelism")] + #[cfg_attr(docsrs, doc(cfg(feature = "log_parallelism")))] + stats: WorkStealingStats, + /// Handle to the global work-stealing statistics. + #[cfg(feature = "log_parallelism")] + #[cfg_attr(docsrs, doc(cfg(feature = "log_parallelism")))] + global_stats: Arc>, +} + +#[cfg(feature = "log_parallelism")] +impl Drop for UpperBoundedWorkStealingRangeIterator<'_, '_> { + fn drop(&mut self) { + *self.global_stats.lock().unwrap() += &self.stats; + } +} + +impl Iterator for UpperBoundedWorkStealingRangeIterator<'_, '_> { + type Item = usize; + + fn next(&mut self) -> Option { + let bound = self.bound.load(Ordering::Relaxed); + let my_atomic_range: &AtomicRange = &self.ranges[self.id]; + let mut my_loaded_range: PackedRange = my_atomic_range.load(); + let mut my_bounded_range = my_loaded_range.upper_bound(bound); + + #[cfg(feature = "log_parallelism")] + { + log_trace!("[thread {}] Loaded upper bound = {}", self.id, bound); + } + + // First phase: try to increment this thread's own range. Retries are needed in + // case another thread stole part of the range. + while !my_bounded_range.is_empty() { + let (taken, my_new_range) = my_bounded_range.increment_start(); + match my_atomic_range.compare_exchange(my_loaded_range, my_new_range) { + // Increment succeeded. + Ok(()) => { + #[cfg(feature = "log_parallelism")] + { + self.stats.increments += 1; + log_trace!( + "[thread {}] Incremented range to {}..{}", + self.id, + my_new_range.start(), + my_new_range.end() + ); + } + return Some(taken as usize); + } + // Increment failed: retry with an updated range. + Err(range) => { + my_loaded_range = range; + my_bounded_range = my_loaded_range.upper_bound(bound); + #[cfg(feature = "log_parallelism")] + { + self.stats.failed_increments += 1; + log_debug!( + "[thread {}] Failed to increment range, new range is {}..{}", + self.id, + range.start(), + range.end() + ); + } + continue; + } + } + } + + // Second phase: the range is empty, try to steal a range from another thread. + self.steal( + bound, + #[cfg(feature = "log_parallelism")] + my_bounded_range, + ) + } +} + +#[derive(Clone, Copy, Default)] +struct OtherRange { + loaded: PackedRange, + bounded: PackedRange, +} + +impl UpperBoundedWorkStealingRangeIterator<'_, '_> { + /// Helper function for the iterator implementation, to steal a range from + /// another thread when this thread's range is empty. + fn steal( + &mut self, + bound: usize, + #[cfg(feature = "log_parallelism")] my_bounded_range: PackedRange, + ) -> Option { + let my_atomic_range: &AtomicRange = &self.ranges[self.id]; + + #[cfg(feature = "log_parallelism")] + log_debug!( + "[thread {}] Range {}..{} is empty, scanning other threads", + self.id, + my_bounded_range.start(), + my_bounded_range.end() + ); + let range_count = self.ranges.len(); + + // Read a snapshot of the other threads' ranges, to identify the best one to + // steal (the largest one). This is only used as a hint, and therefore it's fine + // that the underlying values may be concurrently modified by the other threads + // and that the snapshot becomes (slightly) out-of-date. + let mut other_ranges = vec![OtherRange::default(); range_count]; + for (i, range) in other_ranges.iter_mut().enumerate() { + if i == self.id { + continue; + } + let loaded = self.ranges[i].load(); + let bounded = loaded.upper_bound(bound); + *range = OtherRange { loaded, bounded }; + } + #[cfg(feature = "log_parallelism")] { - log_debug!("[thread {}] Didn't find anything to steal", self.id); - *self.global_stats.lock().unwrap() += &self.stats; + self.stats.other_loads += range_count as u64 - 1; } + + // Identify the thread with the largest range. + let mut max_index = 0; + let mut max_range = OtherRange::default(); + for (i, range) in other_ranges.iter().enumerate() { + if i == self.id { + continue; + } + if range.bounded.len() > max_range.bounded.len() { + max_index = i; + max_range = *range; + } + } + + // Try to steal another thread's range. Retries are needed in case the target + // thread incremented its range or if another thread stole part of the + // target thread's range. + while !max_range.bounded.is_empty() { + // Try to steal half of the range. + let (remaining, stolen) = max_range.bounded.split(); + match self.ranges[max_index].compare_exchange(max_range.loaded, remaining) { + // Theft succeeded. + Ok(()) => { + // Take the first item, and place the rest in this thread's own range. + let (taken, my_new_range) = stolen.increment_start(); + my_atomic_range.store(my_new_range); + #[cfg(feature = "log_parallelism")] + { + self.stats.thefts += 1; + log_trace!( + "[thread {}] Stole range {}:{}..{} from thread {}", + self.id, + taken, + my_new_range.start(), + my_new_range.end(), + max_index + ); + } + return Some(taken as usize); + } + // Theft failed: update the range and retry. + Err(loaded) => { + let bounded = loaded.upper_bound(bound); + let range = OtherRange { loaded, bounded }; + other_ranges[max_index] = range; + #[cfg(feature = "log_parallelism")] + { + self.stats.failed_thefts += 1; + } + + // Re-compute the largest range. + max_range = range; + for (i, range) in other_ranges.iter().enumerate() { + if i == self.id { + continue; + } + if range.bounded.len() > max_range.bounded.len() { + max_index = i; + max_range = *range; + } + } + } + } + } + + // Didn't manage to steal anything: exit the iterator. + #[cfg(feature = "log_parallelism")] + log_debug!("[thread {}] Didn't find anything to steal", self.id); None } } diff --git a/src/core/thread_pool.rs b/src/core/thread_pool.rs index ed17c62..c70cd12 100644 --- a/src/core/thread_pool.rs +++ b/src/core/thread_pool.rs @@ -33,7 +33,7 @@ use std::convert::TryFrom; use std::marker::PhantomData; use std::num::NonZeroUsize; use std::ops::ControlFlow; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; use std::thread::JoinHandle; @@ -162,6 +162,24 @@ impl ThreadPool { .short_circuiting_pipeline(input_len, init, process_item, finalize, reduce) } + /// Processes an input of the given length in parallel and returns the + /// aggregated output. + /// + /// With this variant, the pipeline may skip processing items at larger + /// indices whenever a call to `process_item` returns + /// [`ControlFlow::Break`]. + pub(crate) fn upper_bounded_pipeline( + &mut self, + input_len: usize, + init: impl Fn() -> Accum + Sync, + process_item: impl Fn(Accum, usize) -> ControlFlow + Sync, + finalize: impl Fn(Accum) -> Output + Sync, + reduce: impl Fn(Output, Output) -> Output, + ) -> Output { + self.inner + .upper_bounded_pipeline(input_len, init, process_item, finalize, reduce) + } + /// Processes an input of the given length in parallel and returns the /// aggregated output. pub(crate) fn iter_pipeline( @@ -257,6 +275,30 @@ impl ThreadPoolEnum { } } + /// Processes an input of the given length in parallel and returns the + /// aggregated output. + /// + /// With this variant, the pipeline may skip processing items at larger + /// indices whenever a call to `process_item` returns + /// [`ControlFlow::Break`]. + fn upper_bounded_pipeline( + &mut self, + input_len: usize, + init: impl Fn() -> Accum + Sync, + process_item: impl Fn(Accum, usize) -> ControlFlow + Sync, + finalize: impl Fn(Accum) -> Output + Sync, + reduce: impl Fn(Output, Output) -> Output, + ) -> Output { + match self { + ThreadPoolEnum::Fixed(inner) => { + inner.upper_bounded_pipeline(input_len, init, process_item, finalize, reduce) + } + ThreadPoolEnum::WorkStealing(inner) => { + inner.upper_bounded_pipeline(input_len, init, process_item, finalize, reduce) + } + } + } + /// Processes an input of the given length in parallel and returns the /// aggregated output. fn iter_pipeline( @@ -442,6 +484,43 @@ impl ThreadPoolImpl { .unwrap() } + /// Processes an input of the given length in parallel and returns the + /// aggregated output. + /// + /// With this variant, the pipeline may skip processing items at larger + /// indices whenever a call to `process_item` returns + /// [`ControlFlow::Break`]. + fn upper_bounded_pipeline( + &mut self, + input_len: usize, + init: impl Fn() -> Accum + Sync, + process_item: impl Fn(Accum, usize) -> ControlFlow + Sync, + finalize: impl Fn(Accum) -> Output + Sync, + reduce: impl Fn(Output, Output) -> Output, + ) -> Output { + self.range_orchestrator.reset_ranges(input_len); + + let num_threads = self.threads.len(); + let outputs = (0..num_threads) + .map(|_| Mutex::new(None)) + .collect::>(); + let bound = AtomicUsize::new(usize::MAX); + + self.pipeline.lend(&UpperBoundedPipelineImpl { + bound: CachePadded::new(bound), + outputs: outputs.clone(), + init, + process_item, + finalize, + }); + + outputs + .iter() + .map(move |output| output.lock().unwrap().take().unwrap()) + .reduce(reduce) + .unwrap() + } + /// Processes an input of the given length in parallel and returns the /// aggregated output. fn iter_pipeline( @@ -588,6 +667,45 @@ where } } +struct UpperBoundedPipelineImpl< + Output, + Accum, + Init: Fn() -> Accum, + ProcessItem: Fn(Accum, usize) -> ControlFlow, + Finalize: Fn(Accum) -> Output, +> { + bound: CachePadded, + outputs: Arc<[Mutex>]>, + init: Init, + process_item: ProcessItem, + finalize: Finalize, +} + +impl Pipeline + for UpperBoundedPipelineImpl +where + R: Range, + Init: Fn() -> Accum, + ProcessItem: Fn(Accum, usize) -> ControlFlow, + Finalize: Fn(Accum) -> Output, +{ + fn run(&self, worker_id: usize, range: &R) { + let mut accumulator = (self.init)(); + for i in range.upper_bounded_iter(&self.bound) { + let acc = (self.process_item)(accumulator, i); + accumulator = match acc { + ControlFlow::Continue(acc) => acc, + ControlFlow::Break(acc) => { + self.bound.fetch_min(i, Ordering::Relaxed); + acc + } + }; + } + let output = (self.finalize)(accumulator); + *self.outputs[worker_id].lock().unwrap() = Some(output); + } +} + struct IterPipelineImpl> { outputs: Arc<[Mutex>]>, accum: Accum, diff --git a/src/iter/mod.rs b/src/iter/mod.rs index 0ca098e..a8aa49c 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -119,6 +119,69 @@ pub trait ParallelIterator: Sized { reduce: impl Fn(Output, Output) -> Output, ) -> Output; + /// Runs the pipeline defined by the given functions on this iterator. + /// + /// # Parameters + /// + /// - `init` function to create a new (per-thread) accumulator, + /// - `process_item` function to accumulate an item into the accumulator, + /// - `finalize` function to transform an accumulator into an output, + /// - `reduce` function to reduce a pair of outputs into one output. + /// + /// Contrary to [`pipeline()`](Self::pipeline), the `process_item` function + /// can return [`ControlFlow::Break`] to indicate that the pipeline should + /// skip processing items at larger indices. + /// + /// ``` + /// # use paralight::iter::{IntoParallelRefSource, ParallelIterator, ParallelSourceExt}; + /// # use paralight::{CpuPinningPolicy, RangeStrategy, ThreadCount, ThreadPoolBuilder}; + /// # use std::ops::ControlFlow; + /// # let mut thread_pool = ThreadPoolBuilder { + /// # num_threads: ThreadCount::AvailableParallelism, + /// # range_strategy: RangeStrategy::WorkStealing, + /// # cpu_pinning: CpuPinningPolicy::No, + /// # } + /// # .build(); + /// let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + /// let first_even = input + /// .par_iter() + /// .with_thread_pool(&mut thread_pool) + /// .upper_bounded_pipeline( + /// || None, + /// |acc, i, x| { + /// match acc { + /// // Early return if we found something at a previous index. + /// Some((j, _)) if j < i => ControlFlow::Continue(acc), + /// _ => match x % 2 == 0 { + /// true => ControlFlow::Break(Some((i, x))), + /// false => ControlFlow::Continue(acc), + /// }, + /// } + /// }, + /// |acc| acc, + /// |x, y| match (x, y) { + /// (None, None) => None, + /// (Some(found), None) | (None, Some(found)) => Some(found), + /// (Some((i, a)), Some((j, b))) => { + /// if i < j { + /// Some((i, a)) + /// } else { + /// Some((j, b)) + /// } + /// } + /// }, + /// ) + /// .map(|(_, x)| x); + /// assert_eq!(first_even, Some(&2)); + /// ``` + fn upper_bounded_pipeline( + self, + init: impl Fn() -> Accum + Sync, + process_item: impl Fn(Accum, usize, Self::Item) -> ControlFlow + Sync, + finalize: impl Fn(Accum) -> Output + Sync, + reduce: impl Fn(Output, Output) -> Output, + ) -> Output; + /// Runs the pipeline defined by the given functions on this iterator. /// /// # Parameters @@ -262,6 +325,25 @@ impl ParallelIterator for T { ) } + fn upper_bounded_pipeline( + self, + init: impl Fn() -> Accum + Sync, + process_item: impl Fn(Accum, usize, Self::Item) -> ControlFlow + Sync, + finalize: impl Fn(Accum) -> Output + Sync, + reduce: impl Fn(Output, Output) -> Output, + ) -> Output { + let descriptor = self.descriptor(); + descriptor.inner.upper_bounded_pipeline( + init, + |accum, index, item| match (descriptor.transform_item)(item) { + Some(item) => process_item(accum, index, item), + None => ControlFlow::Continue(accum), + }, + finalize, + reduce, + ) + } + fn iter_pipeline( self, accum: impl Accumulator + Sync, @@ -823,6 +905,72 @@ pub trait ParallelIteratorExt: ParallelIterator { ) } + /// Returns the first item that satisfies the predicate `f`, or [`None`] if + /// no item satisfies it. + /// + /// ``` + /// # use paralight::iter::{IntoParallelRefSource, ParallelIteratorExt, ParallelSourceExt}; + /// # use paralight::{CpuPinningPolicy, RangeStrategy, ThreadCount, ThreadPoolBuilder}; + /// # let mut thread_pool = ThreadPoolBuilder { + /// # num_threads: ThreadCount::try_from(2).unwrap(), + /// # range_strategy: RangeStrategy::WorkStealing, + /// # cpu_pinning: CpuPinningPolicy::No, + /// # } + /// # .build(); + /// let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + /// + /// let four = input + /// .par_iter() + /// .with_thread_pool(&mut thread_pool) + /// .find_first(|&&x| x == 4); + /// assert_eq!(four, Some(&4)); + /// + /// let twenty = input + /// .par_iter() + /// .with_thread_pool(&mut thread_pool) + /// .find_first(|&&x| x == 20); + /// assert_eq!(twenty, None); + /// + /// let first_even = input + /// .par_iter() + /// .with_thread_pool(&mut thread_pool) + /// .copied() + /// .find_first(|&x| x % 2 == 0); + /// assert_eq!(first_even, Some(2)); + /// ``` + fn find_first(self, f: F) -> Option + where + F: Fn(&Self::Item) -> bool + Sync, + Self::Item: Send, + { + self.upper_bounded_pipeline( + || None, + |acc, i, item| { + match acc { + // Early return if we found something at a previous index. + Some((j, _)) if j < i => ControlFlow::Continue(acc), + _ => match f(&item) { + true => ControlFlow::Break(Some((i, item))), + false => ControlFlow::Continue(acc), + }, + } + }, + |acc| acc, + |x, y| match (x, y) { + (None, None) => None, + (Some(found), None) | (None, Some(found)) => Some(found), + (Some((i, a)), Some((j, b))) => { + if i < j { + Some((i, a)) + } else { + Some((j, b)) + } + } + }, + ) + .map(|(_, item)| item) + } + /// Runs `f` on each item of this parallel iterator. /// /// See also [`for_each_init()`](Self::for_each_init) if you need to @@ -2211,6 +2359,27 @@ where ) } + fn upper_bounded_pipeline( + self, + init: impl Fn() -> Accum + Sync, + process_item: impl Fn(Accum, usize, Self::Item) -> ControlFlow + Sync, + finalize: impl Fn(Accum) -> Output + Sync, + reduce: impl Fn(Output, Output) -> Output, + ) -> Output { + self.inner.upper_bounded_pipeline( + || ((self.init)(), init()), + |(mut i, accum), index, item| { + let accum = process_item(accum, index, (self.f)(&mut i, item)); + match accum { + ControlFlow::Continue(accum) => ControlFlow::Continue((i, accum)), + ControlFlow::Break(accum) => ControlFlow::Break((i, accum)), + } + }, + |(_, accum)| finalize(accum), + reduce, + ) + } + fn iter_pipeline( self, accum: impl Accumulator + Sync, diff --git a/src/iter/source/mod.rs b/src/iter/source/mod.rs index 548bfd4..0ceb1ce 100644 --- a/src/iter/source/mod.rs +++ b/src/iter/source/mod.rs @@ -822,6 +822,23 @@ impl ParallelIterator for BaseParallelIterator<'_, S> { ) } + fn upper_bounded_pipeline( + self, + init: impl Fn() -> Accum + Sync, + process_item: impl Fn(Accum, usize, Self::Item) -> ControlFlow + Sync, + finalize: impl Fn(Accum) -> Output + Sync, + reduce: impl Fn(Output, Output) -> Output, + ) -> Output { + let source_descriptor = self.source.descriptor(); + self.thread_pool.upper_bounded_pipeline( + source_descriptor.len, + init, + |acc, index| process_item(acc, index, (source_descriptor.fetch_item)(index)), + finalize, + reduce, + ) + } + fn iter_pipeline( self, accum: impl Accumulator + Sync, diff --git a/src/lib.rs b/src/lib.rs index dd93afd..1e2bccf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -161,6 +161,7 @@ mod test { test_adaptor_filter, test_adaptor_filter_map, test_adaptor_find_any, + test_adaptor_find_first, test_adaptor_for_each, test_adaptor_for_each_init, test_adaptor_inspect, @@ -1632,6 +1633,64 @@ mod test { assert_eq!(empty, None); } + fn test_adaptor_find_first(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (0..=INPUT_LEN).collect::>(); + let first = input + .par_iter() + .with_thread_pool(&mut thread_pool) + .copied() + .find_first(|_| true); + assert_eq!(first, Some(0)); + + let last = input + .par_iter() + .with_thread_pool(&mut thread_pool) + .copied() + .find_first(|&x| x >= INPUT_LEN); + assert_eq!(last, Some(INPUT_LEN)); + + let end = input + .par_iter() + .with_thread_pool(&mut thread_pool) + .copied() + .find_first(|&x| x > INPUT_LEN); + assert_eq!(end, None); + + let forty_two = input + .par_iter() + .with_thread_pool(&mut thread_pool) + .copied() + .find_first(|&x| x >= 42); + assert_eq!(forty_two, if INPUT_LEN >= 42 { Some(42) } else { None }); + + let even = input + .par_iter() + .with_thread_pool(&mut thread_pool) + .copied() + .find_first(|&x| x % 2 == 0); + assert_eq!(even, Some(0)); + + let odd = input + .par_iter() + .with_thread_pool(&mut thread_pool) + .copied() + .find_first(|&x| x % 2 == 1); + assert_eq!(odd, Some(1)); + + let empty = [] + .par_iter() + .with_thread_pool(&mut thread_pool) + .find_first(|_: &&u64| true); + assert_eq!(empty, None); + } + fn test_adaptor_for_each(range_strategy: RangeStrategy) { let mut thread_pool = ThreadPoolBuilder { num_threads: ThreadCount::AvailableParallelism,