Skip to content

Commit

Permalink
Simple MLP attempt #2 (#441)
Browse files Browse the repository at this point in the history
* make gitignore more strict

* update api files

Co-authored-by: Kai Waldrant <[email protected]>

* create separate file for pretrained model

* add train component

* clean up resources

* wip components

* wip refactor

* fix issues with training component

* make input_train optional in prediction methods

* clean up predict method

* fix wf

* clean up train

* add helper test script

* update configs

* add to wf

* always store ymean.npy

* add shmsize to simplemlp

* Update src/tasks/predict_modality/methods/simple_mlp/train/script.py

* bigger shm

* Add nextflow workaround

* lower cpu label

---------

Co-authored-by: Kai Waldrant <[email protected]>
  • Loading branch information
rcannood and KaiWaldrant authored Aug 26, 2024
1 parent 7cea6a5 commit 41fc027
Show file tree
Hide file tree
Showing 16 changed files with 580 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/tasks/predict_modality/api/task_info.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,15 @@ authors:
roles: [ contributor ]
info:
email: [email protected]
github: nonztalk
github: nonztalk
- name: Xueer Chen
roles: [ contributor ]
info:
github: xuerchen
email: [email protected]
- name: Jiwei Liu
roles: [ contributor ]
info:
github: daxiongshu
email: [email protected]
orcid: "0000-0002-8799-9763"
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
__merge__: ../../../api/comp_method_predict.yaml
functionality:
name: simplemlp_predict
resources:
- type: python_script
path: script.py
- path: ../resources/
platforms:
- type: docker
# image: pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime
image: ghcr.io/openproblems-bio/base_pytorch_nvidia:1.0.4
# run_args: ["--gpus all --ipc=host"]
setup:
- type: python
pypi:
- scikit-learn
- scanpy
- pytorch-lightning
- type: nextflow
directives:
label: [highmem, hightime, midcpu, gpu, highsharedmem]
104 changes: 104 additions & 0 deletions src/tasks/predict_modality/methods/simple_mlp/predict/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from glob import glob
import sys
import numpy as np
from scipy.sparse import csc_matrix
import anndata as ad
import torch
from torch.utils.data import TensorDataset,DataLoader

## VIASH START
par = {
'input_train_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod1.h5ad',
'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod2.h5ad',
'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/test_mod1.h5ad',
'input_model': 'output/model',
'output': 'output/prediction'
}
meta = {
'resources_dir': 'src/tasks/predict_modality/methods/simple_mlp',
'cpus': 10
}
## VIASH END

resources_dir = f"{meta['resources_dir']}/resources"
sys.path.append(resources_dir)
from models import MLP
import utils

def _predict(model,dl):
model = model.cuda()
model.eval()
yps = []
for x in dl:
with torch.no_grad():
yp = model(x[0].cuda())
yps.append(yp.detach().cpu().numpy())
yp = np.vstack(yps)
return yp


print('Load data', flush=True)
input_train_mod2 = ad.read_h5ad(par['input_train_mod2'])
input_test_mod1 = ad.read_h5ad(par['input_test_mod1'])

# determine variables
mod_1 = input_test_mod1.uns['modality']
mod_2 = input_train_mod2.uns['modality']

task = f'{mod_1}2{mod_2}'

print('Load ymean', flush=True)
ymean_path = f"{par['input_model']}/{task}_ymean.npy"
ymean = np.load(ymean_path)

print('Start predict', flush=True)
if task == 'GEX2ATAC':
y_pred = ymean*np.ones([input_test_mod1.n_obs, input_test_mod1.n_vars])
else:
folds = [0, 1, 2]

ymean = torch.from_numpy(ymean).float()
yaml_path=f"{resources_dir}/yaml/mlp_{task}.yaml"
config = utils.load_yaml(yaml_path)
X = input_test_mod1.layers["normalized"].toarray()
X = torch.from_numpy(X).float()

