Skip to content

Commit

Permalink
Make the output type dynamic.
Browse files Browse the repository at this point in the history
  • Loading branch information
gendx committed Sep 23, 2024
1 parent 152529f commit 3adc098
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 31 deletions.
45 changes: 44 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ mod test {
test_several_functions,
test_several_accumulators,
test_several_input_types,
test_several_pipelines,
);
};
}
Expand Down Expand Up @@ -219,7 +220,7 @@ mod test {
// The scope should accept FnOnce() parameter. We test it with a closure that
// captures and consumes a non-Copy type.
let token = Box::new(());
pool_builder.scope::<(), ()>(|_| drop(token));
pool_builder.scope(|_| drop(token));
}

fn test_local_sum(range_strategy: RangeStrategy) {
Expand Down Expand Up @@ -372,6 +373,48 @@ mod test {
assert_eq!(sum_lengths, expected_sum_lengths(INPUT_LEN));
}

fn test_several_pipelines(range_strategy: RangeStrategy) {
let pool_builder = ThreadPoolBuilder {
num_threads: NonZeroUsize::try_from(4).unwrap(),
range_strategy,
};
let (sum, sum_pairs) = pool_builder.scope(|thread_pool| {
// Pipelines with different types can be used successively.
let input = (0..=INPUT_LEN).collect::<Vec<u64>>();
let sum = thread_pool.pipeline(
&input,
|| 0u64,
|acc, _, &x| *acc += x,
|acc| acc,
|a, b| a + b,
);

let input = (0..=INPUT_LEN)
.map(|i| (2 * i, 2 * i + 1))
.collect::<Vec<(u64, u64)>>();
let sum_pairs = thread_pool.pipeline(
&input,
|| (0u64, 0u64),
|(a, b), _, &(x, y)| {
*a += x;
*b += y;
},
|acc| acc,
|(a, b), (x, y)| (a + x, b + y),
);

(sum, sum_pairs)
});
assert_eq!(sum, INPUT_LEN * (INPUT_LEN + 1) / 2);
assert_eq!(
sum_pairs,
(
INPUT_LEN * (INPUT_LEN + 1),
(INPUT_LEN + 1) * (INPUT_LEN + 1)
)
);
}

const fn expected_sum_lengths(max: u64) -> u64 {
if max < 10 {
max + 1
Expand Down
56 changes: 26 additions & 30 deletions src/thread_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl ThreadPoolBuilder {
/// });
/// assert_eq!(sum, 5 * 11);
/// ```
pub fn scope<Output: Send, R>(&self, f: impl FnOnce(ThreadPool<Output>) -> R) -> R {
pub fn scope<R>(&self, f: impl FnOnce(ThreadPool) -> R) -> R {
std::thread::scope(|scope| {
let thread_pool = ThreadPool::new(scope, self.num_threads, self.range_strategy);
f(thread_pool)
Expand Down Expand Up @@ -112,10 +112,9 @@ impl RoundColor {

/// A thread pool tied to a scope, that can process inputs into outputs of the
/// given types.
#[allow(clippy::type_complexity)]
pub struct ThreadPool<'scope, Output> {
pub struct ThreadPool<'scope> {
/// Handles to all the worker threads in the pool.
threads: Vec<WorkerThreadHandle<'scope, Output>>,
threads: Vec<WorkerThreadHandle<'scope>>,
/// Number of worker threads active in the current round.
num_active_threads: Arc<AtomicUsize>,
/// Color of the current round.
Expand All @@ -129,15 +128,13 @@ pub struct ThreadPool<'scope, Output> {
/// everything.
range_orchestrator: Box<dyn RangeOrchestrator>,
/// Pipeline to map and reduce inputs into the output.
pipeline: Arc<RwLock<Option<Box<dyn Pipeline<Output> + Send + Sync + 'scope>>>>,
pipeline: Arc<RwLock<Option<Box<dyn Pipeline + Send + Sync + 'scope>>>>,
}

/// Handle to a worker thread in the pool.
struct WorkerThreadHandle<'scope, Output> {
struct WorkerThreadHandle<'scope> {
/// Thread handle object.
handle: ScopedJoinHandle<'scope, ()>,
/// Storage for this thread's computation output.
output: Arc<Mutex<Option<Output>>>,
}

/// Strategy to distribute ranges of work items among threads.
Expand All @@ -149,7 +146,7 @@ pub enum RangeStrategy {
WorkStealing,
}

impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> {
impl<'scope> ThreadPool<'scope> {
/// Creates a new pool tied to the given scope, spawning the given number of
/// worker threads.
fn new<'env>(
Expand Down Expand Up @@ -200,15 +197,12 @@ impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> {
log_warn!("Pinning threads to CPUs is not implemented on this platform.");
let threads = (0..num_threads)
.map(|id| {
let output = Arc::new(Mutex::new(None));
let context = ThreadContext {
#[cfg(feature = "log")]
id,
num_active_threads: num_active_threads.clone(),
worker_status: worker_status.clone(),
main_status: main_status.clone(),
range: range_factory.range(id),
output: output.clone(),
pipeline: pipeline.clone(),
};
WorkerThreadHandle {
Expand All @@ -234,7 +228,6 @@ impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> {
}
context.run()
}),
output,
}
})
.collect();
Expand Down Expand Up @@ -281,7 +274,7 @@ impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> {
/// assert_eq!(sum, 5 * 11);
/// # });
/// ```
pub fn pipeline<Input: Sync + 'scope, Accum: 'scope>(
pub fn pipeline<Input: Sync + 'scope, Output: Send + 'scope, Accum: 'scope>(
&self,
input: &[Input],
init: impl Fn() -> Accum + Send + Sync + 'static,
Expand All @@ -298,8 +291,13 @@ impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> {
round.toggle();
self.round.set(round);

let outputs = (0..num_threads)
.map(|_| Mutex::new(None))
.collect::<Arc<[_]>>();

*self.pipeline.write().unwrap() = Some(Box::new(PipelineImpl {
input: SliceView::new(input),
outputs: outputs.clone(),
init: Box::new(init),
process_item: Box::new(process_item),
finalize: Box::new(finalize),
Expand All @@ -323,15 +321,15 @@ impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> {
log_debug!("[main thread, round {round:?}] All threads have now finished this round.");
*self.pipeline.write().unwrap() = None;

self.threads
outputs
.iter()
.map(move |t| t.output.lock().unwrap().take().unwrap())
.map(move |output| output.lock().unwrap().take().unwrap())
.reduce(reduce)
.unwrap()
}
}

impl<Output> Drop for ThreadPool<'_, Output> {
impl Drop for ThreadPool<'_> {
/// Joins all the threads in the pool.
#[allow(clippy::single_match, clippy::unused_enumerate_index)]
fn drop(&mut self) {
Expand All @@ -353,36 +351,36 @@ impl<Output> Drop for ThreadPool<'_, Output> {
}
}

trait Pipeline<Output> {
fn run(&self, range: &mut dyn Iterator<Item = usize>) -> Output;
trait Pipeline {
fn run(&self, worker_id: usize, range: &mut dyn Iterator<Item = usize>);
}

#[allow(clippy::type_complexity)]
struct PipelineImpl<Input, Output, Accum> {
input: SliceView<Input>,
outputs: Arc<[Mutex<Option<Output>>]>,
init: Box<dyn Fn() -> Accum + Send + Sync>,
process_item: Box<dyn Fn(&mut Accum, usize, &Input) + Send + Sync>,
finalize: Box<dyn Fn(Accum) -> Output + Send + Sync>,
}

impl<Input, Output, Accum> Pipeline<Output> for PipelineImpl<Input, Output, Accum> {
fn run(&self, range: &mut dyn Iterator<Item = usize>) -> Output {
impl<Input, Output, Accum> Pipeline for PipelineImpl<Input, Output, Accum> {
fn run(&self, worker_id: usize, range: &mut dyn Iterator<Item = usize>) {
// SAFETY: the underlying input slice is valid and not mutated for the whole
// lifetime of this block.
let input = unsafe { self.input.get().unwrap() };
let mut accumulator = (self.init)();
for i in range {
(self.process_item)(&mut accumulator, i, &input[i]);
}
(self.finalize)(accumulator)
let output = (self.finalize)(accumulator);
*self.outputs[worker_id].lock().unwrap() = Some(output);
}
}

/// Context object owned by a worker thread.
#[allow(clippy::type_complexity)]
struct ThreadContext<'scope, Rn: Range, Output> {
struct ThreadContext<'scope, Rn: Range> {
/// Thread index.
#[cfg(feature = "log")]
id: usize,
/// Number of worker threads active in the current round.
num_active_threads: Arc<AtomicUsize>,
Expand All @@ -392,13 +390,11 @@ struct ThreadContext<'scope, Rn: Range, Output> {
main_status: Arc<Status<MainStatus>>,
/// Range of items that this worker thread needs to process.
range: Rn,
/// Output that this thread writes to.
output: Arc<Mutex<Option<Output>>>,
/// Pipeline to map and reduce inputs into the output.
pipeline: Arc<RwLock<Option<Box<dyn Pipeline<Output> + Send + Sync + 'scope>>>>,
pipeline: Arc<RwLock<Option<Box<dyn Pipeline + Send + Sync + 'scope>>>>,
}

impl<Rn: Range, Output> ThreadContext<'_, Rn, Output> {
impl<Rn: Range> ThreadContext<'_, Rn> {
/// Main function run by this thread.
fn run(&self) {
let mut round = RoundColor::Blue;
Expand Down Expand Up @@ -439,7 +435,7 @@ impl<Rn: Range, Output> ThreadContext<'_, Rn, Output> {
{
let guard = self.pipeline.read().unwrap();
let pipeline = guard.as_ref().unwrap();
*self.output.lock().unwrap() = Some(pipeline.run(&mut self.range.iter()));
pipeline.run(self.id, &mut self.range.iter());
}
std::mem::forget(panic_notifier);

Expand Down

0 comments on commit 3adc098

Please sign in to comment.