Skip to content

Commit

Permalink
remove torch_scatter dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
zouter committed Oct 24, 2024
1 parent 5361f3e commit 0408663
Show file tree
Hide file tree
Showing 46 changed files with 1,624 additions and 5,193 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,7 @@ docs/tex/overview.fls
docs/tex/overview.pdf
docs/tex/overview.synctex.gz

output/*
output/*


src/*.c
10 changes: 5 additions & 5 deletions docs/source/benchmark/diff/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,24 +251,24 @@
model_info["ix"] = np.arange(model_info.shape[0])

# %%
fig = chd.grid.Figure(chd.grid.Wrap(padding_width=0.1))
fig = polyptich.grid.Figure(polyptich.grid.Wrap(padding_width=0.1))
height = len(scores) * 0.2

plotdata = scores.copy().loc[model_info.index]

panel, ax = fig.main.add(chd.grid.Ax((1, height)))
panel, ax = fig.main.add(polyptich.grid.Panel((1, height)))
ax.barh(plotdata.index, plotdata["lr_test"])
ax.axvline(0, color="black", linestyle="--", lw=1)
ax.set_title("Test")
ax.set_xlabel("Log-likehood ratio")

panel, ax = fig.main.add(chd.grid.Ax((1, height)))
panel, ax = fig.main.add(polyptich.grid.Panel((1, height)))
ax.set_yticks([])
ax.barh(plotdata.index, plotdata["lr_validation"])
ax.axvline(0, color="black", linestyle="--", lw=1)
ax.set_title("Validation")

panel, ax = fig.main.add(chd.grid.Ax((1, height)))
panel, ax = fig.main.add(polyptich.grid.Panel((1, height)))
ax.set_yticks([])
ax.barh(plotdata.index, plotdata["lr_train"])
ax.axvline(0, color="black", linestyle="--", lw=1)
Expand Down Expand Up @@ -304,7 +304,7 @@
genepositional.score(fragments, clustering, [models[model_id]], force=True, genes=transcriptome.gene_id([symbol]))

# %%
fig = chd.grid.Figure(chd.grid.Grid(padding_height=0.05, padding_width=0.05))
fig = polyptich.grid.Figure(polyptich.grid.Grid(padding_height=0.05, padding_width=0.05))
width = 10

region = fragments.regions.coordinates.loc[transcriptome.gene_id(symbol)]
Expand Down
22 changes: 11 additions & 11 deletions docs/source/benchmark/diff/simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,10 @@ def apply_trunc_normal(value, loc, scale, a=0, b=1):
model_info["ix"] = np.arange(model_info.shape[0])

# %%
fig = chd.grid.Figure(chd.grid.Wrap(padding_width=0.1))
fig = polyptich.grid.Figure(polyptich.grid.Wrap(padding_width=0.1))
height = len(model_info) * 0.2

panel, ax = fig.main.add(chd.grid.Ax((1, height)))
panel, ax = fig.main.add(polyptich.grid.Panel((1, height)))
plotdata = scores.xs("test", level="phase").loc[model_info.index]

ax.barh(plotdata.index, plotdata["lr"])
Expand All @@ -245,15 +245,15 @@ def apply_trunc_normal(value, loc, scale, a=0, b=1):
ax.set_title("Test")
ax.set_xlabel("Log-likehood ratio")

panel, ax = fig.main.add(chd.grid.Ax((1, height)))
panel, ax = fig.main.add(polyptich.grid.Panel((1, height)))
plotdata = scores.xs("validation", level="phase").loc[model_info.index]
ax.set_yticks([])
ax.barh(plotdata.index, plotdata["lr"])
ax.barh(plotdata.index, plotdata["lr_position"], alpha=0.5)
ax.axvline(0, color="black", linestyle="--", lw=1)
ax.set_title("Validation")

panel, ax = fig.main.add(chd.grid.Ax((1, height)))
panel, ax = fig.main.add(polyptich.grid.Panel((1, height)))
plotdata = scores.xs("train", level="phase").loc[model_info.index]
ax.set_yticks([])
ax.barh(plotdata.index, plotdata["lr"])
Expand Down Expand Up @@ -296,9 +296,9 @@ def apply_trunc_normal(value, loc, scale, a=0, b=1):
)

# %%
fig = chd.grid.Figure(chd.grid.Grid(padding_height=0))
fig = polyptich.grid.Figure(polyptich.grid.Grid(padding_height=0))
width = 10
panel, ax = fig.main.add_under(chd.grid.Panel((width, 0.5)))
panel, ax = fig.main.add_under(polyptich.grid.Panel((width, 0.5)))
plotdata = design.loc[design["size"] == 0].set_index(["cluster_ix", "coordinate"])[["prob_left"]]
for cluster_ix, plotdata_cluster in plotdata.groupby("cluster_ix"):
plotdata_cluster = plotdata_cluster.droplevel("cluster_ix").sort_index()
Expand All @@ -312,7 +312,7 @@ def apply_trunc_normal(value, loc, scale, a=0, b=1):
ax2.scatter(plotdata["center"], [0] * len(plotdata), c=plotdata["size_mean"])
ax2.set_xlim(*fragments.regions.window)

panel, ax = fig.main.add_under(chd.grid.Panel((width, 2)))
panel, ax = fig.main.add_under(polyptich.grid.Panel((width, 2)))
plotdata = np.exp(design.groupby(["size", "coordinate"]).mean()["prob_right"].unstack())
ax.matshow(plotdata, aspect="auto", extent=(*fragments.regions.window, *plotdata.index[[-1, 0]]), cmap="viridis")
ax.set_xticks([])
Expand All @@ -324,14 +324,14 @@ def apply_trunc_normal(value, loc, scale, a=0, b=1):
fig.plot()

# %%
main = chd.grid.Grid(padding_height=0.1)
fig = chd.grid.Figure(main)
main = polyptich.grid.Grid(padding_height=0.1)
fig = polyptich.grid.Figure(main)

nbins = np.array(model.mixture.transform.nbins)
bincuts = np.concatenate([[0], np.cumsum(nbins)])
binmids = bincuts[:-1] + nbins / 2

ax = main[0, 0] = chd.grid.Ax((10, 0.25))
ax = main[0, 0] = polyptich.grid.Panel((10, 0.25))
ax = ax.ax
plotdata = (model.mixture.transform.unnormalized_heights.data.cpu().numpy())[[gene_ix]]
ax.imshow(plotdata, aspect="auto")
Expand All @@ -342,7 +342,7 @@ def apply_trunc_normal(value, loc, scale, a=0, b=1):
ax.set_xticks([])
ax.set_ylabel("$h_0$", rotation=0, ha="right", va="center")

ax = main[1, 0] = chd.grid.Ax(dim=(10, model.n_clusters * 0.25))
ax = main[1, 0] = polyptich.grid.Panel(dim=(10, model.n_clusters * 0.25))
ax = ax.ax
plotdata = model.decoder.delta_height_weight.data[gene_ix].cpu().numpy()
ax.imshow(plotdata, aspect="auto", cmap=mpl.cm.RdBu_r, vmax=np.log(2), vmin=np.log(1 / 2))
Expand Down
6 changes: 3 additions & 3 deletions docs/source/benchmark/diff/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,20 @@
scores.loc[device, "inference"] = end - start

# %%
fig = chd.grid.Figure(chd.grid.Wrap(padding_width=0.1))
fig = polyptich.grid.Figure(polyptich.grid.Wrap(padding_width=0.1))
height = len(scores) * 0.2

plotdata = scores.copy().loc[devices.index]

panel, ax = fig.main.add(chd.grid.Ax((1, height)))
panel, ax = fig.main.add(polyptich.grid.Panel((1, height)))
ax.barh(plotdata.index, plotdata["train"])
ax.set_yticks(np.arange(len(devices)))
ax.set_yticklabels(devices.label)
ax.axvline(0, color="black", linestyle="--", lw=1)
ax.set_title("Training")
ax.set_xlabel("seconds")

panel, ax = fig.main.add(chd.grid.Ax((1, height)))
panel, ax = fig.main.add(polyptich.grid.Panel((1, height)))
ax.barh(plotdata.index, plotdata["inference"])
ax.axvline(0, color="black", linestyle="--", lw=1)
ax.set_title("Inference")
Expand Down
6 changes: 1 addition & 5 deletions docs/source/quickstart/1_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,6 @@
folds = chd.data.folds.Folds(dataset_folder / "folds" / "5x1").sample_cells(fragments, 5, 1)
folds

# %%
folds = chd.data.folds.Folds(dataset_folder / "folds" / "5x5").sample_cells(fragments, 5, 5)
folds

# %% [markdown]
# ## Optional data¶

Expand Down Expand Up @@ -215,7 +211,7 @@
# %%
import genomepy

genomepy.install_genome("GRCh38", genomes_dir="/data/genome/")
genomepy.install_genome("GRCh38", genomes_dir="/srv/data/genome/")

fasta_file = "/data/genome/GRCh38/GRCh38.fa"

Expand Down
10 changes: 5 additions & 5 deletions docs/source/quickstart/2_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
# %%
symbol = "IRF1"

fig = chd.grid.Figure(chd.grid.Grid(padding_height=0.05))
fig = polyptich.grid.Figure(polyptich.grid.Grid(padding_height=0.05))
width = 10

region = fragments.regions.coordinates.loc[transcriptome.gene_id(symbol)]
Expand Down Expand Up @@ -166,10 +166,10 @@
# %%
# decrease the lost_cutoff to see more regions
regions = regionmultiwindow.select_regions(gene_id, lost_cutoff = 0.15)
breaking = chd.grid.Breaking(regions)
breaking = polyptich.grid.Breaking(regions)

# %%
fig = chd.grid.Figure(chd.grid.Grid(padding_height=0.05))
fig = polyptich.grid.Figure(polyptich.grid.Grid(padding_height=0.05))

region = fragments.regions.coordinates.loc[transcriptome.gene_id(symbol)]
panel_genes = chd.plot.genome.genes.GenesBroken.from_region(region, breaking=breaking, genome = "GRCh38")
Expand Down Expand Up @@ -208,10 +208,10 @@

# %%
windows = regionmultiwindow.select_regions(gene_id, lost_cutoff = 0.2)
breaking = chd.grid.Breaking(windows)
breaking = polyptich.grid.Breaking(windows)

# %%
fig = chd.grid.Figure(chd.grid.Grid(padding_height=0.05))
fig = polyptich.grid.Figure(polyptich.grid.Grid(padding_height=0.05))
width = 10

# genes
Expand Down
6 changes: 3 additions & 3 deletions docs/source/quickstart/3_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
window = [-10000, 10000]

# %%
fig = chd.grid.Figure(chd.grid.Grid(padding_height=0.05, padding_width=0.05))
fig = polyptich.grid.Figure(polyptich.grid.Grid(padding_height=0.05, padding_width=0.05))
width = 10

region = fragments.regions.coordinates.loc[transcriptome.gene_id(symbol)]
Expand All @@ -115,10 +115,10 @@

# %%
windows = regionpositional.select_windows(gene_id)
breaking = chd.grid.Breaking(windows)
breaking = polyptich.grid.Breaking(windows)

# %%
fig = chd.grid.Figure(chd.grid.Grid(padding_height=0.05, padding_width=0.05))
fig = polyptich.grid.Figure(polyptich.grid.Grid(padding_height=0.05, padding_width=0.05))
width = 10

region = fragments.regions.coordinates.loc[transcriptome.gene_id(symbol)]
Expand Down
8 changes: 1 addition & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,10 @@ disable = [
]

[tool.ruff]
line-length = 500
ignore-init-module-imports = true
ignore = ['F401']
line-length = 90
include = ['src/**/*.py']
exclude = ['scripts/*']

