Skip to content

Commit

Permalink
Add novel method to predict modality task (#339)
Browse files Browse the repository at this point in the history
* WIP integral copy

* WIP update

* split method  in train and run comp

* fix errors

* update train model

* refactor method comp

* rename method_predict and method_train api files

* refactor novel components (WIP)

* WIP refactor novel

* refactor train_test split novel method

* Add scritpt to test all test files

* Make dim a variable

* add hvg

* Remove unused code

* update train image

* Update predict part

* Update novel wf

* Add novel to benchmark

* update directives

* Update test scripts

* update train output

* fix config predict

* Update run config

* add submission info

* update lib ref

* Add batches to train data

* update subworkflow

* set output defaults

* reorder helper functions

* update directives

* fix directives

* set to hightime

* fix run config

* remove views

* Fix config

* Add pref norm

* Update directives

* Update predict config

* Remove setState

* Apply suggestion

Co-authored-by: Robrecht Cannoodt <[email protected]>

* add back setstate

* Add fix for test

* prevent divide by zero

* Add workaround for nextflow error

* Fix if variable is empty

* Update dataset_name for neurips 2021

---------

Co-authored-by: Robrecht Cannoodt <[email protected]>
  • Loading branch information
KaiWaldrant and rcannood authored Aug 26, 2024
1 parent d0bc03f commit 7cea6a5
Show file tree
Hide file tree
Showing 17 changed files with 698 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ param_list:
input: "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE194nnn/GSE194122/suppl/GSE194122%5Fopenproblems%5Fneurips2021%5Fcite%5FBMMC%5Fprocessed%2Eh5ad%2Egz"
mod1: GEX
mod2: ADT
dataset_name: OpenProblems NeurIPS2021 CITE-Seq
dataset_name: NeurIPS2021 CITE-Seq
dataset_organism: homo_sapiens
dataset_summary: Single-cell CITE-Seq (GEX+ADT) data collected from bone marrow mononuclear cells of 12 healthy human donors.
dataset_description: "Single-cell CITE-Seq data collected from bone marrow mononuclear cells of 12 healthy human donors using the 10X 3 prime Single-Cell Gene Expression kit with Feature Barcoding in combination with the BioLegend TotalSeq B Universal Human Panel v1.0. The dataset was generated to support Multimodal Single-Cell Data Integration Challenge at NeurIPS 2021. Samples were prepared using a standard protocol at four sites. The resulting data was then annotated to identify cell types and remove doublets. The dataset was designed with a nested batch layout such that some donor samples were measured at multiple sites with some donors measured at a single site."
Expand All @@ -19,7 +19,7 @@ param_list:
input: "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE194nnn/GSE194122/suppl/GSE194122%5Fopenproblems%5Fneurips2021%5Fmultiome%5FBMMC%5Fprocessed%2Eh5ad%2Egz"
mod1: GEX
mod2: ATAC
dataset_name: OpenProblems NeurIPS2021 Multiome
dataset_name: NeurIPS2021 Multiome
dataset_organism: homo_sapiens
dataset_summary: Single-cell Multiome (GEX+ATAC) data collected from bone marrow mononuclear cells of 12 healthy human donors.
dataset_description: "Single-cell CITE-Seq data collected from bone marrow mononuclear cells of 12 healthy human donors using the 10X Multiome Gene Expression and Chromatin Accessibility kit. The dataset was generated to support Multimodal Single-Cell Data Integration Challenge at NeurIPS 2021. Samples were prepared using a standard protocol at four sites. The resulting data was then annotated to identify cell types and remove doublets. The dataset was designed with a nested batch layout such that some donor samples were measured at multiple sites with some donors measured at a single site."
Expand All @@ -35,21 +35,12 @@ output_state: '$id/state.yaml'
publish_dir: s3://openproblems-data/resources/datasets
HERE

cat > /tmp/nextflow.config << HERE
process {
withName:'.*publishStatesProc' {
memory = '16GB'
disk = '100GB'
}
}
HERE

tw launch https://github.com/openproblems-bio/openproblems-v2.git \
--revision main_build \
--pull-latest \
--main-script target/nextflow/datasets/workflows/process_openproblems_neurips2021_bmmc/main.nf \
--workspace 53907369739130 \
--compute-env 6TeIFgV5OY4pJCk8I0bfOh \
--params-file "$params_file" \
--config /tmp/nextflow.config \
--labels openproblems_neurips2021_bmmc,dataset_loader \
--config src/wf_utils/labels_tw.config \
--labels neurips2021,dataset_loader \
4 changes: 2 additions & 2 deletions src/tasks/predict_modality/api/comp_method_predict.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ functionality:
- name: "--input_train_mod1"
__merge__: file_train_mod1.yaml
direction: input
required: true
required: false
- name: "--input_train_mod2"
__merge__: file_train_mod2.yaml
direction: input
required: true
required: false
- name: "--input_test_mod1"
__merge__: file_test_mod1.yaml
direction: input
Expand Down
10 changes: 10 additions & 0 deletions src/tasks/predict_modality/api/file_common_dataset_mod1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ info:
name: hvg_score
description: A score for the feature indicating how highly variable it is.
required: true

- type: boolean
name: hvg
description: Whether or not the feature is considered to be a 'highly variable gene'
required: true

- type: double
name: hvg_score
description: A ranking of the features by hvg.
required: true
uns:
- type: string
name: dataset_id
Expand Down
10 changes: 10 additions & 0 deletions src/tasks/predict_modality/api/file_common_dataset_mod2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ info:
name: hvg_score
description: A score for the feature indicating how highly variable it is.
required: true

- type: boolean
name: hvg
description: Whether or not the feature is considered to be a 'highly variable gene'
required: true

- type: double
name: hvg_score
description: A ranking of the features by hvg.
required: true
uns:
- type: string
name: dataset_id
Expand Down
247 changes: 247 additions & 0 deletions src/tasks/predict_modality/methods/novel/helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import torch

from torch import nn
import torch.nn.functional as F

from torch.utils.data import Dataset

from typing import Optional

import anndata
import numpy as np
import pandas as pd
import scipy.sparse
import sklearn.decomposition
import sklearn.feature_extraction.text
import sklearn.preprocessing
import sklearn.neighbors
import sklearn.utils.extmath

class tfidfTransformer():
def __init__(self):
self.idf = None
self.fitted = False

def fit(self, X):
self.idf = X.shape[0] / X.sum(axis=0)
self.fitted = True

def transform(self, X):
if not self.fitted:
raise RuntimeError('Transformer was not fitted on any data')
if scipy.sparse.issparse(X):
tf = X.multiply(1 / X.sum(axis=1))
return tf.multiply(self.idf)
else:
tf = X / X.sum(axis=1, keepdims=True)
return tf * self.idf

def fit_transform(self, X):
self.fit(X)
return self.transform(X)

class lsiTransformer():
def __init__(self,
n_components: int = 20,
use_highly_variable = None
):
self.n_components = n_components
self.use_highly_variable = use_highly_variable
self.tfidfTransformer = tfidfTransformer()
self.normalizer = sklearn.preprocessing.Normalizer(norm="l1")
self.pcaTransformer = sklearn.decomposition.TruncatedSVD(n_components = self.n_components, random_state=777)
# self.lsi_mean = None
# self.lsi_std = None
self.fitted = None

def fit(self, adata: anndata.AnnData):
if self.use_highly_variable is None:
self.use_highly_variable = "hvg" in adata.var
adata_use = adata[:, adata.var["hvg"]] if self.use_highly_variable else adata
X = self.tfidfTransformer.fit_transform(adata_use.X)
X_norm = self.normalizer.fit_transform(X)
X_norm = np.log1p(X_norm * 1e4)
X_lsi = self.pcaTransformer.fit_transform(X_norm)
# self.lsi_mean = X_lsi.mean(axis=1, keepdims=True)
# self.lsi_std = X_lsi.std(axis=1, ddof=1, keepdims=True)
self.fitted = True

def transform(self, adata):
if not self.fitted:
raise RuntimeError('Transformer was not fitted on any data')
adata_use = adata[:, adata.var["hvg"]] if self.use_highly_variable else adata
X = self.tfidfTransformer.transform(adata_use.X)
X_norm = self.normalizer.transform(X)
X_norm = np.log1p(X_norm * 1e4)
X_lsi = self.pcaTransformer.transform(X_norm)
X_lsi -= X_lsi.mean(axis=1, keepdims=True)
X_lsi /= X_lsi.std(axis=1, ddof=1, keepdims=True)
lsi_df = pd.DataFrame(X_lsi, index = adata_use.obs_names)
return lsi_df

def fit_transform(self, adata):
self.fit(adata)
return self.transform(adata)

class ModalityMatchingDataset(Dataset):
def __init__(
self, df_modality1, df_modality2, is_train=True
):
super().__init__()
self.df_modality1 = df_modality1
self.df_modality2 = df_modality2
self.is_train = is_train
def __len__(self):
return self.df_modality1.shape[0]

def __getitem__(self, index: int):
if self.is_train == True:
x = self.df_modality1.iloc[index].values
y = self.df_modality2.iloc[index].values
return x, y
else:
x = self.df_modality1.iloc[index].values
return x

class Swish(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class Swish_module(nn.Module):
def forward(self, x):
return Swish.apply(x)

sigmoid = torch.nn.Sigmoid()

class ModelRegressionGex2Atac(nn.Module):
def __init__(self, dim_mod1, dim_mod2):
super(ModelRegressionGex2Atac, self).__init__()
#self.bn = torch.nn.BatchNorm1d(1024)
self.input_ = nn.Linear(dim_mod1, 1024)
self.fc = nn.Linear(1024, 256)
self.fc1 = nn.Linear(256, 2048)
self.dropout1 = nn.Dropout(p=0.298885630228993)
self.dropout2 = nn.Dropout(p=0.11289717442776658)
self.dropout3 = nn.Dropout(p=0.13523634924414762)
self.output = nn.Linear(2048, dim_mod2)
def forward(self, x):
x = F.gelu(self.input_(x))
x = self.dropout1(x)
x = F.gelu(self.fc(x))
x = self.dropout2(x)
x = F.gelu(self.fc1(x))
x = self.dropout3(x)
x = F.gelu(self.output(x))
return x

class ModelRegressionAtac2Gex(nn.Module): #
def __init__(self, dim_mod1, dim_mod2):
super(ModelRegressionAtac2Gex, self).__init__()
self.input_ = nn.Linear(dim_mod1, 2048)
self.fc = nn.Linear(2048, 2048)
self.fc1 = nn.Linear(2048, 512)
self.dropout1 = nn.Dropout(p=0.2649138776004753)
self.dropout2 = nn.Dropout(p=0.1769628308148758)
self.dropout3 = nn.Dropout(p=0.2516791883012817)
self.output = nn.Linear(512, dim_mod2)
def forward(self, x):
x = F.gelu(self.input_(x))
x = self.dropout1(x)
x = F.gelu(self.fc(x))
x = self.dropout2(x)
x = F.gelu(self.fc1(x))
x = self.dropout3(x)
x = F.gelu(self.output(x))
return x

class ModelRegressionAdt2Gex(nn.Module):
def __init__(self, dim_mod1, dim_mod2):
super(ModelRegressionAdt2Gex, self).__init__()
self.input_ = nn.Linear(dim_mod1, 512)
self.dropout1 = nn.Dropout(p=0.0)
self.swish = Swish_module()
self.fc = nn.Linear(512, 512)
self.fc1 = nn.Linear(512, 512)
self.fc2 = nn.Linear(512, 512)
self.output = nn.Linear(512, dim_mod2)
def forward(self, x):
x = F.gelu(self.input_(x))
x = F.gelu(self.fc(x))
x = F.gelu(self.fc1(x))
x = F.gelu(self.fc2(x))
x = F.gelu(self.output(x))
return x

class ModelRegressionGex2Adt(nn.Module):
def __init__(self, dim_mod1, dim_mod2):
super(ModelRegressionGex2Adt, self).__init__()
self.input_ = nn.Linear(dim_mod1, 512)
self.dropout1 = nn.Dropout(p=0.20335661386636347)
self.dropout2 = nn.Dropout(p=0.15395289261127876)
self.dropout3 = nn.Dropout(p=0.16902655078832815)
self.fc = nn.Linear(512, 512)
self.fc1 = nn.Linear(512, 2048)
self.output = nn.Linear(2048, dim_mod2)
def forward(self, x):
# x = self.batchswap_noise(x)
x = F.gelu(self.input_(x))
x = self.dropout1(x)
x = F.gelu(self.fc(x))
x = self.dropout2(x)
x = F.gelu(self.fc1(x))
x = self.dropout3(x)
x = F.gelu(self.output(x))
return x

def rmse(y, y_pred):
return np.sqrt(np.mean(np.square(y - y_pred)))

def train_and_valid(model, optimizer, loss_fn, dataloader_train, dataloader_test, name_model, device):
best_score = 100000
for i in range(100):
train_losses = []
test_losses = []
model.train()

for x, y in dataloader_train:
optimizer.zero_grad()
output = model(x.float().to(device))
loss = torch.sqrt(loss_fn(output, y.float().to(device)))
loss.backward()
train_losses.append(loss.item())
optimizer.step()

model.eval()
with torch.no_grad():
for x, y in dataloader_test:
output = model(x.float().to(device))
output[output<0] = 0.0
loss = torch.sqrt(loss_fn(output, y.float().to(device)))
test_losses.append(loss.item())

outputs = []
targets = []
model.eval()
with torch.no_grad():
for x, y in dataloader_test:
output = model(x.float().to(device))

outputs.append(output.detach().cpu().numpy())
targets.append(y.float().detach().cpu().numpy())
cat_outputs = np.concatenate(outputs)
cat_targets = np.concatenate(targets)
cat_outputs[cat_outputs<0.0] = 0

if best_score > rmse(cat_targets,cat_outputs):
torch.save(model.state_dict(), name_model)
best_score = rmse(cat_targets,cat_outputs)
print("best rmse: ", best_score)

25 changes: 25 additions & 0 deletions src/tasks/predict_modality/methods/novel/predict/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
__merge__: ../../../api/comp_method_predict.yaml
functionality:
name: novel_predict
arguments:
- name: "--input_transform"
type: file
direction: input
required: false
example: "lsi_transformer.pickle"
resources:
- type: python_script
path: script.py
- path: ../helper_functions.py
platforms:
- type: docker
image: ghcr.io/openproblems-bio/base_pytorch_nvidia:1.0.4
setup:
- type: python
packages:
- scikit-learn
- networkx
- type: nextflow
directives:
label: [highmem, hightime, midcpu, highsharedmem, gpu]

8 changes: 8 additions & 0 deletions src/tasks/predict_modality/methods/novel/predict/run_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

viash run src/tasks/predict_modality/methods/novel/predict/config.vsh.yaml -- \
--input_train_mod2 'resources/predict_modality/datasets/openproblems_neurips2021/bmmc_cite/normal/log_cp10k/train_mod2.h5ad' \
--input_test_mod1 'resources/predict_modality/datasets/openproblems_neurips2021/bmmc_cite/normal/log_cp10k/test_mod1.h5ad' \
--input_model output/novel/model.pt \
--input_transform output/novel/lsi_transform.pickle \
--output 'output/novel/novel_test.h5ad'
Loading

0 comments on commit 7cea6a5

Please sign in to comment.