Skip to content

Commit

Permalink
updates to chd
Browse files Browse the repository at this point in the history
  • Loading branch information
zouter committed Jan 2, 2024
1 parent fe7a8ef commit e7124d5
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 79 deletions.
1 change: 1 addition & 0 deletions src/chromatinhd/biomart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .dataset import Dataset
from .tss import get_canonical_transcripts, get_exons, get_transcripts, map_symbols
from . import tss
from .homology import get_orthologs

__all__ = ["Dataset", "get_canonical_transcripts", "get_exons", "get_transcripts", "tss"]
24 changes: 24 additions & 0 deletions src/chromatinhd/biomart/homology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from .dataset import Dataset
import numpy as np


def get_orthologs(biomart_dataset: Dataset, gene_ids, organism="mmusculus"):
"""
Map ensembl gene ids to orthologs in another organism
"""

gene_ids_to_map = np.unique(gene_ids)
mapping = biomart_dataset.get_batched(
[
biomart_dataset.attribute("ensembl_gene_id"),
biomart_dataset.attribute("external_gene_name"),
biomart_dataset.attribute(f"{organism}_homolog_ensembl_gene"),
biomart_dataset.attribute(f"{organism}_homolog_associated_gene_name"),
],
filters=[
biomart_dataset.filter("ensembl_gene_id", value=gene_ids_to_map),
],
)
mapping = mapping.groupby("ensembl_gene_id").first()