[tool.black]
line-length = 120
target-version = ['py37', 'py38']


[tool.jupytext]
formats = "ipynb,py:percent"
4 changes: 2 additions & 2 deletions src/chromatinhd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from . import sparse
from . import utils
from . import flow
from . import grid
from . import plot
from . import data
from . import train
from . import embedding
from . import optim
from . import biomart
from . import models
from polyptich import grid

__all__ = [
"get_git_root",
Expand All @@ -26,8 +26,8 @@
"train",
"embedding",
"optim",
"grid",
"biomart",
"models",
"plot",
"grid",
]
4 changes: 2 additions & 2 deletions src/chromatinhd/biomart/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def attribute(self, name, **kwargs):
def filter(self, name, **kwargs):
return Filter(name, **kwargs)

def get(self, attributes=[], filters=[], use_cache=True) -> pd.DataFrame:
def get(self, attributes=[], filters=[], use_cache=True, timeout = 20) -> pd.DataFrame:
"""
Get the result with a given set of attributes and filters
Expand Down Expand Up @@ -176,7 +176,7 @@ def get(self, attributes=[], filters=[], use_cache=True) -> pd.DataFrame:
result = cache[url]
else:
try:
response = requests.get(url, timeout=10)
response = requests.get(url, timeout=timeout)
except requests.exceptions.Timeout:
raise ValueError("Ensembl web service timed out")
# check response status
Expand Down
11 changes: 8 additions & 3 deletions src/chromatinhd/data/associations/plot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from chromatinhd.grid.broken import Broken, Panel
from polyptich.grid.broken import Broken, Panel
import pandas as pd
import adjustText
import seaborn as sns

import matplotlib as mpl

def center_position(peaks, region):
peaks = peaks.copy()
Expand Down Expand Up @@ -136,7 +136,7 @@ def get_snp_color(snp_main):
)

for i, (index, row) in enumerate(plotdata_region.iterrows()):
ax.text(
text = ax.text(
row["pos"],
0.8,
row["rsid"],
Expand All @@ -145,6 +145,11 @@ def get_snp_color(snp_main):
va="bottom",
color=colors[i],
)
text.set_path_effects(
[
mpl.patheffects.withStroke(linewidth=2, foreground="white"),
]
)
# ax.scatter(
# plotdata_region["position"],
# [1] * len(plotdata_region),
Expand Down
Loading

0 comments on commit 0408663

Please sign in to comment.