diff --git a/src/lib.rs b/src/lib.rs index 44a8e13..f45de23 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,6 +76,7 @@ mod test { test_several_functions, test_several_accumulators, test_several_input_types, + test_several_pipelines, ); }; } @@ -219,7 +220,7 @@ mod test { // The scope should accept FnOnce() parameter. We test it with a closure that // captures and consumes a non-Copy type. let token = Box::new(()); - pool_builder.scope::<(), ()>(|_| drop(token)); + pool_builder.scope(|_| drop(token)); } fn test_local_sum(range_strategy: RangeStrategy) { @@ -372,6 +373,48 @@ mod test { assert_eq!(sum_lengths, expected_sum_lengths(INPUT_LEN)); } + fn test_several_pipelines(range_strategy: RangeStrategy) { + let pool_builder = ThreadPoolBuilder { + num_threads: NonZeroUsize::try_from(4).unwrap(), + range_strategy, + }; + let (sum, sum_pairs) = pool_builder.scope(|thread_pool| { + // Pipelines with different types can be used successively. + let input = (0..=INPUT_LEN).collect::>(); + let sum = thread_pool.pipeline( + &input, + || 0u64, + |acc, _, &x| *acc += x, + |acc| acc, + |a, b| a + b, + ); + + let input = (0..=INPUT_LEN) + .map(|i| (2 * i, 2 * i + 1)) + .collect::>(); + let sum_pairs = thread_pool.pipeline( + &input, + || (0u64, 0u64), + |(a, b), _, &(x, y)| { + *a += x; + *b += y; + }, + |acc| acc, + |(a, b), (x, y)| (a + x, b + y), + ); + + (sum, sum_pairs) + }); + assert_eq!(sum, INPUT_LEN * (INPUT_LEN + 1) / 2); + assert_eq!( + sum_pairs, + ( + INPUT_LEN * (INPUT_LEN + 1), + (INPUT_LEN + 1) * (INPUT_LEN + 1) + ) + ); + } + const fn expected_sum_lengths(max: u64) -> u64 { if max < 10 { max + 1 diff --git a/src/thread_pool.rs b/src/thread_pool.rs index d878a92..a4553a0 100644 --- a/src/thread_pool.rs +++ b/src/thread_pool.rs @@ -64,7 +64,7 @@ impl ThreadPoolBuilder { /// }); /// assert_eq!(sum, 5 * 11); /// ``` - pub fn scope(&self, f: impl FnOnce(ThreadPool) -> R) -> R { + pub fn scope(&self, f: impl FnOnce(ThreadPool) -> R) -> R { std::thread::scope(|scope| { let thread_pool = ThreadPool::new(scope, self.num_threads, self.range_strategy); f(thread_pool) @@ -112,10 +112,9 @@ impl RoundColor { /// A thread pool tied to a scope, that can process inputs into outputs of the /// given types. -#[allow(clippy::type_complexity)] -pub struct ThreadPool<'scope, Output> { +pub struct ThreadPool<'scope> { /// Handles to all the worker threads in the pool. - threads: Vec>, + threads: Vec>, /// Number of worker threads active in the current round. num_active_threads: Arc, /// Color of the current round. @@ -129,15 +128,13 @@ pub struct ThreadPool<'scope, Output> { /// everything. range_orchestrator: Box, /// Pipeline to map and reduce inputs into the output. - pipeline: Arc + Send + Sync + 'scope>>>>, + pipeline: Arc>>>, } /// Handle to a worker thread in the pool. -struct WorkerThreadHandle<'scope, Output> { +struct WorkerThreadHandle<'scope> { /// Thread handle object. handle: ScopedJoinHandle<'scope, ()>, - /// Storage for this thread's computation output. - output: Arc>>, } /// Strategy to distribute ranges of work items among threads. @@ -149,7 +146,7 @@ pub enum RangeStrategy { WorkStealing, } -impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> { +impl<'scope> ThreadPool<'scope> { /// Creates a new pool tied to the given scope, spawning the given number of /// worker threads. fn new<'env>( @@ -200,15 +197,12 @@ impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> { log_warn!("Pinning threads to CPUs is not implemented on this platform."); let threads = (0..num_threads) .map(|id| { - let output = Arc::new(Mutex::new(None)); let context = ThreadContext { - #[cfg(feature = "log")] id, num_active_threads: num_active_threads.clone(), worker_status: worker_status.clone(), main_status: main_status.clone(), range: range_factory.range(id), - output: output.clone(), pipeline: pipeline.clone(), }; WorkerThreadHandle { @@ -234,7 +228,6 @@ impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> { } context.run() }), - output, } }) .collect(); @@ -281,7 +274,7 @@ impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> { /// assert_eq!(sum, 5 * 11); /// # }); /// ``` - pub fn pipeline( + pub fn pipeline( &self, input: &[Input], init: impl Fn() -> Accum + Send + Sync + 'static, @@ -298,8 +291,13 @@ impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> { round.toggle(); self.round.set(round); + let outputs = (0..num_threads) + .map(|_| Mutex::new(None)) + .collect::>(); + *self.pipeline.write().unwrap() = Some(Box::new(PipelineImpl { input: SliceView::new(input), + outputs: outputs.clone(), init: Box::new(init), process_item: Box::new(process_item), finalize: Box::new(finalize), @@ -323,15 +321,15 @@ impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> { log_debug!("[main thread, round {round:?}] All threads have now finished this round."); *self.pipeline.write().unwrap() = None; - self.threads + outputs .iter() - .map(move |t| t.output.lock().unwrap().take().unwrap()) + .map(move |output| output.lock().unwrap().take().unwrap()) .reduce(reduce) .unwrap() } } -impl Drop for ThreadPool<'_, Output> { +impl Drop for ThreadPool<'_> { /// Joins all the threads in the pool. #[allow(clippy::single_match, clippy::unused_enumerate_index)] fn drop(&mut self) { @@ -353,20 +351,21 @@ impl Drop for ThreadPool<'_, Output> { } } -trait Pipeline { - fn run(&self, range: &mut dyn Iterator) -> Output; +trait Pipeline { + fn run(&self, worker_id: usize, range: &mut dyn Iterator); } #[allow(clippy::type_complexity)] struct PipelineImpl { input: SliceView, + outputs: Arc<[Mutex>]>, init: Box Accum + Send + Sync>, process_item: Box, finalize: Box Output + Send + Sync>, } -impl Pipeline for PipelineImpl { - fn run(&self, range: &mut dyn Iterator) -> Output { +impl Pipeline for PipelineImpl { + fn run(&self, worker_id: usize, range: &mut dyn Iterator) { // SAFETY: the underlying input slice is valid and not mutated for the whole // lifetime of this block. let input = unsafe { self.input.get().unwrap() }; @@ -374,15 +373,14 @@ impl Pipeline for PipelineImpl { +struct ThreadContext<'scope, Rn: Range> { /// Thread index. - #[cfg(feature = "log")] id: usize, /// Number of worker threads active in the current round. num_active_threads: Arc, @@ -392,13 +390,11 @@ struct ThreadContext<'scope, Rn: Range, Output> { main_status: Arc>, /// Range of items that this worker thread needs to process. range: Rn, - /// Output that this thread writes to. - output: Arc>>, /// Pipeline to map and reduce inputs into the output. - pipeline: Arc + Send + Sync + 'scope>>>>, + pipeline: Arc>>>, } -impl ThreadContext<'_, Rn, Output> { +impl ThreadContext<'_, Rn> { /// Main function run by this thread. fn run(&self) { let mut round = RoundColor::Blue; @@ -439,7 +435,7 @@ impl ThreadContext<'_, Rn, Output> { { let guard = self.pipeline.read().unwrap(); let pipeline = guard.as_ref().unwrap(); - *self.output.lock().unwrap() = Some(pipeline.run(&mut self.range.iter())); + pipeline.run(self.id, &mut self.range.iter()); } std::mem::forget(panic_notifier);