Skip to content

Commit

Permalink
Add step_by() adaptor.
Browse files Browse the repository at this point in the history
  • Loading branch information
gendx committed Nov 27, 2024
1 parent ad5e880 commit 44e3d12
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 0 deletions.
103 changes: 103 additions & 0 deletions src/iter/source/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,79 @@ pub trait ParallelSourceExt: ParallelSource {
}
}

/// Returns a parallel source that produces every `n`-th item from this
/// source, starting with the first one.
///
/// In other words, the returned source produces the items at indices `0`,
/// `n`, `2*n`, etc.
///
/// ```
/// # use paralight::iter::{IntoParallelRefSource, ParallelIteratorExt, ParallelSourceExt};
/// # use paralight::{CpuPinningPolicy, RangeStrategy, ThreadCount, ThreadPoolBuilder};
/// # let mut thread_pool = ThreadPoolBuilder {
/// # num_threads: ThreadCount::AvailableParallelism,
/// # range_strategy: RangeStrategy::WorkStealing,
/// # cpu_pinning: CpuPinningPolicy::No,
/// # }
/// # .build();
/// let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
///
/// let sum = input
/// .par_iter()
/// .step_by(2)
/// .with_thread_pool(&mut thread_pool)
/// .sum::<i32>();
/// assert_eq!(sum, 1 + 3 + 5 + 7 + 9);
///
/// let sum = input
/// .par_iter()
/// .step_by(3)
/// .with_thread_pool(&mut thread_pool)
/// .sum::<i32>();
/// assert_eq!(sum, 1 + 4 + 7 + 10);
/// ```
///
/// This panics if the step is zero, even if the underlying source is empty.
///
/// ```should_panic
/// # use paralight::iter::{IntoParallelRefSource, ParallelIteratorExt, ParallelSourceExt};
/// # use paralight::{CpuPinningPolicy, RangeStrategy, ThreadCount, ThreadPoolBuilder};
/// # let mut thread_pool = ThreadPoolBuilder {
/// # num_threads: ThreadCount::AvailableParallelism,
/// # range_strategy: RangeStrategy::WorkStealing,
/// # cpu_pinning: CpuPinningPolicy::No,
/// # }
/// # .build();
/// let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
/// let _ = input
/// .par_iter()
/// .step_by(0)
/// .with_thread_pool(&mut thread_pool)
/// .sum::<i32>();
/// ```
///
/// ```should_panic
/// # use paralight::iter::{IntoParallelRefSource, ParallelIteratorExt, ParallelSourceExt};
/// # use paralight::{CpuPinningPolicy, RangeStrategy, ThreadCount, ThreadPoolBuilder};
/// # let mut thread_pool = ThreadPoolBuilder {
/// # num_threads: ThreadCount::AvailableParallelism,
/// # range_strategy: RangeStrategy::WorkStealing,
/// # cpu_pinning: CpuPinningPolicy::No,
/// # }
/// # .build();
/// let _ = []
/// .par_iter()
/// .step_by(0)
/// .with_thread_pool(&mut thread_pool)
/// .sum::<i32>();
/// ```
fn step_by(self, n: usize) -> StepBy<Self> {
StepBy {
inner: self,
step: n,
}
}

/// Returns a parallel source that produces the first `n` items from this
/// source, or all the items if this source has fewer than `n` items.
///
Expand Down Expand Up @@ -615,6 +688,36 @@ impl<Inner: ParallelSource> ParallelSource for SkipExact<Inner> {
}
}

/// This struct is created by the
/// [`step_by()`](ParallelSourceExt::step_by) method on [`ParallelSourceExt`].
///
/// You most likely won't need to interact with this struct directly, as it
/// implements the [`ParallelSource`] and [`ParallelSourceExt`] traits, but it
/// is nonetheless public because of the `must_use` annotation.
#[must_use = "iterator adaptors are lazy"]
pub struct StepBy<Inner> {
inner: Inner,
step: usize,
}

