diff --git a/src/extract/subcommand.rs b/src/extract/subcommand.rs index a0815ff..f82d1cd 100644 --- a/src/extract/subcommand.rs +++ b/src/extract/subcommand.rs @@ -474,6 +474,16 @@ impl ExtractMods { ); let reference_records = get_targets(reader.header(), region); + let reference_records = + if let Some(pf) = include_positions.as_ref() { + pf.optimize_reference_records( + reference_records, + self.interval_size, + ) + } else { + reference_records + }; + let feeder = ReferenceIntervalsFeeder::new( reference_records, (self.threads as f32 * 1.5f32).floor() as usize, diff --git a/src/interval_chunks.rs b/src/interval_chunks.rs index 8548799..dac0626 100644 --- a/src/interval_chunks.rs +++ b/src/interval_chunks.rs @@ -8,6 +8,7 @@ use crate::position_filter::{GenomeIntervals, Iv, StrandedPositionFilter}; use crate::util::{ReferenceRecord, StrandRule}; use anyhow::{anyhow, bail}; use derive_new::new; +use itertools::Itertools; use log::debug; use rustc_hash::FxHashMap; @@ -498,36 +499,20 @@ impl ReferenceIntervalsFeeder { if combine_strands & !multi_motif_locations.is_some() { bail!("cannot combine strands without a motif") } - let mut contigs = if let Some(position_filter) = - position_filter.as_ref() - { - // todo do more aggressive "narrowing" - reference_records - .into_iter() - .filter_map(|contig| { - position_filter.contig_ends(&contig.tid).map(|(s, t)| { - let length = t.checked_sub(s).unwrap_or(0u64); - debug!( - "narrowing record {} to {s}-{t} ({length} bases)", - contig.name.as_str() - ); - ReferenceRecord::new( - contig.tid, - s as u32, - length as u32, - contig.name, - ) - }) - }) - .collect::>() - } else { - reference_records.into_iter().collect::>() - }; - - if contigs.len() == 1 { - debug!("there is a single contig to work on"); + let mut contigs = + reference_records.into_iter().collect::>(); + let n_contigs = contigs.iter().map(|r| r.tid).unique().count(); + + if n_contigs == 1 { + debug!( + "there is a single contig to work on (in {} parts)", + contigs.len() + ); } else { - debug!("there are {} contigs to work on", contigs.len()); + debug!( + "there are {n_contigs} contig(s) to work on ({} parts)", + contigs.len() + ); } let curr_contig = contigs .pop_front() diff --git a/src/pileup/subcommand.rs b/src/pileup/subcommand.rs index a17e524..51562e4 100644 --- a/src/pileup/subcommand.rs +++ b/src/pileup/subcommand.rs @@ -657,6 +657,11 @@ impl ModBamPileup { } let (snd, rx) = bounded(self.queue_size); + let reference_records = if let Some(pf) = position_filter.as_ref() { + pf.optimize_reference_records(reference_records, self.interval_size) + } else { + reference_records + }; let feeder = ReferenceIntervalsFeeder::new( reference_records, chunk_size, @@ -1322,6 +1327,11 @@ impl DuplexModBamPileup { // from here down could also be it's own "Processor" let (snd, rx) = bounded(self.queue_size); // todo figure out sane default for this? + let reference_records = if let Some(pf) = position_filter.as_ref() { + pf.optimize_reference_records(reference_records, self.interval_size) + } else { + reference_records + }; let feeder = ReferenceIntervalsFeeder::new( reference_records, chunk_size, diff --git a/src/position_filter.rs b/src/position_filter.rs index ded4bec..d718bb5 100644 --- a/src/position_filter.rs +++ b/src/position_filter.rs @@ -1,20 +1,23 @@ -use crate::mod_base_code::DnaBase; -use crate::util::{get_targets, get_ticker, Strand}; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; + use anyhow::bail; +use itertools::Itertools; use log::info; use log_once::info_once; use rust_htslib::bam::{self, Read}; use rust_lapper as lapper; use rustc_hash::FxHashMap; -use std::collections::{HashMap, HashSet}; -use std::fs::File; -use std::io::{BufRead, BufReader}; -use std::path::PathBuf; + +use crate::mod_base_code::DnaBase; +use crate::util::{get_targets, get_ticker, ReferenceRecord, Strand}; pub(crate) type Iv = lapper::Interval; pub(crate) type GenomeIntervals = lapper::Lapper; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct StrandedPositionFilter { pub(crate) pos_positions: FxHashMap>, pub(crate) neg_positions: FxHashMap>, @@ -72,19 +75,6 @@ impl StrandedPositionFilter { } } - pub fn iter_intervals( - &self, - ) -> impl Iterator)> + '_ { - self.pos_positions - .iter() - .flat_map(|(tid, lp)| lp.intervals.iter().map(|iv| (*tid, iv))) - .chain( - self.neg_positions.iter().flat_map(|(tid, lp)| { - lp.intervals.iter().map(|iv| (*tid, iv)) - }), - ) - } - pub fn contig_ends(&self, contig_id: &u32) -> Option<(u64, u64)> { let get_start_end = |positions: &FxHashMap>| -> Option<(u64, u64)> { positions.get(&contig_id) @@ -109,6 +99,114 @@ impl StrandedPositionFilter { _ => None, } } + + fn group_genome_intervals( + genome_intervals: GenomeIntervals, + reference_record: &ReferenceRecord, + interval_size: u64, + ) -> Vec { + let mut intervals = + genome_intervals.intervals.into_iter().collect::>(); + if intervals.is_empty() { + return Vec::new(); + } + + let mut current = intervals.pop_front().unwrap(); + let mut agg = Vec::new(); + while let Some(iv) = intervals.pop_front() { + if current.stop.checked_sub(current.start).expect(&format!( + "invalid interval coordinates, {}:{}", + current.start, current.stop + )) > interval_size + { + let finished = std::mem::replace(&mut current, iv); + agg.push(finished); + continue; + } + current.stop = iv.stop + } + agg.push(current); + agg.into_iter() + .map(|iv| { + let start = iv.start as u32; + let length = iv + .stop + .checked_sub(iv.start) + .expect("invalid final interval") + as u32; + ReferenceRecord::new( + reference_record.tid, + start, + length, + reference_record.name.clone(), + ) + }) + .collect() + } + + pub(crate) fn optimize_reference_records( + &self, + reference_records: Vec, + interval_size: u32, + ) -> Vec { + let lut = reference_records + .into_iter() + .map(|rec| (rec.tid, rec)) + .collect::>(); + + let contig_ids = self + .pos_positions + .keys() + .chain(self.neg_positions.keys()) + .unique() + .copied() + .collect::>(); + + contig_ids + .iter() + // shouldn't really need this filter + .filter_map(|tid| lut.get(tid)) + .flat_map(|ref_record| { + let tid = &ref_record.tid; + let mut pos = self + .pos_positions + .get(tid) + .map(|ivs| ivs.intervals.clone()) + .unwrap_or_else(Vec::new); + let mut neg = self + .neg_positions + .get(tid) + .map(|ivs| ivs.intervals.clone()) + .unwrap_or_else(Vec::new); + pos.append(&mut neg); + #[cfg(debug_assertions)] + { + for iv in pos.iter() { + assert!(iv.start <= iv.stop); + } + } + let mut genome_intervals = GenomeIntervals::new(pos); + #[cfg(debug_assertions)] + { + for iv in genome_intervals.intervals.iter() { + assert!(iv.start <= iv.stop); + } + } + genome_intervals.merge_overlaps(); + #[cfg(debug_assertions)] + { + for iv in genome_intervals.intervals.iter() { + assert!(iv.start <= iv.stop); + } + } + Self::group_genome_intervals( + genome_intervals, + ref_record, + interval_size as u64, + ) + }) + .collect() + } } impl StrandedPositionFilter<()> { @@ -155,10 +253,10 @@ impl StrandedPositionFilter<()> { if warned.contains(chrom_name) { continue; } - let raw_start = &parts[1].parse::(); - let raw_end = &parts[2].parse::(); + let raw_start = parts[1].parse::(); + let raw_end = parts[2].parse::(); let (start, stop) = match (raw_start, raw_end) { - (Ok(start), Ok(end)) => (*start, *end), + (Ok(start), Ok(end)) => (start, end), _ => { info!( "improperly formatted BED line, failed to parse start \ @@ -198,6 +296,7 @@ impl StrandedPositionFilter<()> { continue; } }; + debug_assert!(start <= stop, "start should be before stop"); if let Some(chrom_id) = chrom_to_target_id.get(chrom_name) { if pos_strand { pos_positions @@ -222,7 +321,7 @@ impl StrandedPositionFilter<()> { bail!("zero valid positions parsed from BED file") } - let pos_lapper = pos_positions + let pos_intervals = pos_positions .into_iter() .map(|(chrom_id, intervals)| { let mut lp = lapper::Lapper::new(intervals); @@ -231,7 +330,7 @@ impl StrandedPositionFilter<()> { }) .collect::>>(); - let neg_lapper = neg_positions + let neg_intervals = neg_positions .into_iter() .map(|(chrom_id, intervals)| { let mut lp = lapper::Lapper::new(intervals); @@ -243,7 +342,7 @@ impl StrandedPositionFilter<()> { lines_processed.finish_and_clear(); info!("processed {} BED lines", lines_processed.position()); - Ok(Self { pos_positions: pos_lapper, neg_positions: neg_lapper }) + Ok(Self { pos_positions: pos_intervals, neg_positions: neg_intervals }) } }