Skip to content

Commit

Permalink
Merge branch 'ar/faster-includes' into 'master'
Browse files Browse the repository at this point in the history
[extract, pileup] faster when regions provided

See merge request machine-learning/modkit!221
  • Loading branch information
ArtRand committed Sep 14, 2024
2 parents 708985c + 2142781 commit f4e1d42
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 55 deletions.
10 changes: 10 additions & 0 deletions src/extract/subcommand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,16 @@ impl ExtractMods {
);
let reference_records =
get_targets(reader.header(), region);
let reference_records =
if let Some(pf) = include_positions.as_ref() {
pf.optimize_reference_records(
reference_records,
self.interval_size,
)
} else {
reference_records
};

let feeder = ReferenceIntervalsFeeder::new(
reference_records,
(self.threads as f32 * 1.5f32).floor() as usize,
Expand Down
43 changes: 14 additions & 29 deletions src/interval_chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::position_filter::{GenomeIntervals, Iv, StrandedPositionFilter};
use crate::util::{ReferenceRecord, StrandRule};
use anyhow::{anyhow, bail};
use derive_new::new;
use itertools::Itertools;
use log::debug;
use rustc_hash::FxHashMap;

Expand Down Expand Up @@ -498,36 +499,20 @@ impl ReferenceIntervalsFeeder {
if combine_strands & !multi_motif_locations.is_some() {
bail!("cannot combine strands without a motif")
}
let mut contigs = if let Some(position_filter) =
position_filter.as_ref()
{
// todo do more aggressive "narrowing"
reference_records
.into_iter()
.filter_map(|contig| {
position_filter.contig_ends(&contig.tid).map(|(s, t)| {
let length = t.checked_sub(s).unwrap_or(0u64);
debug!(
"narrowing record {} to {s}-{t} ({length} bases)",
contig.name.as_str()
);
ReferenceRecord::new(
contig.tid,
s as u32,
length as u32,
contig.name,
)
})
})
.collect::<VecDeque<_>>()
} else {
reference_records.into_iter().collect::<VecDeque<_>>()
};

if contigs.len() == 1 {
debug!("there is a single contig to work on");
let mut contigs =
reference_records.into_iter().collect::<VecDeque<_>>();
let n_contigs = contigs.iter().map(|r| r.tid).unique().count();

if n_contigs == 1 {
debug!(
"there is a single contig to work on (in {} parts)",
contigs.len()
);
} else {
debug!("there are {} contigs to work on", contigs.len());
debug!(
"there are {n_contigs} contig(s) to work on ({} parts)",
contigs.len()
);
}
let curr_contig = contigs
.pop_front()
Expand Down
10 changes: 10 additions & 0 deletions src/pileup/subcommand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,11 @@ impl ModBamPileup {
}

let (snd, rx) = bounded(self.queue_size);
let reference_records = if let Some(pf) = position_filter.as_ref() {
pf.optimize_reference_records(reference_records, self.interval_size)
} else {
reference_records
};
let feeder = ReferenceIntervalsFeeder::new(
reference_records,
chunk_size,
Expand Down Expand Up @@ -1322,6 +1327,11 @@ impl DuplexModBamPileup {

// from here down could also be it's own "Processor"
let (snd, rx) = bounded(self.queue_size); // todo figure out sane default for this?
let reference_records = if let Some(pf) = position_filter.as_ref() {
pf.optimize_reference_records(reference_records, self.interval_size)
} else {
reference_records
};
let feeder = ReferenceIntervalsFeeder::new(
reference_records,
chunk_size,
Expand Down
151 changes: 125 additions & 26 deletions src/position_filter.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
use crate::mod_base_code::DnaBase;
use crate::util::{get_targets, get_ticker, Strand};
use std::collections::{HashMap, HashSet, VecDeque};
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;

use anyhow::bail;
use itertools::Itertools;
use log::info;
use log_once::info_once;
use rust_htslib::bam::{self, Read};
use rust_lapper as lapper;
use rustc_hash::FxHashMap;
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;

use crate::mod_base_code::DnaBase;
use crate::util::{get_targets, get_ticker, ReferenceRecord, Strand};

pub(crate) type Iv = lapper::Interval<u64, ()>;
pub(crate) type GenomeIntervals<T> = lapper::Lapper<u64, T>;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct StrandedPositionFilter<T: Send + Sync + Eq + Clone> {
pub(crate) pos_positions: FxHashMap<u32, GenomeIntervals<T>>,
pub(crate) neg_positions: FxHashMap<u32, GenomeIntervals<T>>,
Expand Down Expand Up @@ -72,19 +75,6 @@ impl<T: Send + Sync + Eq + Clone> StrandedPositionFilter<T> {
}
}

pub fn iter_intervals(
&self,
) -> impl Iterator<Item = (u32, &lapper::Interval<u64, T>)> + '_ {
self.pos_positions
.iter()
.flat_map(|(tid, lp)| lp.intervals.iter().map(|iv| (*tid, iv)))
.chain(
self.neg_positions.iter().flat_map(|(tid, lp)| {
lp.intervals.iter().map(|iv| (*tid, iv))
}),
)
}

pub fn contig_ends(&self, contig_id: &u32) -> Option<(u64, u64)> {
let get_start_end = |positions: &FxHashMap<u32, GenomeIntervals<T>>| -> Option<(u64, u64)> {
positions.get(&contig_id)
Expand All @@ -109,6 +99,114 @@ impl<T: Send + Sync + Eq + Clone> StrandedPositionFilter<T> {
_ => None,
}
}

fn group_genome_intervals(
genome_intervals: GenomeIntervals<T>,
reference_record: &ReferenceRecord,
interval_size: u64,
) -> Vec<ReferenceRecord> {
let mut intervals =
genome_intervals.intervals.into_iter().collect::<VecDeque<_>>();
if intervals.is_empty() {
return Vec::new();
}

let mut current = intervals.pop_front().unwrap();
let mut agg = Vec::new();
while let Some(iv) = intervals.pop_front() {
if current.stop.checked_sub(current.start).expect(&format!(
"invalid interval coordinates, {}:{}",
current.start, current.stop
)) > interval_size
{
let finished = std::mem::replace(&mut current, iv);
agg.push(finished);
continue;
}
current.stop = iv.stop
}
agg.push(current);
agg.into_iter()
.map(|iv| {
let start = iv.start as u32;
let length = iv
.stop
.checked_sub(iv.start)
.expect("invalid final interval")
as u32;
ReferenceRecord::new(
reference_record.tid,
start,
length,
reference_record.name.clone(),
)
})
.collect()
}

pub(crate) fn optimize_reference_records(
&self,
reference_records: Vec<ReferenceRecord>,
interval_size: u32,
) -> Vec<ReferenceRecord> {
let lut = reference_records
.into_iter()
.map(|rec| (rec.tid, rec))
.collect::<HashMap<u32, ReferenceRecord>>();

let contig_ids = self
.pos_positions
.keys()
.chain(self.neg_positions.keys())
.unique()
.copied()
.collect::<Vec<u32>>();

contig_ids
.iter()
// shouldn't really need this filter
.filter_map(|tid| lut.get(tid))
.flat_map(|ref_record| {
let tid = &ref_record.tid;
let mut pos = self
.pos_positions
.get(tid)
.map(|ivs| ivs.intervals.clone())
.unwrap_or_else(Vec::new);
let mut neg = self
.neg_positions
.get(tid)
.map(|ivs| ivs.intervals.clone())
.unwrap_or_else(Vec::new);
pos.append(&mut neg);
#[cfg(debug_assertions)]
{
for iv in pos.iter() {
assert!(iv.start <= iv.stop);
}
}
let mut genome_intervals = GenomeIntervals::new(pos);
#[cfg(debug_assertions)]
{
for iv in genome_intervals.intervals.iter() {
assert!(iv.start <= iv.stop);
}
}
genome_intervals.merge_overlaps();
#[cfg(debug_assertions)]
{
for iv in genome_intervals.intervals.iter() {
assert!(iv.start <= iv.stop);
}
}
Self::group_genome_intervals(
genome_intervals,
ref_record,
interval_size as u64,
)
})
.collect()
}
}

impl StrandedPositionFilter<()> {
Expand Down Expand Up @@ -155,10 +253,10 @@ impl StrandedPositionFilter<()> {
if warned.contains(chrom_name) {
continue;
}
let raw_start = &parts[1].parse::<u64>();
let raw_end = &parts[2].parse::<u64>();
let raw_start = parts[1].parse::<u64>();
let raw_end = parts[2].parse::<u64>();
let (start, stop) = match (raw_start, raw_end) {
(Ok(start), Ok(end)) => (*start, *end),
(Ok(start), Ok(end)) => (start, end),
_ => {
info!(
"improperly formatted BED line, failed to parse start \
Expand Down Expand Up @@ -198,6 +296,7 @@ impl StrandedPositionFilter<()> {
continue;
}
};
debug_assert!(start <= stop, "start should be before stop");
if let Some(chrom_id) = chrom_to_target_id.get(chrom_name) {
if pos_strand {
pos_positions
Expand All @@ -222,7 +321,7 @@ impl StrandedPositionFilter<()> {
bail!("zero valid positions parsed from BED file")
}

let pos_lapper = pos_positions
let pos_intervals = pos_positions
.into_iter()
.map(|(chrom_id, intervals)| {
let mut lp = lapper::Lapper::new(intervals);
Expand All @@ -231,7 +330,7 @@ impl StrandedPositionFilter<()> {
})
.collect::<FxHashMap<u32, GenomeIntervals<()>>>();

let neg_lapper = neg_positions
let neg_intervals = neg_positions
.into_iter()
.map(|(chrom_id, intervals)| {
let mut lp = lapper::Lapper::new(intervals);
Expand All @@ -243,7 +342,7 @@ impl StrandedPositionFilter<()> {
lines_processed.finish_and_clear();
info!("processed {} BED lines", lines_processed.position());

Ok(Self { pos_positions: pos_lapper, neg_positions: neg_lapper })
Ok(Self { pos_positions: pos_intervals, neg_positions: neg_intervals })
}
}

Expand Down

0 comments on commit f4e1d42

Please sign in to comment.