te_ds = TensorDataset(X)

yp = 0
for fold in folds:
# load_path = f"{par['input_model']}/{task}_fold_{fold}/version_0/checkpoints/*"
load_path = f"{par['input_model']}/{task}_fold_{fold}/**.ckpt"
print(load_path)
ckpt = glob(load_path)[0]
model_inf = MLP.load_from_checkpoint(
ckpt,
in_dim=X.shape[1],
out_dim=input_test_mod1.n_vars,
ymean=ymean,
config=config
)
te_loader = DataLoader(
te_ds,
batch_size=config.batch_size,
num_workers=0,
shuffle=False,
drop_last=False
)
yp = yp + _predict(model_inf, te_loader)

y_pred = yp/len(folds)

y_pred = csc_matrix(y_pred)

adata = ad.AnnData(
layers={"normalized": y_pred},
shape=y_pred.shape,
uns={
'dataset_id': input_test_mod1.uns['dataset_id'],
'method_id': meta['functionality_name'],
},
)

print('Write data', flush=True)
adata.write_h5ad(par['output'], compression = "gzip")
68 changes: 68 additions & 0 deletions src/tasks/predict_modality/methods/simple_mlp/resources/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F

class MLP(pl.LightningModule):
def __init__(self,in_dim,out_dim,ymean,config):
super(MLP, self).__init__()
self.ymean = ymean.cuda()
H1 = config.H1
H2 = config.H2
p = config.dropout
self.config = config
self.fc1 = nn.Linear(in_dim, H1)
self.fc2 = nn.Linear(H1,H2)
self.fc3 = nn.Linear(H1+H2, out_dim)
self.dp2 = nn.Dropout(p=p)

def forward(self, x):
x0 = x
x1 = F.relu(self.fc1(x))
x1 = self.dp2(x1)
x = F.relu(self.fc2(x1))
x = torch.cat([x,x1],dim=1)
x = self.fc3(x)
x = self.apply_mask(x)
return x

def apply_mask(self,yp):
tmp = torch.ones_like(yp).float()*self.ymean
mask = tmp<self.config.threshold
mask = mask.float()
return yp*(1-mask) + tmp*mask

def training_step(self, batch, batch_nb):
x,y = batch
yp = self(x)
criterion = nn.MSELoss()
loss = criterion(yp, y)
self.log('train_loss', loss, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
x,y = batch
yp = self(x)
criterion = nn.MSELoss()
loss = criterion(yp, y)
self.log('valid_RMSE', loss**0.5, prog_bar=True)
return loss

def predict_step(self, batch, batch_idx):
if len(batch) == 2:
x,_ = batch
else:
x = batch
return self(x)

def configure_optimizers(self):
lr = self.config.lr
wd = float(self.config.wd)
adam = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)
if self.config.lr_schedule == 'adam':
return adam
elif self.config.lr_schedule == 'adam_cosin':
slr = torch.optim.lr_scheduler.CosineAnnealingLR(adam, self.config.epochs)
return [adam], [slr]
else:
assert 0
37 changes: 37 additions & 0 deletions src/tasks/predict_modality/methods/simple_mlp/resources/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import yaml
from collections import namedtuple


def to_site_donor(data):
df = data.obs['batch'].copy().to_frame().reset_index()
df.columns = ['index','batch']
df['site'] = df['batch'].apply(lambda x: x[:2])
df['donor'] = df['batch'].apply(lambda x: x[2:])
return df


def split(tr1, tr2, fold):
df = to_site_donor(tr1)
mask = df['site'] == f's{fold+1}'
maskr = ~mask

Xt = tr1[mask].layers["normalized"].toarray()
X = tr1[maskr].layers["normalized"].toarray()

yt = tr2[mask].layers["normalized"].toarray()
y = tr2[maskr].layers["normalized"].toarray()

print(f"{X.shape}, {y.shape}, {Xt.shape}, {yt.shape}")

