diff --git a/src/core/range.rs b/src/core/range.rs index 351408b..d857dad 100644 --- a/src/core/range.rs +++ b/src/core/range.rs @@ -49,13 +49,13 @@ pub trait RangeOrchestrator { /// be stolen by other threads. pub trait Range { /// Type of iterator returned by [`iter()`](Self::iter). - type Iter<'a>: Iterator + type Iter<'a>: SkipIterator where Self: 'a; /// Type of iterator returned by /// [`upper_bounded_iter()`](Self::upper_bounded_iter). - type UpperBoundedIter<'a, 'bound>: Iterator + type UpperBoundedIter<'a, 'bound>: SkipIterator where Self: 'a; @@ -72,6 +72,21 @@ pub trait Range { ) -> Self::UpperBoundedIter<'a, 'bound>; } +/// An iterator trait over `usize` that either returns a next index or a range +/// of skipped indices. +pub trait SkipIterator { + /// Returns the next item and/or a range of skipped indices. + /// + /// The iterator is exhausted if and only if this returns a pair of [`None`] + /// values. + fn next(&mut self) -> (Option, Option>); + + /// Returns any remaining range of indices that have been skipped. + /// + /// This iterator must not be used again once this has been called. + fn remaining_range(&self) -> Option>; +} + /// A factory that hands out a fixed range to each thread, without any stealing. pub struct FixedRangeFactory { /// Number of threads that iterate. @@ -156,6 +171,20 @@ impl Range for FixedRange { } } +impl SkipIterator for std::ops::Range { + fn next(&mut self) -> (Option, Option>) { + (Iterator::next(self), None) + } + + fn remaining_range(&self) -> Option> { + if self.is_empty() { + None + } else { + Some(self.clone()) + } + } +} + /// An upper-bounded iterator for a [`FixedRange`]. pub struct UpperBoundedRange<'bound> { /// Underlying contiguous range. @@ -164,19 +193,25 @@ pub struct UpperBoundedRange<'bound> { bound: &'bound AtomicUsize, } -impl Iterator for UpperBoundedRange<'_> { - type Item = usize; +impl SkipIterator for UpperBoundedRange<'_> { + fn next(&mut self) -> (Option, Option>) { + let start = self.range.start; + if start != self.range.end && start <= self.bound.load(Ordering::Relaxed) { + self.range.start += 1; + (Some(start), None) + } else { + // The upper bound can only decrease, so once it's reached the iterator is + // exhausted. + (None, None) + } + } - fn next(&mut self) -> Option { - self.range.next().and_then(|x| { - if x <= self.bound.load(Ordering::Relaxed) { - Some(x) - } else { - // The upper bound can only decrease, so once it's reached the iterator is - // exhausted. - None - } - }) + fn remaining_range(&self) -> Option> { + if self.range.is_empty() { + None + } else { + Some(self.range.clone()) + } } } @@ -383,6 +418,11 @@ impl PackedRange { (self.0 >> 32) as u32 } + #[inline(always)] + fn to_range(self) -> std::ops::Range { + self.start() as usize..self.end() as usize + } + /// Reads the length of the range. #[inline(always)] fn len(self) -> u32 { @@ -397,10 +437,17 @@ impl PackedRange { /// Upper bound this range by the given maximum. #[inline(always)] - fn upper_bound(self, bound: usize) -> Self { - let start = (self.start() as usize).min(bound) as u32; - let end = (self.end() as usize).min(bound) as u32; - Self::new(start, end) + fn upper_bound(self, bound: usize) -> (Self, Self) { + let start = self.start(); + let end = self.end(); + + if end as usize <= bound { + (Self::new(start, end), Self::default()) + } else if start as usize >= bound { + (Self::default(), Self::new(start, end)) + } else { + (Self::new(start, bound as u32), Self::new(bound as u32, end)) + } } /// Increments the start of the range. @@ -416,6 +463,7 @@ impl PackedRange { fn split(self) -> (Self, Self) { let start = self.start(); let end = self.end(); + // TODO(MSRV >= 1.85.0): Use u32::midpoint(). // The result fits in u32 because the inputs fit in u32. let middle = ((start as u64 + end as u64) / 2) as u32; (Self::new(start, middle), Self::new(middle, end)) @@ -475,10 +523,22 @@ impl Drop for WorkStealingRangeIterator<'_> { } } -impl Iterator for WorkStealingRangeIterator<'_> { - type Item = usize; +impl SkipIterator for WorkStealingRangeIterator<'_> { + fn remaining_range(&self) -> Option> { + let my_atomic_range: &AtomicRange = &self.ranges[self.id]; + let mut my_range: PackedRange = my_atomic_range.load(); + + while !my_range.is_empty() { + match my_atomic_range.compare_exchange(my_range, PackedRange::default()) { + Ok(()) => return Some(my_range.to_range()), + Err(range) => my_range = range, + } + } - fn next(&mut self) -> Option { + None + } + + fn next(&mut self) -> (Option, Option>) { let my_atomic_range: &AtomicRange = &self.ranges[self.id]; let mut my_range: PackedRange = my_atomic_range.load(); @@ -493,13 +553,12 @@ impl Iterator for WorkStealingRangeIterator<'_> { { self.stats.increments += 1; log_trace!( - "[thread {}] Incremented range to {}..{}", + "[thread {}] Incremented range to {:?}", self.id, - my_new_range.start(), - my_new_range.end() + my_new_range.to_range() ); } - return Some(taken as usize); + return (Some(taken as usize), None); } // Increment failed: retry with an updated range. Err(range) => { @@ -508,10 +567,9 @@ impl Iterator for WorkStealingRangeIterator<'_> { { self.stats.failed_increments += 1; log_debug!( - "[thread {}] Failed to increment range, new range is {}..{}", + "[thread {}] Failed to increment range, new range is {:?}", self.id, - range.start(), - range.end() + range.to_range() ); } continue; @@ -533,15 +591,14 @@ impl WorkStealingRangeIterator<'_> { fn steal( &mut self, #[cfg(feature = "log_parallelism")] my_range: PackedRange, - ) -> Option { + ) -> (Option, Option>) { let my_atomic_range: &AtomicRange = &self.ranges[self.id]; #[cfg(feature = "log_parallelism")] log_debug!( - "[thread {}] Range {}..{} is empty, scanning other threads", + "[thread {}] Range {:?} is empty, scanning other threads", self.id, - my_range.start(), - my_range.end() + my_range.to_range() ); let range_count = self.ranges.len(); @@ -590,15 +647,14 @@ impl WorkStealingRangeIterator<'_> { { self.stats.thefts += 1; log_trace!( - "[thread {}] Stole range {}:{}..{} from thread {}", + "[thread {}] Stole range {}:{:?} from thread {}", self.id, taken, - my_new_range.start(), - my_new_range.end(), + my_new_range.to_range(), max_index ); } - return Some(taken as usize); + return (Some(taken as usize), None); } // Theft failed: update the range and retry. Err(range) => { @@ -626,7 +682,7 @@ impl WorkStealingRangeIterator<'_> { // Didn't manage to steal anything: exit the iterator. #[cfg(feature = "log_parallelism")] log_debug!("[thread {}] Didn't find anything to steal", self.id); - None + (None, None) } } @@ -655,55 +711,104 @@ impl Drop for UpperBoundedWorkStealingRangeIterator<'_, '_> { } } -impl Iterator for UpperBoundedWorkStealingRangeIterator<'_, '_> { - type Item = usize; +impl SkipIterator for UpperBoundedWorkStealingRangeIterator<'_, '_> { + fn remaining_range(&self) -> Option> { + let my_atomic_range: &AtomicRange = &self.ranges[self.id]; + let mut my_range: PackedRange = my_atomic_range.load(); + + while !my_range.is_empty() { + match my_atomic_range.compare_exchange(my_range, PackedRange::default()) { + Ok(()) => return Some(my_range.to_range()), + Err(range) => my_range = range, + } + } + + None + } - fn next(&mut self) -> Option { + fn next(&mut self) -> (Option, Option>) { let bound = self.bound.load(Ordering::Relaxed); + #[cfg(feature = "log_parallelism")] + log_trace!("[thread {}] Loaded upper bound = {}", self.id, bound); + let my_atomic_range: &AtomicRange = &self.ranges[self.id]; let mut my_loaded_range: PackedRange = my_atomic_range.load(); - let mut my_bounded_range = my_loaded_range.upper_bound(bound); - - #[cfg(feature = "log_parallelism")] - { - log_trace!("[thread {}] Loaded upper bound = {}", self.id, bound); - } + let (mut my_bounded_range, mut my_residual_range) = my_loaded_range.upper_bound(bound); // First phase: try to increment this thread's own range. Retries are needed in // case another thread stole part of the range. - while !my_bounded_range.is_empty() { - let (taken, my_new_range) = my_bounded_range.increment_start(); - match my_atomic_range.compare_exchange(my_loaded_range, my_new_range) { - // Increment succeeded. - Ok(()) => { - #[cfg(feature = "log_parallelism")] - { - self.stats.increments += 1; - log_trace!( - "[thread {}] Incremented range to {}..{}", - self.id, - my_new_range.start(), - my_new_range.end() - ); + loop { + if !my_bounded_range.is_empty() { + let (taken, my_new_range) = my_bounded_range.increment_start(); + match my_atomic_range.compare_exchange(my_loaded_range, my_new_range) { + // Increment succeeded. + Ok(()) => { + #[cfg(feature = "log_parallelism")] + { + self.stats.increments += 1; + log_trace!( + "[thread {}] Incremented range to {:?}", + self.id, + my_new_range.to_range() + ); + } + + let residual = if my_residual_range.is_empty() { + None + } else { + let residual = my_residual_range.to_range(); + #[cfg(feature = "log_parallelism")] + log_debug!( + "[thread {}] Residual range {:?} is not empty (increment), scheduling it for cleanup.", + self.id, + residual + ); + Some(residual) + }; + + return (Some(taken as usize), residual); + } + // Increment failed: retry with an updated range. + Err(range) => { + my_loaded_range = range; + (my_bounded_range, my_residual_range) = my_loaded_range.upper_bound(bound); + #[cfg(feature = "log_parallelism")] + { + self.stats.failed_increments += 1; + log_debug!( + "[thread {}] Failed to increment range, new range is {:?}", + self.id, + range.to_range() + ); + } + continue; } - return Some(taken as usize); } - // Increment failed: retry with an updated range. - Err(range) => { - my_loaded_range = range; - my_bounded_range = my_loaded_range.upper_bound(bound); - #[cfg(feature = "log_parallelism")] - { - self.stats.failed_increments += 1; - log_debug!( - "[thread {}] Failed to increment range, new range is {}..{}", - self.id, - range.start(), - range.end() - ); + } else if !my_loaded_range.is_empty() { + // First, let's make sure other threads don't try to steal this range, which can + // happen if they have cached another bound. + match my_atomic_range.compare_exchange(my_loaded_range, my_bounded_range) { + Ok(()) => { + if !my_residual_range.is_empty() { + let residual = my_residual_range.to_range(); + #[cfg(feature = "log_parallelism")] + log_debug!( + "[thread {}] Residual range {:?} is not empty (empty bounded range), scheduling it for cleanup.", + self.id, + residual + ); + return (None, Some(residual)); + }; + break; + } + Err(range) => { + my_loaded_range = range; + (my_bounded_range, my_residual_range) = my_loaded_range.upper_bound(bound); + continue; } - continue; } + } else { + break; } } @@ -720,6 +825,7 @@ impl Iterator for UpperBoundedWorkStealingRangeIterator<'_, '_> { struct OtherRange { loaded: PackedRange, bounded: PackedRange, + residual: PackedRange, } impl UpperBoundedWorkStealingRangeIterator<'_, '_> { @@ -729,15 +835,14 @@ impl UpperBoundedWorkStealingRangeIterator<'_, '_> { &mut self, bound: usize, #[cfg(feature = "log_parallelism")] my_bounded_range: PackedRange, - ) -> Option { + ) -> (Option, Option>) { let my_atomic_range: &AtomicRange = &self.ranges[self.id]; #[cfg(feature = "log_parallelism")] log_debug!( - "[thread {}] Range {}..{} is empty, scanning other threads", + "[thread {}] Range {:?} is empty, scanning other threads", self.id, - my_bounded_range.start(), - my_bounded_range.end() + my_bounded_range.to_range() ); let range_count = self.ranges.len(); @@ -751,8 +856,12 @@ impl UpperBoundedWorkStealingRangeIterator<'_, '_> { continue; } let loaded = self.ranges[i].load(); - let bounded = loaded.upper_bound(bound); - *range = OtherRange { loaded, bounded }; + let (bounded, residual) = loaded.upper_bound(bound); + *range = OtherRange { + loaded, + bounded, + residual, + }; } #[cfg(feature = "log_parallelism")] { @@ -781,6 +890,19 @@ impl UpperBoundedWorkStealingRangeIterator<'_, '_> { match self.ranges[max_index].compare_exchange(max_range.loaded, remaining) { // Theft succeeded. Ok(()) => { + let residual = if max_range.residual.is_empty() { + None + } else { + let residual = max_range.residual.to_range(); + #[cfg(feature = "log_parallelism")] + log_debug!( + "[thread {}] Residual range {:?} is not empty (stolen), scheduling it for cleanup.", + self.id, + residual + ); + Some(residual) + }; + // Take the first item, and place the rest in this thread's own range. let (taken, my_new_range) = stolen.increment_start(); my_atomic_range.store(my_new_range); @@ -788,20 +910,24 @@ impl UpperBoundedWorkStealingRangeIterator<'_, '_> { { self.stats.thefts += 1; log_trace!( - "[thread {}] Stole range {}:{}..{} from thread {}", + "[thread {}] Stole range {}:{:?} from thread {}", self.id, taken, - my_new_range.start(), - my_new_range.end(), + my_new_range.to_range(), max_index ); } - return Some(taken as usize); + + return (Some(taken as usize), residual); } // Theft failed: update the range and retry. Err(loaded) => { - let bounded = loaded.upper_bound(bound); - let range = OtherRange { loaded, bounded }; + let (bounded, residual) = loaded.upper_bound(bound); + let range = OtherRange { + loaded, + bounded, + residual, + }; other_ranges[max_index] = range; #[cfg(feature = "log_parallelism")] { @@ -826,7 +952,7 @@ impl UpperBoundedWorkStealingRangeIterator<'_, '_> { // Didn't manage to steal anything: exit the iterator. #[cfg(feature = "log_parallelism")] log_debug!("[thread {}] Didn't find anything to steal", self.id); - None + (None, None) } } @@ -834,6 +960,21 @@ impl UpperBoundedWorkStealingRangeIterator<'_, '_> { mod test { use super::*; + struct SkipIteratorWrapper(T); + + impl Iterator for SkipIteratorWrapper { + type Item = usize; + + fn next(&mut self) -> Option { + loop { + match self.0.next() { + (None, Some(_)) => continue, + (next, _) => return next, + } + } + } + } + #[test] fn test_fixed_range_factory_splits_evenly() { let factory = FixedRangeFactory::new(4); @@ -865,9 +1006,9 @@ mod test { std::thread::scope(|s| { for _ in 0..10 { orchestrator.reset_ranges(100); - let handles = ranges - .each_ref() - .map(|range| s.spawn(move || range.iter().collect::>())); + let handles = ranges.each_ref().map(|range| { + s.spawn(move || SkipIteratorWrapper(range.iter()).collect::>()) + }); let values: [Vec; 4] = handles.map(|handle| handle.join().unwrap()); // The fixed range implementation always yields the same items in order. @@ -893,9 +1034,9 @@ mod test { std::thread::scope(|s| { for _ in 0..10 { orchestrator.reset_ranges(NUM_ELEMENTS); - let handles = ranges - .each_ref() - .map(|range| s.spawn(move || range.iter().collect::>())); + let handles = ranges.each_ref().map(|range| { + s.spawn(move || SkipIteratorWrapper(range.iter()).collect::>()) + }); let values: [Vec; NUM_THREADS] = handles.map(|handle| handle.join().unwrap()); @@ -973,6 +1114,26 @@ mod test { } } + #[test] + fn test_packed_range_upper_bound() { + let range = PackedRange::new(10, 20); + for bound in 0..=10 { + let (left, right) = range.upper_bound(bound as usize); + assert!(left.is_empty()); + assert_eq!((right.start(), right.end()), (10, 20)); + } + for bound in 11..=19 { + let (left, right) = range.upper_bound(bound as usize); + assert_eq!((left.start(), left.end()), (10, bound)); + assert_eq!((right.start(), right.end()), (bound, 20)); + } + for bound in 20..=30 { + let (left, right) = range.upper_bound(bound as usize); + assert_eq!((left.start(), left.end()), (10, 20)); + assert!(right.is_empty()); + } + } + #[test] fn test_packed_range_increment_start() { let mut range = PackedRange::new(0, 10); diff --git a/src/core/thread_pool.rs b/src/core/thread_pool.rs index 3a8684f..303677d 100644 --- a/src/core/thread_pool.rs +++ b/src/core/thread_pool.rs @@ -9,11 +9,12 @@ //! A thread pool implementing parallelism at a lightweight cost. use super::range::{ - FixedRangeFactory, Range, RangeFactory, RangeOrchestrator, WorkStealingRangeFactory, + FixedRangeFactory, Range, RangeFactory, RangeOrchestrator, SkipIterator, + WorkStealingRangeFactory, }; use super::sync::{make_lending_group, Borrower, Lender, WorkerState}; use super::util::LifetimeParameterized; -use crate::iter::Accumulator; +use crate::iter::{Accumulator, SourceCleanup}; use crate::macros::{log_debug, log_error, log_warn}; use crossbeam_utils::CachePadded; // Platforms that support `libc::sched_setaffinity()`. @@ -152,9 +153,10 @@ impl ThreadPool { process_item: impl Fn(Accum, usize) -> ControlFlow + Sync, finalize: impl Fn(Accum) -> Output + Sync, reduce: impl Fn(Output, Output) -> Output, + cleanup: impl SourceCleanup + Sync, ) -> Output { self.inner - .upper_bounded_pipeline(input_len, init, process_item, finalize, reduce) + .upper_bounded_pipeline(input_len, init, process_item, finalize, reduce, cleanup) } /// Processes an input of the given length in parallel and returns the @@ -164,8 +166,9 @@ impl ThreadPool { input_len: usize, accum: impl Accumulator + Sync, reduce: impl Accumulator, + cleanup: impl SourceCleanup + Sync, ) -> Output { - self.inner.iter_pipeline(input_len, accum, reduce) + self.inner.iter_pipeline(input_len, accum, reduce, cleanup) } } @@ -221,14 +224,25 @@ impl ThreadPoolEnum { process_item: impl Fn(Accum, usize) -> ControlFlow + Sync, finalize: impl Fn(Accum) -> Output + Sync, reduce: impl Fn(Output, Output) -> Output, + cleanup: impl SourceCleanup + Sync, ) -> Output { match self { - ThreadPoolEnum::Fixed(inner) => { - inner.upper_bounded_pipeline(input_len, init, process_item, finalize, reduce) - } - ThreadPoolEnum::WorkStealing(inner) => { - inner.upper_bounded_pipeline(input_len, init, process_item, finalize, reduce) - } + ThreadPoolEnum::Fixed(inner) => inner.upper_bounded_pipeline( + input_len, + init, + process_item, + finalize, + reduce, + cleanup, + ), + ThreadPoolEnum::WorkStealing(inner) => inner.upper_bounded_pipeline( + input_len, + init, + process_item, + finalize, + reduce, + cleanup, + ), } } @@ -239,10 +253,13 @@ impl ThreadPoolEnum { input_len: usize, accum: impl Accumulator + Sync, reduce: impl Accumulator, + cleanup: impl SourceCleanup + Sync, ) -> Output { match self { - ThreadPoolEnum::Fixed(inner) => inner.iter_pipeline(input_len, accum, reduce), - ThreadPoolEnum::WorkStealing(inner) => inner.iter_pipeline(input_len, accum, reduce), + ThreadPoolEnum::Fixed(inner) => inner.iter_pipeline(input_len, accum, reduce, cleanup), + ThreadPoolEnum::WorkStealing(inner) => { + inner.iter_pipeline(input_len, accum, reduce, cleanup) + } } } } @@ -370,6 +387,7 @@ impl ThreadPoolImpl { process_item: impl Fn(Accum, usize) -> ControlFlow + Sync, finalize: impl Fn(Accum) -> Output + Sync, reduce: impl Fn(Output, Output) -> Output, + cleanup: impl SourceCleanup + Sync, ) -> Output { self.range_orchestrator.reset_ranges(input_len); @@ -385,6 +403,7 @@ impl ThreadPoolImpl { init, process_item, finalize, + cleanup, }); outputs @@ -401,6 +420,7 @@ impl ThreadPoolImpl { input_len: usize, accum: impl Accumulator + Sync, reduce: impl Accumulator, + cleanup: impl SourceCleanup + Sync, ) -> Output { self.range_orchestrator.reset_ranges(input_len); @@ -412,6 +432,7 @@ impl ThreadPoolImpl { self.pipeline.lend(&IterPipelineImpl { outputs: outputs.clone(), accum, + cleanup, }); reduce.accumulate( @@ -463,25 +484,32 @@ struct UpperBoundedPipelineImpl< Init: Fn() -> Accum, ProcessItem: Fn(Accum, usize) -> ControlFlow, Finalize: Fn(Accum) -> Output, + Cleanup: SourceCleanup, > { bound: CachePadded, outputs: Arc<[Mutex>]>, init: Init, process_item: ProcessItem, finalize: Finalize, + cleanup: Cleanup, } -impl Pipeline - for UpperBoundedPipelineImpl +impl Pipeline + for UpperBoundedPipelineImpl where R: Range, Init: Fn() -> Accum, ProcessItem: Fn(Accum, usize) -> ControlFlow, Finalize: Fn(Accum) -> Output, + Cleanup: SourceCleanup, { fn run(&self, worker_id: usize, range: &R) { let mut accumulator = (self.init)(); - for i in range.upper_bounded_iter(&self.bound) { + let iter = SkipIteratorWrapper { + iter: range.upper_bounded_iter(&self.bound), + cleanup: &self.cleanup, + }; + for i in iter { let acc = (self.process_item)(accumulator, i); accumulator = match acc { ControlFlow::Continue(acc) => acc, @@ -496,22 +524,59 @@ where } } -struct IterPipelineImpl> { +struct IterPipelineImpl, Cleanup: SourceCleanup> { outputs: Arc<[Mutex>]>, accum: Accum, + cleanup: Cleanup, } -impl Pipeline for IterPipelineImpl +impl Pipeline for IterPipelineImpl where R: Range, Accum: Accumulator, + Cleanup: SourceCleanup, { fn run(&self, worker_id: usize, range: &R) { - let output = self.accum.accumulate(range.iter()); + let iter = SkipIteratorWrapper { + iter: range.iter(), + cleanup: &self.cleanup, + }; + let output = self.accum.accumulate(iter); *self.outputs[worker_id].lock().unwrap() = Some(output); } } +struct SkipIteratorWrapper<'a, I: SkipIterator, Cleanup: SourceCleanup> { + iter: I, + cleanup: &'a Cleanup, +} + +impl Iterator for SkipIteratorWrapper<'_, I, Cleanup> { + type Item = usize; + + fn next(&mut self) -> Option { + loop { + match self.iter.next() { + (index, None) => return index, + (index, Some(skipped_range)) => { + self.cleanup.cleanup_item_range(skipped_range); + if index.is_some() { + return index; + } + } + } + } + } +} + +impl Drop for SkipIteratorWrapper<'_, I, Cleanup> { + fn drop(&mut self) { + if let Some(range) = self.iter.remaining_range() { + self.cleanup.cleanup_item_range(range); + } + } +} + /// Context object owned by a worker thread. struct ThreadContext { /// Thread index. diff --git a/src/iter/mod.rs b/src/iter/mod.rs index 5d9a203..f7128b0 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -11,12 +11,13 @@ mod source; use crossbeam_utils::CachePadded; +pub use source::owned_slice::{ArrayParallelSource, VecParallelSource}; pub use source::range::{RangeInclusiveParallelSource, RangeParallelSource}; pub use source::slice::{MutSliceParallelSource, SliceParallelSource}; pub use source::zip::{ZipEq, ZipMax, ZipMin, ZipableSource}; pub use source::{ IntoParallelRefMutSource, IntoParallelRefSource, IntoParallelSource, ParallelSource, - ParallelSourceExt, SourceDescriptor, + ParallelSourceExt, SourceCleanup, SourceDescriptor, }; use std::cmp::Ordering; use std::iter::{Product, Sum}; diff --git a/src/iter/source/mod.rs b/src/iter/source/mod.rs index 66c83d5..8775bac 100644 --- a/src/iter/source/mod.rs +++ b/src/iter/source/mod.rs @@ -8,6 +8,7 @@ //! Parallel sources from which parallel iterators are derived. +pub mod owned_slice; pub mod range; pub mod slice; pub mod zip; @@ -17,11 +18,55 @@ use crate::ThreadPool; use std::ops::ControlFlow; /// An object describing how to fetch items from a [`ParallelSource`]. -pub struct SourceDescriptor Item + Sync> { +pub struct SourceDescriptor< + Item: Send, + FetchItem: Fn(usize) -> Item + Sync, + Cleanup: SourceCleanup + Sync, +> { /// Number of items that the source produces. pub len: usize, /// A function to fetch the item at the given index. pub fetch_item: FetchItem, + /// An API to cleanup a range of items that won't be fetched. + pub cleanup: Cleanup, +} + +/// An interface to cleanup a range of items that aren't fetched from a source. +/// +/// There are two reasons why manual cleanup of items is sometimes needed. +/// - If a short-circuiting combinator such as +/// [`find_any()`](super::ParallelIteratorExt::find_any) is used, the pipeline +/// will skip remaining items once a match is found. +/// - If a function in an iterator pipeline panics, the remaining items are +/// skipped but should nevertheless be dropped as part of unwinding. +/// +/// A non-trivial cleanup is needed for parallel sources that drain items, such +/// as calling [`into_par_iter()`](IntoParallelSource::into_par_iter) on a +/// [`Vec`], and must correspond to [`drop()`]-ing items. +pub trait SourceCleanup { + /// Set to [`false`] if the cleanup function is guaranteed to be a noop. + /// + /// Typically, cleanup is a noop for sources over [references](reference). + /// For draining sources, this should follow the + /// [`std::mem::needs_drop()`] hint. + const NEEDS_CLEANUP: bool; + + /// Clean up the given range of items from the source. + /// + /// As with [`Drop`], this should not panic. + fn cleanup_item_range(&self, range: std::ops::Range); +} + +/// A [`SourceCleanup`] that does nothing. +/// +/// This is useful for [`ParallelSource`]s whose underlying destructor is a +/// noop, e.g. a parallel source over a slice that yields references. +pub struct NoopSourceCleanup; + +impl SourceCleanup for NoopSourceCleanup { + const NEEDS_CLEANUP: bool = false; + + fn cleanup_item_range(&self, _range: std::ops::Range) {} } /// A source to produce items in parallel. The [`ParallelSourceExt`] trait @@ -41,7 +86,9 @@ pub trait ParallelSource: Sized { type Item: Send; /// Returns an object that describes how to fetch items from this source. - fn descriptor(self) -> SourceDescriptor Self::Item + Sync>; + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync>; } /// Trait for converting into a [`ParallelSource`]. @@ -301,7 +348,7 @@ pub trait ParallelSourceExt: ParallelSource { fn skip(self, n: usize) -> Skip { Skip { inner: self, - len: n, + count: n, } } @@ -349,7 +396,7 @@ pub trait ParallelSourceExt: ParallelSource { fn skip_exact(self, n: usize) -> SkipExact { SkipExact { inner: self, - len: n, + count: n, } } @@ -470,7 +517,7 @@ pub trait ParallelSourceExt: ParallelSource { fn take(self, n: usize) -> Take { Take { inner: self, - len: n, + count: n, } } @@ -517,7 +564,7 @@ pub trait ParallelSourceExt: ParallelSource { fn take_exact(self, n: usize) -> TakeExact { TakeExact { inner: self, - len: n, + count: n, } } @@ -568,7 +615,10 @@ impl, Second: ParallelSource> { type Item = T; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let descriptor1 = self.first.descriptor(); let descriptor2 = self.second.descriptor(); let len = descriptor1 @@ -586,6 +636,39 @@ impl, Second: ParallelSource> (descriptor2.fetch_item)(index - descriptor1.len) } }, + cleanup: ChainSourceCleanup { + len1: descriptor1.len, + cleanup1: descriptor1.cleanup, + cleanup2: descriptor2.cleanup, + }, + } + } +} + +struct ChainSourceCleanup { + len1: usize, + cleanup1: First, + cleanup2: Second, +} + +impl SourceCleanup for ChainSourceCleanup +where + First: SourceCleanup, + Second: SourceCleanup, +{ + const NEEDS_CLEANUP: bool = First::NEEDS_CLEANUP || Second::NEEDS_CLEANUP; + + fn cleanup_item_range(&self, range: std::ops::Range) { + if Self::NEEDS_CLEANUP { + if range.end <= self.len1 { + self.cleanup1.cleanup_item_range(range); + } else if range.start >= self.len1 { + self.cleanup2 + .cleanup_item_range(range.start - self.len1..range.end - self.len1); + } else { + self.cleanup1.cleanup_item_range(range.start..self.len1); + self.cleanup2.cleanup_item_range(0..range.end - self.len1); + } } } } @@ -604,11 +687,15 @@ pub struct Enumerate { impl ParallelSource for Enumerate { type Item = (usize, Inner::Item); - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let descriptor = self.inner.descriptor(); SourceDescriptor { len: descriptor.len, fetch_item: move |index| (index, (descriptor.fetch_item)(index)), + cleanup: descriptor.cleanup, } } } @@ -627,11 +714,34 @@ pub struct Rev { impl ParallelSource for Rev { type Item = Inner::Item; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let descriptor = self.inner.descriptor(); SourceDescriptor { len: descriptor.len, fetch_item: move |index| (descriptor.fetch_item)(descriptor.len - index - 1), + cleanup: RevSourceCleanup { + inner: descriptor.cleanup, + len: descriptor.len, + }, + } + } +} + +struct RevSourceCleanup { + inner: Inner, + len: usize, +} + +impl SourceCleanup for RevSourceCleanup { + const NEEDS_CLEANUP: bool = Inner::NEEDS_CLEANUP; + + fn cleanup_item_range(&self, range: std::ops::Range) { + if Self::NEEDS_CLEANUP { + self.inner + .cleanup_item_range(self.len - range.end..self.len - range.start) } } } @@ -645,17 +755,49 @@ impl ParallelSource for Rev { #[must_use = "iterator adaptors are lazy"] pub struct Skip { inner: Inner, - len: usize, + count: usize, } impl ParallelSource for Skip { type Item = Inner::Item; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let descriptor = self.inner.descriptor(); + let count = std::cmp::min(self.count, descriptor.len); SourceDescriptor { - len: descriptor.len - std::cmp::min(self.len, descriptor.len), - fetch_item: move |index| (descriptor.fetch_item)(self.len + index), + len: descriptor.len - count, + fetch_item: move |index| (descriptor.fetch_item)(self.count + index), + cleanup: SkipSourceCleanup { + inner: descriptor.cleanup, + count, + }, + } + } +} + +struct SkipSourceCleanup { + inner: Inner, + count: usize, +} + +impl SourceCleanup for SkipSourceCleanup { + const NEEDS_CLEANUP: bool = Inner::NEEDS_CLEANUP; + + fn cleanup_item_range(&self, range: std::ops::Range) { + if Self::NEEDS_CLEANUP { + self.inner + .cleanup_item_range(self.count + range.start..self.count + range.end) + } + } +} + +impl Drop for SkipSourceCleanup { + fn drop(&mut self) { + if Self::NEEDS_CLEANUP && self.count != 0 { + self.inner.cleanup_item_range(0..self.count) } } } @@ -670,21 +812,28 @@ impl ParallelSource for Skip { #[must_use = "iterator adaptors are lazy"] pub struct SkipExact { inner: Inner, - len: usize, + count: usize, } impl ParallelSource for SkipExact { type Item = Inner::Item; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let descriptor = self.inner.descriptor(); assert!( - self.len <= descriptor.len, + self.count <= descriptor.len, "called skip_exact() with more items than this source produces" ); SourceDescriptor { - len: descriptor.len - self.len, - fetch_item: move |index| (descriptor.fetch_item)(self.len + index), + len: descriptor.len - self.count, + fetch_item: move |index| (descriptor.fetch_item)(self.count + index), + cleanup: SkipSourceCleanup { + inner: descriptor.cleanup, + count: self.count, + }, } } } @@ -704,7 +853,10 @@ pub struct StepBy { impl ParallelSource for StepBy { type Item = Inner::Item; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let descriptor = self.inner.descriptor(); assert!(self.step != 0, "called step_by() with a step of zero"); let len = if descriptor.len == 0 { @@ -715,6 +867,46 @@ impl ParallelSource for StepBy { SourceDescriptor { len, fetch_item: move |index| (descriptor.fetch_item)(self.step * index), + cleanup: StepByCleanup { + inner: descriptor.cleanup, + step: self.step, + inner_len: descriptor.len, + }, + } + } +} + +struct StepByCleanup { + inner: Inner, + step: usize, + inner_len: usize, +} + +impl SourceCleanup for StepByCleanup { + const NEEDS_CLEANUP: bool = Inner::NEEDS_CLEANUP; + + fn cleanup_item_range(&self, range: std::ops::Range) { + if Self::NEEDS_CLEANUP { + for i in range { + self.inner + .cleanup_item_range(self.step * i..self.step * i + 1); + } + } + } +} + +impl Drop for StepByCleanup { + fn drop(&mut self) { + if Self::NEEDS_CLEANUP && self.step != 1 { + let full_blocks = self.inner_len / self.step; + for i in 0..full_blocks { + self.inner + .cleanup_item_range(self.step * i + 1..self.step * (i + 1)); + } + let last_block = self.step * full_blocks + 1; + if last_block < self.inner_len { + self.inner.cleanup_item_range(last_block..self.inner_len); + } } } } @@ -728,17 +920,50 @@ impl ParallelSource for StepBy { #[must_use = "iterator adaptors are lazy"] pub struct Take { inner: Inner, - len: usize, + count: usize, } impl ParallelSource for Take { type Item = Inner::Item; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let descriptor = self.inner.descriptor(); + let count = std::cmp::min(self.count, descriptor.len); SourceDescriptor { - len: std::cmp::min(self.len, descriptor.len), + len: count, fetch_item: descriptor.fetch_item, + cleanup: TakeSourceCleanup { + inner: descriptor.cleanup, + count, + inner_len: descriptor.len, + }, + } + } +} + +struct TakeSourceCleanup { + inner: Inner, + count: usize, + inner_len: usize, +} + +impl SourceCleanup for TakeSourceCleanup { + const NEEDS_CLEANUP: bool = Inner::NEEDS_CLEANUP; + + fn cleanup_item_range(&self, range: std::ops::Range) { + if Self::NEEDS_CLEANUP { + self.inner.cleanup_item_range(range) + } + } +} + +impl Drop for TakeSourceCleanup { + fn drop(&mut self) { + if Self::NEEDS_CLEANUP && self.count != self.inner_len { + self.inner.cleanup_item_range(self.count..self.inner_len) } } } @@ -753,21 +978,29 @@ impl ParallelSource for Take { #[must_use = "iterator adaptors are lazy"] pub struct TakeExact { inner: Inner, - len: usize, + count: usize, } impl ParallelSource for TakeExact { type Item = Inner::Item; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let descriptor = self.inner.descriptor(); assert!( - self.len <= descriptor.len, + self.count <= descriptor.len, "called take_exact() with more items than this source produces" ); SourceDescriptor { - len: self.len, + len: self.count, fetch_item: descriptor.fetch_item, + cleanup: TakeSourceCleanup { + inner: descriptor.cleanup, + count: self.count, + inner_len: descriptor.len, + }, } } } @@ -802,6 +1035,7 @@ impl ParallelIterator for BaseParallelIterator<'_, S> { |acc, index| process_item(acc, index, (source_descriptor.fetch_item)(index)), finalize, reduce, + source_descriptor.cleanup, ) } @@ -815,8 +1049,12 @@ impl ParallelIterator for BaseParallelIterator<'_, S> { inner: accum, fetch_item: source_descriptor.fetch_item, }; - self.thread_pool - .iter_pipeline(source_descriptor.len, accumulator, reduce) + self.thread_pool.iter_pipeline( + source_descriptor.len, + accumulator, + reduce, + source_descriptor.cleanup, + ) } } @@ -832,7 +1070,6 @@ where FetchItem: Fn(usize) -> Item, { fn accumulate(&self, iter: impl Iterator) -> Output { - self.inner - .accumulate(iter.map(|index| (self.fetch_item)(index))) + self.inner.accumulate(iter.map(&self.fetch_item)) } } diff --git a/src/iter/source/owned_slice.rs b/src/iter/source/owned_slice.rs new file mode 100644 index 0000000..1826a6c --- /dev/null +++ b/src/iter/source/owned_slice.rs @@ -0,0 +1,272 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::{IntoParallelSource, ParallelSource, SourceCleanup, SourceDescriptor}; +use std::mem::ManuallyDrop; + +/// A parallel source over an [array](array). This struct is created by the +/// [`into_par_iter()`](IntoParallelSource::into_par_iter) method on +/// [`IntoParallelSource`]. +/// +/// You most likely won't need to interact with this struct directly, as it +/// implements the [`ParallelSource`] and +/// [`ParallelSourceExt`](super::ParallelSourceExt) traits, but it +/// is nonetheless public because of the `must_use` annotation. +#[must_use = "iterator adaptors are lazy"] +pub struct ArrayParallelSource { + array: [T; N], +} + +impl IntoParallelSource for [T; N] { + type Item = T; + type Source = ArrayParallelSource; + + fn into_par_iter(self) -> Self::Source { + ArrayParallelSource { array: self } + } +} + +impl ParallelSource for ArrayParallelSource { + type Item = T; + + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { + let mut array = ManuallyDrop::new(self.array); + let mut_ptr = array.as_mut_ptr(); + let ptr = PtrWrapper(mut_ptr as *const T); + SourceDescriptor { + len: N, + fetch_item: move |index| { + assert!(index < N); + let base_ptr: *const T = ptr.get(); + // SAFETY: + // - The offset in bytes `index * size_of::()` fits in an `isize`, because + // the index is smaller than the length of the (well-formed) input array. This + // is ensured by the thread pool's `pipeline()` function (which yields indices + // in the range `0..N`), and further confirmed by the assertion. + // - The `base_ptr` is derived from an allocated object (the input array), and + // the entire range between `base_ptr` and the resulting `item_ptr` is in + // bounds of that allocated object. This is because the index is smaller than + // the length of the input array. + let item_ptr: *const T = unsafe { base_ptr.add(index) }; + // SAFETY: + // - The `item_ptr` is properly aligned, as it is constructed by calling `add()` + // on the aligned `base_ptr`. + // - The `item_ptr` points to a properly initialized value of type `T`, the + // element from the input array at position `index`. + // - The `item_ptr` is valid for reads. This is ensured by the thread pool's + // `pipeline()` function (which yields distinct indices in the range `0..N`), + // i.e. this item hasn't been read (and moved out of the array) yet. + // Additionally, there are no concurrent writes to this slot in the array. + let item: T = unsafe { std::ptr::read(item_ptr) }; + item + }, + cleanup: OwnedSliceSourceCleanup { + ptr: MutPtrWrapper(mut_ptr), + }, + } + } +} + +struct OwnedSliceSourceCleanup { + ptr: MutPtrWrapper, +} + +impl SourceCleanup for OwnedSliceSourceCleanup { + const NEEDS_CLEANUP: bool = std::mem::needs_drop::(); + + fn cleanup_item_range(&self, range: std::ops::Range) { + if Self::NEEDS_CLEANUP { + let base_ptr: *mut T = self.ptr.get(); + // SAFETY: + // - The offset in bytes `range.start * size_of::()` fits in an `isize`, + // because the range is included in the length of the (well-formed) input + // array. This is ensured by the thread pool's `pipeline()` function (which + // only yields in-bound ranges for cleanup). + // - The `base_ptr` is derived from an allocated object (the input array), and + // the entire range between `base_ptr` and the resulting `start_ptr` is in + // bounds of that allocated object. This is because the range start is smaller + // than the length of the input array. + let start_ptr: *mut T = unsafe { base_ptr.add(range.start) }; + let slice: *mut [T] = + std::ptr::slice_from_raw_parts_mut(start_ptr, range.end - range.start); + // SAFETY: + // - The `slice` is properly aligned, as it is constructed by calling `add()` on + // the aligned `base_ptr`. + // - The `slice` isn't null, as it is constructed by calling `add()` on the + // non-null `base_ptr`. + // - The `slice` is valid for reads and writes. This is ensured by the thread + // pool's `pipeline()` function, which yields non-overlapping indices and + // cleanup ranges. I.e. the range of items in this slice isn't accessed by + // anything else. + // - The `slice` is valid for dropping, as it is a part of the input array that + // nothing else accesses. + // - Nothing else is accessing the `slice` while `drop_in_place` is executing. + // + // The `slice` is never of size zero, but the above properties (aligned, + // non-null, etc.) would still hold if it was. + unsafe { std::ptr::drop_in_place(slice) }; + } + } +} + +/// A parallel source over a [`Vec`]. This struct is created by the +/// [`into_par_iter()`](IntoParallelSource::into_par_iter) method on +/// [`IntoParallelSource`]. +/// +/// You most likely won't need to interact with this struct directly, as it +/// implements the [`ParallelSource`] and +/// [`ParallelSourceExt`](super::ParallelSourceExt) traits, but it +/// is nonetheless public because of the `must_use` annotation. +#[must_use = "iterator adaptors are lazy"] +pub struct VecParallelSource { + vec: Vec, +} + +impl IntoParallelSource for Vec { + type Item = T; + type Source = VecParallelSource; + + fn into_par_iter(self) -> Self::Source { + VecParallelSource { vec: self } + } +} + +impl IntoParallelSource for Box<[T]> { + type Item = T; + type Source = VecParallelSource; + + fn into_par_iter(self) -> Self::Source { + // There's no Box::<[T]>::from_raw_parts(), so we just piggy back on Vec. + VecParallelSource { + vec: self.into_vec(), + } + } +} + +impl ParallelSource for VecParallelSource { + type Item = T; + + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { + let mut vec = ManuallyDrop::new(self.vec); + let mut_ptr = vec.as_mut_ptr(); + let len = vec.len(); + let capacity = vec.capacity(); + let ptr = PtrWrapper(mut_ptr as *const T); + SourceDescriptor { + len, + fetch_item: move |index| { + assert!(index < len); + let base_ptr: *const T = ptr.get(); + // SAFETY: + // - The offset in bytes `index * size_of::()` fits in an `isize`, because + // the index is smaller than the length of the (well-formed) input vector. + // This is ensured by the thread pool's `pipeline()` function (which yields + // indices in the range `0..len`), and further confirmed by the assertion. + // - The `base_ptr` is derived from an allocated object (the input vector), and + // the entire range between `base_ptr` and the resulting `item_ptr` is in + // bounds of that allocated object. This is because the index is smaller than + // the length of the input vector. + let item_ptr: *const T = unsafe { base_ptr.add(index) }; + // SAFETY: + // - The `item_ptr` is properly aligned, as it is constructed by calling `add()` + // on the aligned `base_ptr`. + // - The `item_ptr` points to a properly initialized value of type `T`, the + // element from the input vector at position `index`. + // - The `item_ptr` is valid for reads. This is ensured by the thread pool's + // `pipeline()` function (which yields distinct indices in the range + // `0..len`), i.e. this item hasn't been read (and moved out of the vector) + // yet. Additionally, there are no concurrent writes to this slot in the + // vector. + let item: T = unsafe { std::ptr::read(item_ptr) }; + item + }, + cleanup: VecSourceCleanup { + slice: OwnedSliceSourceCleanup { + ptr: MutPtrWrapper(mut_ptr), + }, + capacity, + }, + } + } +} + +struct VecSourceCleanup { + slice: OwnedSliceSourceCleanup, + capacity: usize, +} + +impl Drop for VecSourceCleanup { + fn drop(&mut self) { + let base_ptr: *mut T = self.slice.ptr.get(); + // SAFETY: + // - The `base_ptr` has been allocated with the global allocator, as it is + // derived from the source vector. + // - `T` has the same alignement as what `base_ptr` was allocated with, because + // `base_ptr` derives from a vector of `T`s. + // - `T * capacity` is the size of what `base_ptr` was allocated with, because + // that's the capacity of the source vector. + // - `length <= capacity` because the `length` is set to zero here. + // - The first `length` values are properly initialized values of type `T` + // because the `length` is set to zero. + // - The allocated size in bytes isn't larger than `isize::MAX`, because that's + // derived from the source vector. + let vec: Vec = unsafe { Vec::from_raw_parts(base_ptr, 0, self.capacity) }; + drop(vec); + } +} + +impl SourceCleanup for VecSourceCleanup { + const NEEDS_CLEANUP: bool = OwnedSliceSourceCleanup::::NEEDS_CLEANUP; + + fn cleanup_item_range(&self, range: std::ops::Range) { + self.slice.cleanup_item_range(range); + } +} + +/// A helper struct for the implementation of [`OwnedSliceSourceCleanup`], that +/// wraps a [`*mut T`](pointer). This enables sending [`&mut [T]`](slice) to +/// other threads. +struct MutPtrWrapper(*mut T); +impl MutPtrWrapper { + fn get(&self) -> *mut T { + self.0 + } +} + +/// SAFETY: +/// +/// A [`MutPtrWrapper`] is meant to be shared among threads as a way to send +/// items of type [`&mut [T]`](slice) to other threads (see the safety +/// comments in [`OwnedSliceSourceCleanup::cleanup_item_range`]). Therefore we +/// make it [`Sync`] if and only if [`&mut [T]`](slice) is [`Send`], which is +/// when `T` is [`Send`]. +unsafe impl Sync for MutPtrWrapper {} + +/// A helper struct for the implementation of [`ArrayParallelSource`] and +/// [`VecParallelSource`], that wraps a [`*const T`](pointer). This enables +/// sending `T` to other threads. +struct PtrWrapper(*const T); +impl PtrWrapper { + fn get(&self) -> *const T { + self.0 + } +} + +/// SAFETY: +/// +/// A [`PtrWrapper`] is meant to be shared among threads as a way to send items +/// of type `T` to other threads (see the safety comments in +/// [`ArrayParallelSource::descriptor`] and [`VecParallelSource::descriptor`]). +/// Therefore we make it [`Sync`] if and only if `T` is [`Send`]. +unsafe impl Sync for PtrWrapper {} diff --git a/src/iter/source/range.rs b/src/iter/source/range.rs index c3bac10..d1e07f2 100644 --- a/src/iter/source/range.rs +++ b/src/iter/source/range.rs @@ -6,7 +6,9 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use super::{IntoParallelSource, ParallelSource, SourceDescriptor}; +use super::{ + IntoParallelSource, NoopSourceCleanup, ParallelSource, SourceCleanup, SourceDescriptor, +}; #[cfg(feature = "nightly")] use std::iter::Step; use std::ops::{Range, RangeInclusive}; @@ -38,7 +40,10 @@ impl IntoParallelSource for Range { impl ParallelSource for RangeParallelSource { type Item = T; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let range = self.range; let (len_hint, len) = T::steps_between(&range.start, &range.end); let len = len.unwrap_or_else(|| { @@ -54,6 +59,7 @@ impl ParallelSource for RangeParallelSource { SourceDescriptor { len, fetch_item: move |index| T::forward(range.start, index), + cleanup: NoopSourceCleanup, } } } @@ -72,7 +78,10 @@ impl IntoParallelSource for Range { impl ParallelSource for RangeParallelSource { type Item = usize; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let range = self.range; let len = range .end @@ -81,6 +90,7 @@ impl ParallelSource for RangeParallelSource { SourceDescriptor { len, fetch_item: move |index| range.start + index, + cleanup: NoopSourceCleanup, } } } @@ -112,7 +122,10 @@ impl IntoParallelSource for RangeInclusive { impl ParallelSource for RangeInclusiveParallelSource { type Item = T; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let (start, end) = self.range.into_inner(); let (len_hint, len) = T::steps_between(&start, &end); let len = len.unwrap_or_else(|| { @@ -134,6 +147,7 @@ impl ParallelSource for RangeInclusiveParallelSour SourceDescriptor { len, fetch_item: move |index| T::forward(start, index), + cleanup: NoopSourceCleanup, } } } @@ -152,7 +166,10 @@ impl IntoParallelSource for RangeInclusive { impl ParallelSource for RangeInclusiveParallelSource { type Item = usize; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let (start, end) = self.range.into_inner(); let len = end .checked_sub(start) @@ -166,6 +183,7 @@ impl ParallelSource for RangeInclusiveParallelSource { SourceDescriptor { len, fetch_item: move |index| start + index, + cleanup: NoopSourceCleanup, } } } diff --git a/src/iter/source/slice.rs b/src/iter/source/slice.rs index a62e871..12283ca 100644 --- a/src/iter/source/slice.rs +++ b/src/iter/source/slice.rs @@ -6,7 +6,9 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use super::{IntoParallelSource, ParallelSource, SourceDescriptor}; +use super::{ + IntoParallelSource, NoopSourceCleanup, ParallelSource, SourceCleanup, SourceDescriptor, +}; /// A parallel source over a [slice](slice). This struct is created by the /// [`par_iter()`](super::IntoParallelRefSource::par_iter) method on @@ -33,10 +35,14 @@ impl<'data, T: Sync> IntoParallelSource for &'data [T] { impl<'data, T: Sync> ParallelSource for SliceParallelSource<'data, T> { type Item = &'data T; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { SourceDescriptor { len: self.slice.len(), fetch_item: |index| &self.slice[index], + cleanup: NoopSourceCleanup, } } } @@ -67,7 +73,10 @@ impl<'data, T: Send> IntoParallelSource for &'data mut [T] { impl<'data, T: Send> ParallelSource for MutSliceParallelSource<'data, T> { type Item = &'data mut T; - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let len = self.slice.len(); let ptr = MutPtrWrapper(self.slice.as_mut_ptr()); SourceDescriptor { @@ -112,6 +121,7 @@ impl<'data, T: Send> ParallelSource for MutSliceParallelSource<'data, T> { let item: &mut T = unsafe { &mut *item_ptr }; item }, + cleanup: NoopSourceCleanup, } } } diff --git a/src/iter/source/zip.rs b/src/iter/source/zip.rs index 6ad47f8..70fedec 100644 --- a/src/iter/source/zip.rs +++ b/src/iter/source/zip.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use super::{ParallelSource, SourceDescriptor}; +use super::{ParallelSource, SourceCleanup, SourceDescriptor}; /// A helper trait for zipping together multiple [`ParallelSource`]s into a /// single [`ParallelSource`] that produces items grouped from the original @@ -184,6 +184,12 @@ macro_rules! min_lens { } } +macro_rules! or_bools { + ( $tuple:expr, $zero:tt, $($i:tt),* ) => { + $tuple.0 $( || $tuple.$i )* + } +} + macro_rules! zipable_tuple { ( $($tuple:ident $i:tt),+ ) => { impl<$($tuple),+> ZipableSource for ($($tuple),+) @@ -193,7 +199,10 @@ macro_rules! zipable_tuple { where $($tuple: ParallelSource),+ { type Item = ( $($tuple::Item),+ ); - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let tuple = self.0; let descriptors = ( $(tuple.$i.descriptor()),+ ); assert_eq_lens!(descriptors, $($i),+); @@ -202,6 +211,7 @@ macro_rules! zipable_tuple { fetch_item: move |index| { ( $( (descriptors.$i.fetch_item)(index) ),+ ) }, + cleanup: ( $(descriptors.$i.cleanup),+ ), } } } @@ -210,7 +220,10 @@ macro_rules! zipable_tuple { where $($tuple: ParallelSource),+ { type Item = ( $(Option<$tuple::Item>),+ ); - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let tuple = self.0; let descriptors = ( $(tuple.$i.descriptor()),+ ); let mut len = 0; @@ -226,6 +239,7 @@ macro_rules! zipable_tuple { } ),+ ) }, + cleanup: ( $(descriptors.$i.cleanup),+ ), } } } @@ -234,7 +248,10 @@ macro_rules! zipable_tuple { where $($tuple: ParallelSource),+ { type Item = ( $($tuple::Item),+ ); - fn descriptor(self) -> SourceDescriptor Self::Item + Sync> { + fn descriptor( + self, + ) -> SourceDescriptor Self::Item + Sync, impl SourceCleanup + Sync> + { let tuple = self.0; let descriptors = ( $(tuple.$i.descriptor()),+ ); let len = min_lens!(descriptors, $($i),+); @@ -243,6 +260,21 @@ macro_rules! zipable_tuple { fetch_item: move |index| { ( $( (descriptors.$i.fetch_item)(index) ),+ ) }, + cleanup: ( $(descriptors.$i.cleanup),+ ), + } + } + } + + impl<$($tuple),+> SourceCleanup for ($($tuple),+) + where $($tuple: SourceCleanup),+ { + const NEEDS_CLEANUP: bool = { + let need_cleanups = ( $($tuple::NEEDS_CLEANUP),+ ); + or_bools!(need_cleanups, $($i),+) + }; + + fn cleanup_item_range(&self, range: std::ops::Range) { + if Self::NEEDS_CLEANUP { + $( self.$i.cleanup_item_range(range.clone()); )+ } } } diff --git a/src/lib.rs b/src/lib.rs index 82e5d81..a4e7bd5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -129,18 +129,29 @@ mod test { test_source_range_inclusive_u64_too_large => fail("cannot iterate over a range with more than usize::MAX items"), #[cfg(feature = "nightly")] test_source_range_inclusive_u128_too_large => fail("cannot iterate over a range with more than usize::MAX items"), + test_source_vec, + test_source_vec_boxed, + test_source_vec_find_any, + test_source_vec_find_first, + test_source_vec_panic => fail("worker thread(s) panicked!"), + test_source_vec_find_any_panic => fail("worker thread(s) panicked!"), + test_source_vec_find_first_panic => fail("worker thread(s) panicked!"), test_source_adaptor_chain, test_source_adaptor_chain_overflow => fail("called chain() with sources that together produce more than usize::MAX items"), test_source_adaptor_enumerate, test_source_adaptor_rev, test_source_adaptor_skip, + test_source_adaptor_skip_cleanup, test_source_adaptor_skip_exact, test_source_adaptor_skip_exact_too_much => fail("called skip_exact() with more items than this source produces"), test_source_adaptor_step_by, + test_source_adaptor_step_by_cleanup, test_source_adaptor_step_by_zero => fail("called step_by() with a step of zero"), test_source_adaptor_step_by_zero_empty => fail("called step_by() with a step of zero"), test_source_adaptor_take, + test_source_adaptor_take_cleanup, test_source_adaptor_take_exact, + test_source_adaptor_take_exact_cleanup, test_source_adaptor_take_exact_too_much => fail("called take_exact() with more items than this source produces"), test_source_adaptor_zip_eq, test_source_adaptor_zip_eq_unequal => fail("called zip_eq() with sources of different lengths"), @@ -1043,6 +1054,133 @@ mod test { .sum::(); } + fn test_source_vec(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (0..=INPUT_LEN).collect::>(); + let sum = input + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .sum::(); + assert_eq!(sum, INPUT_LEN * (INPUT_LEN + 1) / 2); + } + + fn test_source_vec_boxed(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (0..=INPUT_LEN).map(Box::new).collect::>>(); + let sum = input + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .map(|x| *x) + .sum::(); + assert_eq!(sum, INPUT_LEN * (INPUT_LEN + 1) / 2); + } + + fn test_source_vec_find_any(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (0..=INPUT_LEN).map(Box::new).collect::>>(); + let needle = input + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .find_any(|x| **x % 10 == 9); + assert!(needle.is_some()); + assert_eq!(*needle.unwrap() % 10, 9); + } + + fn test_source_vec_find_first(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (0..=INPUT_LEN).map(Box::new).collect::>>(); + let needle = input + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .find_first(|x| **x % 10 == 9); + assert_eq!(needle, Some(Box::new(9))); + } + + fn test_source_vec_panic(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (0..=INPUT_LEN).map(Box::new).collect::>>(); + input + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .for_each(|x| { + if *x % 2 == 1 { + panic!("arithmetic panic"); + } + }); + } + + fn test_source_vec_find_any_panic(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (0..=INPUT_LEN).map(Box::new).collect::>>(); + input + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .find_any(|x| { + if **x % 2 == 1 { + true + } else { + panic!("arithmetic panic"); + } + }); + } + + fn test_source_vec_find_first_panic(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (0..=INPUT_LEN).map(Box::new).collect::>>(); + input + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .find_first(|x| { + if **x % 2 == 1 { + true + } else { + panic!("arithmetic panic"); + } + }); + } + fn test_source_adaptor_chain(range_strategy: RangeStrategy) { let mut thread_pool = ThreadPoolBuilder { num_threads: ThreadCount::AvailableParallelism, @@ -1140,6 +1278,36 @@ mod test { assert_eq!(sum_empty, 0); } + fn test_source_adaptor_skip_cleanup(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (1..=2 * INPUT_LEN).map(Box::new).collect::>>(); + let sum = input + .into_par_iter() + .skip(INPUT_LEN as usize / 2) + .with_thread_pool(&mut thread_pool) + .map(|x| *x) + .sum::(); + assert_eq!( + sum, + ((3 * INPUT_LEN + 1) / 2) * ((5 * INPUT_LEN) / 2 + 1) / 2 + ); + + let input = (1..=2 * INPUT_LEN).map(Box::new).collect::>>(); + let sum = input + .into_par_iter() + .skip(3 * INPUT_LEN as usize / 2) + .with_thread_pool(&mut thread_pool) + .map(|x| *x) + .sum::(); + assert_eq!(sum, ((INPUT_LEN + 1) / 2) * ((7 * INPUT_LEN) / 2 + 1) / 2); + } + fn test_source_adaptor_skip_exact(range_strategy: RangeStrategy) { let mut thread_pool = ThreadPoolBuilder { num_threads: ThreadCount::AvailableParallelism, @@ -1212,6 +1380,24 @@ mod test { assert_eq!(sum_empty, 0); } + fn test_source_adaptor_step_by_cleanup(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (0..=2 * INPUT_LEN).map(Box::new).collect::>>(); + let sum_by_2 = input + .into_par_iter() + .step_by(2) + .with_thread_pool(&mut thread_pool) + .map(|x| *x) + .sum::(); + assert_eq!(sum_by_2, INPUT_LEN * (INPUT_LEN + 1)); + } + fn test_source_adaptor_step_by_zero(range_strategy: RangeStrategy) { let mut thread_pool = ThreadPoolBuilder { num_threads: ThreadCount::AvailableParallelism, @@ -1266,6 +1452,33 @@ mod test { assert_eq!(sum_all, INPUT_LEN * (INPUT_LEN + 1) / 2); } + fn test_source_adaptor_take_cleanup(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (1..=2 * INPUT_LEN).map(Box::new).collect::>>(); + let sum = input + .into_par_iter() + .take(INPUT_LEN as usize / 2) + .with_thread_pool(&mut thread_pool) + .map(|x| *x) + .sum::(); + assert_eq!(sum, ((INPUT_LEN / 2) * (INPUT_LEN / 2 + 1)) / 2); + + let input = (1..=2 * INPUT_LEN).map(Box::new).collect::>>(); + let sum = input + .into_par_iter() + .take(3 * INPUT_LEN as usize / 2) + .with_thread_pool(&mut thread_pool) + .map(|x| *x) + .sum::(); + assert_eq!(sum, ((3 * INPUT_LEN / 2) * (3 * INPUT_LEN / 2 + 1)) / 2); + } + fn test_source_adaptor_take_exact(range_strategy: RangeStrategy) { let mut thread_pool = ThreadPoolBuilder { num_threads: ThreadCount::AvailableParallelism, @@ -1283,6 +1496,33 @@ mod test { assert_eq!(sum, ((INPUT_LEN / 2) * (INPUT_LEN / 2 + 1)) / 2); } + fn test_source_adaptor_take_exact_cleanup(range_strategy: RangeStrategy) { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let input = (1..=2 * INPUT_LEN).map(Box::new).collect::>>(); + let sum = input + .into_par_iter() + .take_exact(INPUT_LEN as usize / 2) + .with_thread_pool(&mut thread_pool) + .map(|x| *x) + .sum::(); + assert_eq!(sum, ((INPUT_LEN / 2) * (INPUT_LEN / 2 + 1)) / 2); + + let input = (1..=2 * INPUT_LEN).map(Box::new).collect::>>(); + let sum = input + .into_par_iter() + .take_exact(3 * INPUT_LEN as usize / 2) + .with_thread_pool(&mut thread_pool) + .map(|x| *x) + .sum::(); + assert_eq!(sum, ((3 * INPUT_LEN / 2) * (3 * INPUT_LEN / 2 + 1)) / 2); + } + fn test_source_adaptor_take_exact_too_much(range_strategy: RangeStrategy) { let mut thread_pool = ThreadPoolBuilder { num_threads: ThreadCount::AvailableParallelism,