diff --git a/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz b/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz index 01ac1cd8..9bbd83c6 100644 Binary files a/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz and b/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz differ diff --git a/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz.tbi b/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz.tbi index 8216d8f5..ce1dacd2 100644 Binary files a/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz.tbi and b/src/chromatinhd/data/examples/pbmc10ktiny/fragments.tsv.gz.tbi differ diff --git a/src/chromatinhd/data/examples/pbmc10ktiny/transcriptome.h5ad b/src/chromatinhd/data/examples/pbmc10ktiny/transcriptome.h5ad index fae3f41f..9062dabf 100644 Binary files a/src/chromatinhd/data/examples/pbmc10ktiny/transcriptome.h5ad and b/src/chromatinhd/data/examples/pbmc10ktiny/transcriptome.h5ad differ diff --git a/src/chromatinhd/data/peakcounts/plot.py b/src/chromatinhd/data/peakcounts/plot.py index 0df3485e..9cb9c802 100644 --- a/src/chromatinhd/data/peakcounts/plot.py +++ b/src/chromatinhd/data/peakcounts/plot.py @@ -104,6 +104,10 @@ def extract_peaks(peaks_bed, promoter, peakcaller): peaks = peaks.rename(columns={"name": "cluster"}) peaks["cluster"] = peaks["cluster"].astype(int) + if peakcaller == "rolling_500": + peaks["start"] = peaks["start"] + 250 + peaks["end"] = peaks["end"] + 250 + if len(peaks) > 0: peaks["peak"] = peaks["chrom"] + ":" + peaks["start"].astype(str) + "-" + peaks["end"].astype(str) peaks = center_peaks(peaks, promoter) @@ -337,7 +341,7 @@ def _plot_peaks(ax, plotdata, y, lw=0.5, fc="#555"): peak["end"] - peak["start"], 1, fc=fc, - lw=0, + lw=0.5, ) ax.add_patch(rect) # ax.plot([peak["start"]] * 2, [y, y + 1], color="grey", lw=lw) diff --git a/src/chromatinhd/grid/__init__.py b/src/chromatinhd/grid/__init__.py index 9b3ea42f..f49c7f49 100644 --- a/src/chromatinhd/grid/__init__.py +++ b/src/chromatinhd/grid/__init__.py @@ -1,2 +1,2 @@ from .grid import Grid, Figure, Panel, Wrap, Ax -from .broken import Broken, BrokenGrid +from .broken import Broken, BrokenGrid, Breaking diff --git a/src/chromatinhd/grid/broken.py b/src/chromatinhd/grid/broken.py index 1798eefc..e293b7e7 100644 --- a/src/chromatinhd/grid/broken.py +++ b/src/chromatinhd/grid/broken.py @@ -7,7 +7,7 @@ @dataclasses.dataclass class Breaking: regions: pd.DataFrame - gap: int + gap: int = 0.1 resolution: int = 5000 @property @@ -29,7 +29,11 @@ def __init__(self, breaking, height=0.5, margin_height=0.0, *args, **kwargs): regions["ix"] = np.arange(len(regions)) for i, (region, region_info) in enumerate(regions.iterrows()): - subpanel_width = region_info["width"] / breaking.resolution + if "resolution" in region_info.index: + resolution = region_info["resolution"] + else: + resolution = breaking.resolution + subpanel_width = region_info["width"] / resolution panel, ax = self.add_right( Panel((subpanel_width, height + 1e-4)), ) diff --git a/src/chromatinhd/grid/grid.py b/src/chromatinhd/grid/grid.py index 812c2815..b1c4d58b 100644 --- a/src/chromatinhd/grid/grid.py +++ b/src/chromatinhd/grid/grid.py @@ -59,6 +59,10 @@ def height(self): h += AXIS_HEIGHT return h + @height.setter + def height(self, value): + self.dim = (self.dim[0], value) + @property def width(self): w = self.dim[0] @@ -66,6 +70,10 @@ def width(self): w += AXIS_WIDTH return w + @width.setter + def width(self, value): + self.dim = (value, self.dim[1]) + def align(self): pass @@ -218,9 +226,7 @@ def __getitem__(self, key): def get_bottom_left_corner(self): nrow = (len(self.elements) - 1) // self.ncol - print(len(self.elements)) ix = (nrow) * self.ncol - print(nrow, self.ncol, ix) return self.elements[ix] # return self.elements[self.ncol * ((len(self.elements) % self.ncol) - 1)] diff --git a/src/chromatinhd/models/pred/interpret/censorers.py b/src/chromatinhd/models/pred/interpret/censorers.py index 9680574f..e4bd19e7 100644 --- a/src/chromatinhd/models/pred/interpret/censorers.py +++ b/src/chromatinhd/models/pred/interpret/censorers.py @@ -61,7 +61,7 @@ def __call__(self, data): class MultiWindowCensorer: - def __init__(self, window, window_sizes=(50, 100, 200, 500), relative_stride=0.5): + def __init__(self, window, window_sizes=(25, 50, 100, 200, 500), relative_stride=0.5): design = [{"window": "control"}] for window_size in window_sizes: cuts = np.arange(*window, step=int(window_size * relative_stride)) diff --git a/src/chromatinhd/models/pred/interpret/regionmultiwindow.py b/src/chromatinhd/models/pred/interpret/regionmultiwindow.py index dc762e19..f1dff735 100644 --- a/src/chromatinhd/models/pred/interpret/regionmultiwindow.py +++ b/src/chromatinhd/models/pred/interpret/regionmultiwindow.py @@ -60,8 +60,6 @@ def create( regions = fragments.regions.var.index - print(len(folds)) - coords_pointed = { regions.name: regions, "fold": pd.Index(range(len(folds)), name="fold"), @@ -305,8 +303,6 @@ def _interpolate(self, region): effects = np.stack(effects) losts = np.stack(losts) - print(deltacors.min()) - scores_statistical = [] for i in range(deltacors.shape[1]): if deltacors.shape[0] > 1: diff --git a/src/chromatinhd/models/pred/model/__init__.py b/src/chromatinhd/models/pred/model/__init__.py index 103fdfb3..9afde205 100644 --- a/src/chromatinhd/models/pred/model/__init__.py +++ b/src/chromatinhd/models/pred/model/__init__.py @@ -1,2 +1,3 @@ from . import additive from . import nonadditive +from . import better as multiscale diff --git a/src/chromatinhd/models/pred/model/better.py b/src/chromatinhd/models/pred/model/better.py index 1a34f0ae..def91e85 100644 --- a/src/chromatinhd/models/pred/model/better.py +++ b/src/chromatinhd/models/pred/model/better.py @@ -288,16 +288,10 @@ class LibrarySizeEncoder(torch.nn.Module): def __init__(self, fragments, n_layers=1, scale=1.0): super().__init__() - # library_size = np.bincount(fragments.mapping[:, 0], minlength=fragments.n_cells) - # self.register_buffer( - # "differential_library_size", - # torch.from_numpy((library_size - library_size.mean()) / library_size.std()).float() * scale, - # ) self.scale = scale self.n_embedding_dimensions = 1 def forward(self, data): - # return self.differential_library_size[data.minibatch.cells_oi].reshape(-1, 1) return data.fragments.libsize.reshape(-1, 1) * self.scale @@ -321,8 +315,7 @@ def __init__( **kwargs, ): self.n_input_embedding_dimensions = n_input_embedding_dimensions - # self.n_embedding_dimensions = n_embedding_dimensions - self.n_embedding_dimensions = n_input_embedding_dimensions # required for residual layers, sorry.. + self.n_embedding_dimensions = n_input_embedding_dimensions self.residual = residual super().__init__() @@ -382,31 +375,6 @@ def forward(self, cell_region_embedding): return embedding.reshape(cell_region_embedding.shape[:-1]) -class catchtime(object): - def __init__(self, dict, name): - self.name = name - self.dict = dict - - def __enter__(self): - self.t = time.time() - return self - - def __exit__(self, type, value, traceback): - self.t = time.time() - self.t - self.dict[self.name] += self.t - - -import collections - - -class timer(object): - def __init__(self): - self.times = collections.defaultdict(float) - - def catch(self, name): - return catchtime(self.times, name) - - class Model(FlowModel): """ Predicting region expression from raw fragments using an additive model across fragments from the same cell @@ -447,27 +415,27 @@ def __init__( transcriptome: Transcriptome | None = None, fold=None, dummy: bool = False, - n_frequencies: int = 50, + n_frequencies: int = (1000, 500, 250, 125, 63, 31), reduce: str = "sum", - nonlinear: bool = True, - n_embedding_dimensions: int = 10, + nonlinear: bool = "silu", + n_embedding_dimensions: int = 100, embedding_to_expression_initialization: str = "default", dropout_rate_fragment_embedder: float = 0.0, n_layers_fragment_embedder=1, - residual_fragment_embedder=False, + residual_fragment_embedder=True, batchnorm_fragment_embedder=False, layernorm_fragment_embedder=False, - n_layers_embedding2expression=1, + n_layers_embedding2expression=5, dropout_rate_embedding2expression: float = 0.0, - residual_embedding2expression=False, + residual_embedding2expression=True, batchnorm_embedding2expression=False, - layernorm_embedding2expression=False, + layernorm_embedding2expression=True, layer=None, reset=False, encoder=None, pooler=None, - distance_encoder=None, - library_size_encoder=None, + distance_encoder="direct", + library_size_encoder="linear", library_size_encoder_kwargs=None, region_oi=None, encoder_kwargs=None, @@ -574,81 +542,55 @@ def forward_region_loss(self, data): # return region_pairzmse_loss(expression_predicted, expression_true) def forward_multiple(self, data, fragments_oi, min_fragments=1): - timing = timer() + fragment_embedding = self.fragment_embedder(data) - with timing.catch("prep"): - fragment_embedding = self.fragment_embedder(data) + total_n_fragments = torch.bincount( + data.fragments.local_cellxregion_ix, + minlength=data.minibatch.n_regions * data.minibatch.n_cells, + ).reshape((data.minibatch.n_cells, data.minibatch.n_regions)) - total_n_fragments = torch.bincount( - data.fragments.local_cellxregion_ix, - minlength=data.minibatch.n_regions * data.minibatch.n_cells, - ).reshape((data.minibatch.n_cells, data.minibatch.n_regions)) + total_cell_region_embedding = self.embedding_region_pooler.forward( + fragment_embedding, + data.fragments.local_cellxregion_ix, + data.minibatch.n_cells, + data.minibatch.n_regions, + ) + cell_region_embedding = total_cell_region_embedding - total_cell_region_embedding = self.embedding_region_pooler.forward( - fragment_embedding, - data.fragments.local_cellxregion_ix, - data.minibatch.n_cells, - data.minibatch.n_regions, + if hasattr(self, "library_size_encoder"): + cell_region_embedding = torch.cat( + [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1 ) - cell_region_embedding = total_cell_region_embedding - if hasattr(self, "library_size_encoder"): - cell_region_embedding = torch.cat( - [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1 + total_expression_predicted = self.embedding_to_expression.forward(cell_region_embedding) + + for fragments_oi_ in fragments_oi: + if (fragments_oi_ is not None) and ((~fragments_oi_).sum().item() > min_fragments): + lost_fragments_oi = ~fragments_oi_ + lost_local_cellxregion_ix = data.fragments.local_cellxregion_ix[lost_fragments_oi] + n_fragments = total_n_fragments - torch.bincount( + lost_local_cellxregion_ix, + minlength=data.minibatch.n_regions * data.minibatch.n_cells, + ).reshape((data.minibatch.n_cells, data.minibatch.n_regions)) + + cell_region_embedding = total_cell_region_embedding - self.embedding_region_pooler.forward( + fragment_embedding[lost_fragments_oi], + lost_local_cellxregion_ix, + data.minibatch.n_cells, + data.minibatch.n_regions, ) - total_expression_predicted = self.embedding_to_expression.forward(cell_region_embedding) - - # with timing.catch("selecting_fragments"): - # for fragments_oi_ in fragments_oi: - # pass - - with timing.catch("main_loop"): - for fragments_oi_ in fragments_oi: - with timing.catch("within"): - if (fragments_oi_ is not None) and ((~fragments_oi_).sum().item() > min_fragments): - with timing.catch("lost"): - with timing.catch("selector"): - lost_fragments_oi = ~fragments_oi_ - lost_local_cellxregion_ix = data.fragments.local_cellxregion_ix[lost_fragments_oi] - with timing.catch("counter"): - n_fragments = total_n_fragments - torch.bincount( - lost_local_cellxregion_ix, - minlength=data.minibatch.n_regions * data.minibatch.n_cells, - ).reshape((data.minibatch.n_cells, data.minibatch.n_regions)) - - with timing.catch("pooler"): - cell_region_embedding = ( - total_cell_region_embedding - - self.embedding_region_pooler.forward( - fragment_embedding[lost_fragments_oi], - lost_local_cellxregion_ix, - data.minibatch.n_cells, - data.minibatch.n_regions, - ) - ) - - with timing.catch("libsize"): - if hasattr(self, "library_size_encoder"): - cell_region_embedding = torch.cat( - [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1 - ) - - with timing.catch("embedding_to_expression"): - # changed = torch.where((n_fragments != total_n_fragments).any(1))[0] - # print(expression_predicted.shape) - - expression_predicted = self.embedding_to_expression.forward(cell_region_embedding) - # expression_predicted = total_expression_predicted.clone() - # expression_predicted[changed] = self.embedding_to_expression.forward(cell_region_embedding[changed]) - else: - n_fragments = total_n_fragments - expression_predicted = total_expression_predicted + if hasattr(self, "library_size_encoder"): + cell_region_embedding = torch.cat( + [cell_region_embedding, self.library_size_encoder(data).unsqueeze(-2)], dim=-1 + ) - with timing.catch("yield"): - yield expression_predicted, n_fragments + expression_predicted = self.embedding_to_expression.forward(cell_region_embedding) + else: + n_fragments = total_n_fragments + expression_predicted = total_expression_predicted - print(timing.times) + yield expression_predicted, n_fragments def forward_multiple2(self, data, fragments_oi, min_fragments=1): fragment_embedding = self.fragment_embedder(data) @@ -734,7 +676,7 @@ def train_model( pbar=True, n_regions_step=1, n_cells_step=20000, - weight_decay=1e-2, + weight_decay=1e-1, checkpoint_every_epoch=1, optimizer="adam", n_cells_train=None, @@ -846,7 +788,6 @@ def train_model( self.trace = trainer.trace trainer.train() - # trainer.trace.plot() def get_prediction( self, @@ -1039,15 +980,11 @@ def get_prediction_censored( pred_mb, n_fragments_oi_mb, ) in enumerate(self.forward_multiple(data, fragments_oi, min_fragments=min_fragments)): - # predicted.append(pred_mb.cpu().numpy()) - # n_fragments.append(n_fragments_oi_mb.cpu().numpy()) predicted.append(pred_mb) n_fragments.append(n_fragments_oi_mb) expected = data.transcriptome.value.cpu().numpy() self.to("cpu") - # predicted = np.stack(predicted, axis=0) - # n_fragments = np.stack(n_fragments, axis=0) predicted = torch.stack(predicted, axis=0).cpu().numpy() n_fragments = torch.stack(n_fragments, axis=0).cpu().numpy() @@ -1168,13 +1105,18 @@ def models_path(self): path.mkdir(exist_ok=True) return path - def train_models(self, device=None, pbar=True, regions_oi=None, **kwargs): + def train_models( + self, device=None, pbar=True, transcriptome=None, fragments=None, folds=None, regions_oi=None, **kwargs + ): if "device" in self.train_params and device is None: device = self.train_params["device"] - fragments = self.fragments - transcriptome = self.transcriptome - folds = self.folds + if fragments is None: + fragments = self.fragments + if transcriptome is None: + transcriptome = self.transcriptome + if folds is None: + folds = self.folds if regions_oi is None: if self.regions_oi is None: @@ -1234,22 +1176,35 @@ def get_region_cors(self, fragments, transcriptome, folds, device=None): cor_n_fragments = np.zeros((len(fragments.var.index), len(folds))) n_fragments = np.zeros((len(fragments.var.index), len(folds))) + regions_oi = fragments.var.index if self.regions_oi is None else self.regions_oi + + from itertools import product + + cors = [] + if device is None: device = get_default_device() - for model_ix, (model, fold) in enumerate(zip(self, folds)): - prediction = model.get_prediction(fragments, transcriptome, cell_ixs=fold["cells_test"], device=device) - - cor_predicted[:, model_ix] = paircor(prediction["predicted"].values, prediction["expected"].values) - cor_n_fragments[:, model_ix] = paircor(prediction["n_fragments"].values, prediction["expected"].values) + for region_id, (fold_ix, fold) in product(regions_oi, enumerate(folds)): + if region_id + "_" + str(fold_ix) in self: + model = self[region_id + "_" + str(fold_ix)] + prediction = model.get_prediction(fragments, transcriptome, cell_ixs=fold["cells_test"], device=device) + + cors.append( + { + fragments.var.index.name: region_id, + "cor": np.corrcoef(prediction["predicted"].values[:, 0], prediction["expected"].values[:, 0])[ + 0, 1 + ], + "cor_n_fragments": np.corrcoef( + prediction["n_fragments"].values[:, 0], prediction["expected"].values[:, 0] + )[0, 1], + } + ) - n_fragments[:, model_ix] = prediction["n_fragments"].values.sum(0) - cor_predicted = pd.Series(cor_predicted.mean(1), index=fragments.var.index, name="cor_predicted") - cor_n_fragments = pd.Series(cor_n_fragments.mean(1), index=fragments.var.index, name="cor_n_fragments") - n_fragments = pd.Series(n_fragments.mean(1), index=fragments.var.index, name="n_fragments") - result = pd.concat([cor_predicted, cor_n_fragments, n_fragments], axis=1) - result["deltacor"] = result["cor_predicted"] - result["cor_n_fragments"] + cors = pd.DataFrame(cors).set_index(fragments.var.index.name) + cors["deltacor"] = cors["cor"] - cors["cor_n_fragments"] - return result + return cors @property def design(self): diff --git a/src/chromatinhd/models/pred/plot/copredictivity.py b/src/chromatinhd/models/pred/plot/copredictivity.py index f4d93c1d..26d94024 100644 --- a/src/chromatinhd/models/pred/plot/copredictivity.py +++ b/src/chromatinhd/models/pred/plot/copredictivity.py @@ -3,6 +3,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +import itertools class Copredictivity(chromatinhd.grid.Panel): @@ -49,3 +50,43 @@ def from_regionpairwindow(cls, regionpairwindow, gene, width): """ plotdata = regionpairwindow.get_plotdata(gene).reset_index() return cls(plotdata, width) + + +class CopredictivityBroken(chromatinhd.grid.Panel): + """ + Plot co-predictivity for different regions + """ + + def __init__(self, plotdata, breaking): + pass + + @classmethod + def from_regionpairwindow(cls, regionpairwindow, gene, breaking): + plotdata_windows = regionpairwindow.scores[gene].mean("fold").to_dataframe() + plotdata_interaction = regionpairwindow.interaction[gene].mean("fold").to_pandas().unstack().to_frame("cor") + + plotdata = plotdata_interaction.copy() + + # make plotdata, making sure we have all window combinations, otherwise nan + plotdata = ( + pd.DataFrame(itertools.combinations(windows.index, 2), columns=["window1", "window2"]) + .set_index(["window1", "window2"]) + .join(plotdata_interaction) + ) + plotdata.loc[np.isnan(plotdata["cor"]), "cor"] = 0.0 + plotdata["dist"] = ( + windows.loc[plotdata.index.get_level_values("window2"), "window_mid"].values + - windows.loc[plotdata.index.get_level_values("window1"), "window_mid"].values + ) + + transform = chd.grid.broken.TransformBroken(breaking) + plotdata["window1_broken"] = transform( + windows.loc[plotdata.index.get_level_values("window1"), "window_mid"].values + ) + plotdata["window2_broken"] = transform( + windows.loc[plotdata.index.get_level_values("window2"), "window_mid"].values + ) + + plotdata = plotdata.loc[~pd.isnull(plotdata["window1_broken"]) & ~pd.isnull(plotdata["window2_broken"])] + + plotdata.loc[plotdata["dist"] < 1000, "cor"] = 0.0 diff --git a/src/chromatinhd/plot/genome/genes.py b/src/chromatinhd/plot/genome/genes.py index 7e52591b..ff3ba680 100644 --- a/src/chromatinhd/plot/genome/genes.py +++ b/src/chromatinhd/plot/genome/genes.py @@ -82,6 +82,7 @@ def __init__( width, full_ticks=False, label_genome=False, + annotate_tss=True, symbol=None, ): super().__init__((width, len(plotdata_genes) * 0.08 + 0.01)) @@ -204,7 +205,8 @@ def __init__( ax.add_patch(rect) # vline at tss - ax.axvline(0, color="#888888", lw=0.5, zorder=-1, dashes=(2, 2)) + if annotate_tss: + ax.axvline(0, color="#888888", lw=0.5, zorder=-1, dashes=(2, 2)) @classmethod def from_region(cls, region, genome="GRCh38", window=None, use_cache=True, show_genes=True, **kwargs): diff --git a/src/chromatinhd/utils/timing.py b/src/chromatinhd/utils/timing.py new file mode 100644 index 00000000..beff5ec8 --- /dev/null +++ b/src/chromatinhd/utils/timing.py @@ -0,0 +1,24 @@ +import collections +import time + + +class catchtime(object): + def __init__(self, dict, name): + self.name = name + self.dict = dict + + def __enter__(self): + self.t = time.time() + return self + + def __exit__(self, type, value, traceback): + self.t = time.time() - self.t + self.dict[self.name] += self.t + + +class timer(object): + def __init__(self): + self.times = collections.defaultdict(float) + + def catch(self, name): + return catchtime(self.times, name)