From fd450e7f133491018d5eaf91e42f05779acdcbdb Mon Sep 17 00:00:00 2001 From: Guillaume Endignoux Date: Thu, 5 Dec 2024 10:17:48 +0100 Subject: [PATCH] Add an API to get the number of threads spawned in a pool. --- src/core/thread_pool.rs | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/core/thread_pool.rs b/src/core/thread_pool.rs index c70cd12..0eb327c 100644 --- a/src/core/thread_pool.rs +++ b/src/core/thread_pool.rs @@ -115,7 +115,8 @@ impl ThreadPoolBuilder { /// A thread pool that can execute parallel pipelines. /// -/// This type doesn't expose any public methods. You can interact with it via +/// This type doesn't expose any public methods other than +/// [`num_threads()`](Self::num_threads). You can interact with it via /// the [`ThreadPoolBuilder::build()`] function to create a thread pool, and the /// [`with_thread_pool()`](crate::iter::ParallelSourceExt::with_thread_pool) /// method to attach a thread pool to a parallel iterator. @@ -131,6 +132,12 @@ impl ThreadPool { } } + /// Returns the number of worker threads that have been spawned in this + /// thread pool. + pub fn num_threads(&self) -> NonZeroUsize { + self.inner.num_threads() + } + /// Processes an input of the given length in parallel and returns the /// aggregated output. pub(crate) fn pipeline( @@ -232,6 +239,15 @@ impl ThreadPoolEnum { } } + /// Returns the number of worker threads that have been spawned in this + /// thread pool. + fn num_threads(&self) -> NonZeroUsize { + match self { + ThreadPoolEnum::Fixed(inner) => inner.num_threads(), + ThreadPoolEnum::WorkStealing(inner) => inner.num_threads(), + } + } + /// Processes an input of the given length in parallel and returns the /// aggregated output. fn pipeline( @@ -418,6 +434,12 @@ impl ThreadPoolImpl { } } + /// Returns the number of worker threads that have been spawned in this + /// thread pool. + fn num_threads(&self) -> NonZeroUsize { + self.threads.len().try_into().unwrap() + } + /// Processes an input of the given length in parallel and returns the /// aggregated output. fn pipeline(