impl<Inner: ParallelSource> ParallelSource for StepBy<Inner> {
type Item = Inner::Item;

fn descriptor(self) -> SourceDescriptor<Self::Item, impl Fn(usize) -> Self::Item + Sync> {
let descriptor = self.inner.descriptor();
assert!(self.step != 0, "called step_by() with a step of zero");
let len = if descriptor.len == 0 {
0
} else {
(descriptor.len - 1) / self.step + 1
};
SourceDescriptor {
len,
fetch_item: move |index| (descriptor.fetch_item)(self.step * index),
}
}
}

/// This struct is created by the [`take()`](ParallelSourceExt::take) method on
/// [`ParallelSourceExt`].
///
Expand Down
72 changes: 72 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ mod test {
test_source_adaptor_skip,
test_source_adaptor_skip_exact,
test_source_adaptor_skip_exact_too_much => fail("called skip_exact() with more items than this source produces"),
test_source_adaptor_step_by,
test_source_adaptor_step_by_zero => fail("called step_by() with a step of zero"),
test_source_adaptor_step_by_zero_empty => fail("called step_by() with a step of zero"),
test_source_adaptor_take,
test_source_adaptor_take_exact,
test_source_adaptor_take_exact_too_much => fail("called take_exact() with more items than this source produces"),
Expand Down Expand Up @@ -1148,6 +1151,75 @@ mod test {
.sum::<u64>();
}

fn test_source_adaptor_step_by(range_strategy: RangeStrategy) {
let mut thread_pool = ThreadPoolBuilder {
num_threads: ThreadCount::AvailableParallelism,
range_strategy,
cpu_pinning: CpuPinningPolicy::No,
}
.build();

let mut input = (0..=2 * INPUT_LEN).collect::<Vec<u64>>();
let sum_by_1 = input
.par_iter()
.step_by(1)
.with_thread_pool(&mut thread_pool)
.sum::<u64>();
assert_eq!(sum_by_1, INPUT_LEN * (2 * INPUT_LEN + 1));

let sum_by_2 = input
.par_iter()
.step_by(2)
.with_thread_pool(&mut thread_pool)
.sum::<u64>();
assert_eq!(sum_by_2, INPUT_LEN * (INPUT_LEN + 1));

input.truncate(2 * INPUT_LEN as usize);
let sum_by_2 = input
.par_iter()
.step_by(2)
.with_thread_pool(&mut thread_pool)
.sum::<u64>();
assert_eq!(sum_by_2, (INPUT_LEN - 1) * INPUT_LEN);

let sum_empty = []
.par_iter()
.step_by(2)
.with_thread_pool(&mut thread_pool)
.sum::<u64>();
assert_eq!(sum_empty, 0);
}

fn test_source_adaptor_step_by_zero(range_strategy: RangeStrategy) {
let mut thread_pool = ThreadPoolBuilder {
num_threads: ThreadCount::AvailableParallelism,
range_strategy,
cpu_pinning: CpuPinningPolicy::No,
}
.build();

let input = (0..=INPUT_LEN).collect::<Vec<u64>>();
input
.par_iter()
.step_by(0)
.with_thread_pool(&mut thread_pool)
.sum::<u64>();
}

fn test_source_adaptor_step_by_zero_empty(range_strategy: RangeStrategy) {
let mut thread_pool = ThreadPoolBuilder {
num_threads: ThreadCount::AvailableParallelism,
range_strategy,
cpu_pinning: CpuPinningPolicy::No,
}
.build();

[].par_iter()
.step_by(0)
.with_thread_pool(&mut thread_pool)
.sum::<u64>();
}

fn test_source_adaptor_take(range_strategy: RangeStrategy) {
let mut thread_pool = ThreadPoolBuilder {
num_threads: ThreadCount::AvailableParallelism,
Expand Down

0 comments on commit 44e3d12

Please sign in to comment.