return mapping[f"{organism}_homolog_ensembl_gene"].reindex(gene_ids).values
7 changes: 7 additions & 0 deletions src/chromatinhd/data/motifscan/motifscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def get_slice(
return_indptr=False,
return_scores=True,
return_strands=True,
motif_ixs=None,
):
"""
Get a slice of the motifscan
Expand Down Expand Up @@ -450,11 +451,17 @@ def get_slice(

coordinates = self.coordinates[indptr_start:indptr_end]
indices = self.indices[indptr_start:indptr_end]

out = [coordinates, indices]
if return_scores:
out.append(self.scores[indptr_start:indptr_end])
if return_strands:
out.append(self.strands[indptr_start:indptr_end])

if motif_ixs is not None:
selection = np.isin(indices, motif_ixs)
out = [x[selection] for x in out]

if return_indptr:
out.append(indptr)

Expand Down
4 changes: 3 additions & 1 deletion src/chromatinhd/data/motifscan/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def __init__(self, motifscan, gene, motifs_oi, breaking, group_info=None, panel_
motifs_oi, group_info, motifdata = _process_grouped_motifs(gene, motifs_oi, motifscan, group_info=group_info)

for group, group_info_oi in group_info.iterrows():
broken = self.add_under(Broken(breaking, height=panel_height), padding=0)
broken = self.add_under(
Broken(breaking, height=panel_height, margin_height=0.0, padding_height=0.01), padding=0
)
group_motifs = motifs_oi.query("group == @group")

panel, ax = broken[0, -1]
Expand Down
12 changes: 9 additions & 3 deletions src/chromatinhd/data/motifscan/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def get_slice(
return_scores=True,
return_strands=True,
return_indptr=False,
motif_ixs=None,
) -> tuple:
"""
Get the positions/scores/strandedness of motifs within a slice of the motifscan
Expand Down Expand Up @@ -226,15 +227,20 @@ def get_slice(

indptr_start, indptr_end = indptr[0], indptr[-1]

out = []
positions = (self.parent.coordinates[indptr_start:indptr_end] - region["tss"]) * region["strand"]
indices = self.parent.indices[indptr_start:indptr_end]

out.append((self.parent.coordinates[indptr_start:indptr_end] - region["tss"]) * region["strand"])
out.append(self.parent.indices[indptr_start:indptr_end])
out = [positions, indices]

if return_scores:
out.append(self.parent.scores[indptr_start:indptr_end])
if return_strands:
out.append(self.parent.strands[indptr_start:indptr_end])

if motif_ixs is not None:
selection = np.isin(indices, motif_ixs)
out = [x[selection] for x in out]

if return_indptr:
out.append(indptr - indptr_start)

Expand Down
150 changes: 81 additions & 69 deletions src/chromatinhd/data/peakcounts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import pathlib


def center_peaks(peaks, promoter):
def center_peaks(peaks, promoter, columns=["start", "end"]):
peaks = peaks.copy()
if peaks.shape[0] == 0:
peaks = pd.DataFrame(columns=["start", "end"])
else:
peaks[["start", "end"]] = [
peaks[columns] = [
[
(peak["start"] - promoter["tss"]) * int(promoter["strand"]),
(peak["end"] - promoter["tss"]) * int(promoter["strand"]),
Expand All @@ -23,12 +23,12 @@ def center_peaks(peaks, promoter):
return peaks


def uncenter_peaks(peaks, promoter):
def uncenter_peaks(peaks, promoter, columns=["start", "end"]):
peaks = peaks.copy()
if peaks.shape[0] == 0:
peaks = pd.DataFrame(columns=["start", "end"])
else:
peaks[["start", "end"]] = [
peaks[columns] = [
[
(peak["start"] * int(promoter["strand"]) + promoter["tss"]),
(peak["end"] * int(promoter["strand"]) + promoter["tss"]),
Expand All @@ -38,6 +38,26 @@ def uncenter_peaks(peaks, promoter):
return peaks


def uncenter_multiple_peaks(slices, coordinates):
if "region_ix" not in slices.columns:
slices["region_ix"] = coordinates.index.get_indexer(slices["region"])
coordinates_oi = coordinates.iloc[slices["region_ix"]].copy()

slices["chrom"] = coordinates_oi["chrom"].values

slices["start_genome"] = np.where(
coordinates_oi["strand"] == 1,
(slices["start"] * coordinates_oi["strand"].astype(int).values + coordinates_oi["tss"].values),
(slices["end"] * coordinates_oi["strand"].astype(int).values + coordinates_oi["tss"].values),
)
slices["end_genome"] = np.where(
coordinates_oi["strand"] == 1,
(slices["end"] * coordinates_oi["strand"].astype(int).values + coordinates_oi["tss"].values),
(slices["start"] * coordinates_oi["strand"].astype(int).values + coordinates_oi["tss"].values),
)
return slices


def get_usecols_and_names(peakcaller):
if peakcaller in ["macs2_leiden_0.1"]:
usecols = [0, 1, 2, 6]
Expand Down Expand Up @@ -86,32 +106,7 @@ def __init__(
for peakcaller, peaks_peakcaller in peaks.groupby("peakcaller"):
y = peakcallers.index.get_loc(peakcaller)

if len(peaks_peakcaller) == 0:
continue
if ("cluster" not in peaks_peakcaller.columns) or pd.isnull(peaks_peakcaller["cluster"]).all():
for _, peak in peaks_peakcaller.iterrows():
rect = mpl.patches.Rectangle(
(peak["start"], y),
peak["end"] - peak["start"],
1,
fc="#333",
lw=0,
)
ax.add_patch(rect)
ax.plot([peak["start"]] * 2, [y, y + 1], color="grey", lw=0.5)
ax.plot([peak["end"]] * 2, [y, y + 1], color="grey", lw=0.5)
else:
n_clusters = peaks_peakcaller["cluster"].max() + 1
h = 1 / n_clusters
for _, peak in peaks_peakcaller.iterrows():
rect = mpl.patches.Rectangle(
(peak["start"], y + peak["cluster"] / n_clusters),
peak["end"] - peak["start"],
h,
fc="#333",
lw=0,
)
ax.add_patch(rect)
_plot_peaks(ax, peaks_peakcaller, y)
if y > 0:
ax.axhline(y, color="#DDD", zorder=10, lw=0.5)

Expand Down Expand Up @@ -174,9 +169,9 @@ def __init__(
super().__init__(breaking, height=row_height * len(peakcallers) / 5)

# y axis
ax = self.elements[0, -1]
panel, ax = self[0, -1]

ax.set_ylim(len(peakcallers), 0)
# label methods
if label_methods:
ax.set_yticks(np.arange(len(peakcallers)) + 0.5)
ax.set_yticks(np.arange(len(peakcallers) + 1), minor=True)
Expand All @@ -188,12 +183,6 @@ def __init__(
else:
ax.set_yticks([])

if label_rows is True:
ax.set_ylabel("Putative\nCREs", rotation=0, ha="right", va="center")
elif label_rows is not False:
ax.set_ylabel(label_rows, rotation=0, ha="right", va="center")
else:
ax.set_ylabel("")
ax.tick_params(
axis="y",
which="major",
Expand All @@ -202,6 +191,7 @@ def __init__(
right=label_methods_side == "right",
left=not label_methods_side == "left",
)

ax.tick_params(
axis="y",
which="minor",
Expand All @@ -212,39 +202,32 @@ def __init__(
)
ax.yaxis.tick_right()

ax.set_xticks([])
# label y
panel, ax = self[0, 0]
if label_rows is True:
ax.set_ylabel("Putative\nCREs", rotation=0, ha="right", va="center")
elif label_rows is not False:
ax.set_ylabel(label_rows, rotation=0, ha="right", va="center")
else:
ax.set_ylabel("")

# plot peaks
# set ylim for each panel
for (region, region_info), (panel, ax) in zip(breaking.regions.iterrows(), self):
for peakcaller, peaks_peakcaller in peaks.groupby("peakcaller"):
y = peakcallers.index.get_loc(peakcaller)

if len(peaks_peakcaller) == 0:
continue
if ("cluster" not in peaks_peakcaller.columns) or pd.isnull(peaks_peakcaller["cluster"]).all():
for _, peak in peaks_peakcaller.iterrows():
rect = mpl.patches.Rectangle(
(peak["start"], y),
peak["end"] - peak["start"],
1,
fc="#333",
lw=0,
)
ax.add_patch(rect)
ax.plot([peak["start"]] * 2, [y, y + 1], color="grey", lw=0.5)
ax.plot([peak["end"]] * 2, [y, y + 1], color="grey", lw=0.5)
else:
n_clusters = peaks_peakcaller["cluster"].max() + 1
h = 1 / n_clusters
for _, peak in peaks_peakcaller.iterrows():
rect = mpl.patches.Rectangle(
(peak["start"], y + peak["cluster"] / n_clusters),
peak["end"] - peak["start"],
h,
fc="#333",
lw=0,
)
ax.add_patch(rect)
ax.set_xticks([])
ax.set_ylim(len(peakcallers), 0)

# plot peaks
for peakcaller, peaks_peakcaller in peaks.groupby("peakcaller"):
y = peakcallers.index.get_loc(peakcaller)
for (region, region_info), (panel, ax) in zip(breaking.regions.iterrows(), self):
plotdata = peaks_peakcaller.loc[
~(
(peaks_peakcaller["start"] > region_info["end"])
| (peaks_peakcaller["end"] < region_info["start"])
)
]

_plot_peaks(ax, plotdata, y)
if y > 0:
ax.axhline(y, color="#DDD", zorder=10, lw=0.5)

Expand Down Expand Up @@ -275,3 +258,32 @@ def _get_peaks(region, peakcallers):
peaks = pd.concat(peaks).reset_index().set_index(["peakcaller", "peak"])
peaks["size"] = peaks["end"] - peaks["start"]
return peaks


def _plot_peaks(ax, plotdata, y):
if len(plotdata) == 0:
return
if ("cluster" not in plotdata.columns) or pd.isnull(plotdata["cluster"]).all():
for _, peak in plotdata.iterrows():
rect = mpl.patches.Rectangle(
(peak["start"], y),
peak["end"] - peak["start"],
1,
fc="#333",
lw=0,
)
ax.add_patch(rect)
ax.plot([peak["start"]] * 2, [y, y + 1], color="grey", lw=0.5)
ax.plot([peak["end"]] * 2, [y, y + 1], color="grey", lw=0.5)
else:
n_clusters = plotdata["cluster"].max() + 1
h = 1 / n_clusters
for _, peak in plotdata.iterrows():
rect = mpl.patches.Rectangle(
(peak["start"], y + peak["cluster"] / n_clusters),
peak["end"] - peak["start"],
h,
fc="#333",
lw=0,
)
ax.add_patch(rect)
7 changes: 6 additions & 1 deletion src/chromatinhd/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,12 @@ def __getitem__(self, key):
return list(self.elements)[key]

def get_bottom_left_corner(self):
return self.elements[self.ncol * ((len(self.elements) % self.ncol) - 1)]
nrow = (len(self.elements) - 1) // self.ncol
print(len(self.elements))
ix = (nrow) * self.ncol
print(nrow, self.ncol, ix)
return self.elements[ix]
# return self.elements[self.ncol * ((len(self.elements) % self.ncol) - 1)]


class WrapAutobreak(Wrap):
Expand Down
Loading

0 comments on commit e7124d5

Please sign in to comment.