return X,y,Xt,yt


def load_yaml(path):
with open(path) as f:
x = yaml.safe_load(f)
res = {}
for i in x:
res[i] = x[i]['value']
config = namedtuple('Config', res.keys())(**res)
print(config)
return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# sample config defaults file
epochs:
desc: Number of epochs to train over
value: 10
batch_size:
desc: Size of each mini-batch
value: 512
H1:
desc: Number of hidden neurons in 1st layer of MLP
value: 256
H2:
desc: Number of hidden neurons in 2nd layer of MLP
value: 128
dropout:
desc: probs of zeroing values
value: 0
lr:
desc: learning rate
value: 0.001
wd:
desc: weight decay
value: 1e-5
threshold:
desc: threshold to set values to zero
value: 0
lr_schedule:
desc: learning rate scheduler
value: adam
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# sample config defaults file
epochs:
desc: Number of epochs to train over
value: 10
batch_size:
desc: Size of each mini-batch
value: 512
H1:
desc: Number of hidden neurons in 1st layer of MLP
value: 256
H2:
desc: Number of hidden neurons in 2nd layer of MLP
value: 128
dropout:
desc: probs of zeroing values
value: 0.5
lr:
desc: learning rate
value: 0.001
wd:
desc: weight decay
value: 1e-5
threshold:
desc: threshold to set values to zero
value: 0
lr_schedule:
desc: learning rate scheduler
value: adam
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# sample config defaults file
epochs:
desc: Number of epochs to train over
value: 10
batch_size:
desc: Size of each mini-batch
value: 512
H1:
desc: Number of hidden neurons in 1st layer of MLP
value: 1024
H2:
desc: Number of hidden neurons in 2nd layer of MLP
value: 512
dropout:
desc: probs of zeroing values
value: 0
lr:
desc: learning rate
value: 0.001
wd:
desc: weight decay
value: 1e-5
threshold:
desc: threshold to set values to zero
value: 0.05
lr_schedule:
desc: learning rate scheduler
value: adam_cosin
26 changes: 26 additions & 0 deletions src/tasks/predict_modality/methods/simple_mlp/run/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
__merge__: ../../../api/comp_method_train.yaml
functionality:
name: simplemlp
info:
label: Simple MLP
summary: Ensemble of MLPs trained on different sites (team AXX)
description: |
This folder contains the AXX solution to the OpenProblems-NeurIPS2021 Single-Cell Multimodal Data Integration.
Team took the 4th place of the modality prediction task in terms of overall ranking of 4 subtasks: namely GEX
to ADT, ADT to GEX, GEX to ATAC and ATAC to GEX. Specifically, our methods ranked 3rd in GEX to ATAC and 4th
in GEX to ADT. More details about the task can be found in the
[competition webpage](https://openproblems.bio/events/2021-09_neurips/documentation/about_tasks/task1_modality_prediction).
documentation_url: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/AXX
repository_url: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/AXX
reference: lance2022multimodal
preferred_normalization: log_cp10k
competition_submission_id: 170812
resources:
- path: main.nf
type: nextflow_script
entrypoint: run_wf
dependencies:
- name: predict_modality/methods/simplemlp_train
- name: predict_modality/methods/simplemlp_predict
platforms:
- type: nextflow
21 changes: 21 additions & 0 deletions src/tasks/predict_modality/methods/simple_mlp/run/main.nf
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
workflow run_wf {
take: input_ch
main:
output_ch = input_ch

| simplemlp_train.run(
fromState: ["input_train_mod1", "input_train_mod2"],
toState: ["input_model": "output"]
)

| simplemlp_predict.run(
fromState: ["input_train_mod2", "input_test_mod1", "input_model", "input_transform"],
toState: ["output": "output"]
)

| map { tup ->
[tup[0], [output: tup[1].output]]
}

emit: output_ch
}
Loading

0 comments on commit 41fc027

Please sign in to comment.