Skip to content

Commit

Permalink
switch to stratified data splits
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Nov 23, 2023
1 parent 80e1748 commit bb2340f
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 65 deletions.
6 changes: 4 additions & 2 deletions chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from lightning.pytorch.cli import LightningCLI
from chebai.trainer.InnerCVTrainer import InnerCVTrainer


class ChebaiCLI(LightningCLI):

def __init__(self, *args, **kwargs):
super().__init__(trainer_class = InnerCVTrainer, *args, **kwargs)
super().__init__(trainer_class=InnerCVTrainer, *args, **kwargs)

def add_arguments_to_parser(self, parser):
for kind in ("train", "val", "test"):
Expand All @@ -15,7 +16,7 @@ def add_arguments_to_parser(self, parser):
"model.init_args.out_dim",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
)
#parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why
# parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why

@staticmethod
def subcommands() -> Dict[str, Set[str]]:
Expand All @@ -28,5 +29,6 @@ def subcommands() -> Dict[str, Set[str]]:
"cv_fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
}


def cli():
r = ChebaiCLI(save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"})
15 changes: 1 addition & 14 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from chebai.preprocessing.reader import MASK_TOKEN_INDEX, CLS_TOKEN
from chebai.preprocessing.datasets.chebi import extract_class_hierarchy
from chebai.loss.pretraining import ElectraPreLoss # noqa
import torch
import csv

Expand Down Expand Up @@ -405,17 +406,3 @@ def __call__(self, target, input):
)
return loss

class ElectraPreLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.ce = torch.nn.CrossEntropyLoss()

def forward(self, input, target, **loss_kwargs):
t, p = input
gen_pred, disc_pred = t
gen_tar, disc_tar = p
gen_loss = self.ce(target=torch.argmax(gen_tar.int(), dim=-1), input=gen_pred)
disc_loss = self.ce(
target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred
)
return gen_loss + disc_loss
131 changes: 91 additions & 40 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import pickle
import random

from sklearn.model_selection import train_test_split
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

import fastobo
import networkx as nx
import pandas as pd
Expand Down Expand Up @@ -117,17 +118,19 @@ def __init__(self, chebi_version_train: int = None, **kwargs):
def select_classes(self, g, split_name, *args, **kwargs):
raise NotImplementedError

def save(self, g, split, split_name: str):
def graph_to_raw_dataset(self, g, split_name=None):
"""Preparation step before creating splits, uses graph created by extract_class_hierarchy()
split_name is only relevant, if a separate train_version is set"""
smiles = nx.get_node_attributes(g, "smiles")
names = nx.get_node_attributes(g, "name")

print("build labels")
print(f"Process {split_name}")
print(f"Process graph")

molecules, smiles_list = zip(
*(
(n, smiles)
for n, smiles in ((n, smiles.get(n)) for n in split)
for n, smiles in ((n, smiles.get(n)) for n in smiles.keys())
if smiles
)
)
Expand All @@ -142,6 +145,10 @@ def save(self, g, split, split_name: str):
data = pd.DataFrame(data)
data = data[~data["SMILES"].isnull()]
data = data[data.iloc[:, 3:].any(axis=1)]
return data

def save(self, data: pd.DataFrame, split_name: str):

pickle.dump(data, open(os.path.join(self.raw_dir, split_name), "wb"))

@staticmethod
Expand Down Expand Up @@ -192,37 +199,75 @@ def setup_processed(self):
self._setup_pruned_test_set()
self.reader.save_token_cache()

def get_splits(self, g):
fixed_nodes = list(g.nodes)
def get_splits(self, df: pd.DataFrame):
print("Split dataset")
random.shuffle(fixed_nodes)

train_split, test_split = train_test_split(
fixed_nodes, train_size=self.train_split, shuffle=True
)
df_list = df.values.tolist()
df_list = [row[3:] for row in df_list]

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0)

train_split = []
test_split = []
for (train_split, test_split) in msss.split(
df_list, df_list,
):
train_split = train_split
test_split = test_split
break
df_train = df.iloc[train_split]
df_test = df.iloc[test_split]
if self.use_inner_cross_validation:
return train_split, test_split
return df_train, df_test

df_test_list = df_test.values.tolist()
df_test_list = [row[3:] for row in df_test_list]
validation_split = []
test_split = []
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0)
for (test_split, validation_split) in msss.split(
df_test_list, df_test_list
):
test_split = test_split
validation_split = validation_split
break

test_split, validation_split = train_test_split(
test_split, train_size=self.train_split, shuffle=True
)
return train_split, test_split, validation_split
df_validation = df_test.iloc[validation_split]
df_test = df_test.iloc[test_split]
return df_train, df_test, df_validation

def get_splits_given_test(self, g, test_split):
def get_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFrame):
""" Use test set from another chebi version the model does not train on, avoid overlap"""
fixed_nodes = list(g.nodes)
print(f"Split dataset for chebi_v{self.chebi_version_train}")
for node in test_split:
if node in fixed_nodes:
fixed_nodes.remove(node)
random.shuffle(fixed_nodes)
df_trainval = df
test_smiles = test_df['SMILES'].tolist()
mask = []
for row in df_trainval:
if row['SMILES'] in test_smiles:
mask.append(False)
else:
mask.append(True)
df_trainval = df_trainval[mask]

if self.use_inner_cross_validation:
return fixed_nodes
return df_trainval

# assume that size of validation split should relate to train split as in get_splits()
validation_split, train_split = train_test_split(
fixed_nodes, train_size=(1 - self.train_split) ** 2, shuffle=True
)
return train_split, validation_split
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=self.train_split ** 2, random_state=0)

df_trainval_list = df_trainval.tolist()
df_trainval_list = [row[3:] for row in df_trainval_list]
train_split = []
validation_split = []
for (train_split, validation_split) in msss.split(
df_trainval_list, df_trainval_list
):
train_split = train_split
validation_split = validation_split

df_validation = df_trainval.iloc[validation_split]
df_train = df_trainval.iloc[train_split]
return df_train, df_validation

@property
def processed_dir(self):
Expand All @@ -237,7 +282,7 @@ def processed_file_names_dict(self) -> dict:
train_v_str = f'_v{self.chebi_version_train}' if self.chebi_version_train else ''
res = {'test': f"test{train_v_str}.pt"}
if self.use_inner_cross_validation:
res['train_val'] = f'trainval{train_v_str}.pt' # for cv, split train/val on runtime
res['train_val'] = f'trainval{train_v_str}.pt' # for cv, split train/val on runtime
else:
res['train'] = f"train{train_v_str}.pt"
res['validation'] = f"validation{train_v_str}.pt"
Expand All @@ -246,10 +291,10 @@ def processed_file_names_dict(self) -> dict:
@property
def raw_file_names_dict(self) -> dict:
train_v_str = f'_v{self.chebi_version_train}' if self.chebi_version_train else ''
res = {'test': f"test.pkl"} # no extra raw test version for chebi_version_train - use default test set and only
# adapt processed file
res = {'test': f"test.pkl"} # no extra raw test version for chebi_version_train - use default test set and only
# adapt processed file
if self.use_inner_cross_validation:
res['train_val'] = f'trainval{train_v_str}.pkl' # for cv, split train/val on runtime
res['train_val'] = f'trainval{train_v_str}.pkl' # for cv, split train/val on runtime
else:
res['train'] = f"train{train_v_str}.pkl"
res['validation'] = f"validation{train_v_str}.pkl"
Expand Down Expand Up @@ -282,12 +327,13 @@ def prepare_data(self, *args, **kwargs):
open(chebi_path, "wb").write(r.content)
g = extract_class_hierarchy(chebi_path)
splits = {}
full_data = self.graph_to_raw_dataset(g)
if self.use_inner_cross_validation:
splits['train_val'], splits['test'] = self.get_splits(g)
splits['train_val'], splits['test'] = self.get_splits(full_data)
else:
splits['train'], splits['test'], splits['validation'] = self.get_splits(g)
splits['train'], splits['test'], splits['validation'] = self.get_splits(full_data)
for label, split in splits.items():
self.save(g, split, self.raw_file_names_dict[label])
self.save(split, self.raw_file_names_dict[label])
else:
# missing test set -> create
if not os.path.isfile(os.path.join(self.raw_dir, self.raw_file_names_dict['test'])):
Expand All @@ -298,8 +344,9 @@ def prepare_data(self, *args, **kwargs):
r = requests.get(url, allow_redirects=True)
open(chebi_path, "wb").write(r.content)
g = extract_class_hierarchy(chebi_path)
_, test_split, _ = self.get_splits(g)
self.save(g, test_split, self.raw_file_names_dict['test'])
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['test'])
_, test_split, _ = self.get_splits(df)
self.save(df, self.raw_file_names_dict['test'])
else:
# load test_split from file
with open(os.path.join(self.raw_dir, self.raw_file_names_dict['test']), "rb") as input_file:
Expand All @@ -312,12 +359,14 @@ def prepare_data(self, *args, **kwargs):
open(chebi_path, "wb").write(r.content)
g = extract_class_hierarchy(chebi_path)
if self.use_inner_cross_validation:
train_val_data = self.get_splits_given_test(g, test_split)
self.save(g, train_val_data, self.raw_file_names_dict['train_val'])
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train_val'])
train_val_df = self.get_splits_given_test(df, test_split)
self.save(train_val_df, self.raw_file_names_dict['train_val'])
else:
train_split, val_split = self.get_splits_given_test(g, test_split)
self.save(g, train_split, self.raw_file_names_dict['train'])
self.save(g, val_split, self.raw_file_names_dict['validation'])
df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train'])
train_split, val_split = self.get_splits_given_test(df, test_split)
self.save(train_split, self.raw_file_names_dict['train'])
self.save(val_split, self.raw_file_names_dict['validation'])


class JCIExtendedBase(_ChEBIDataExtractor):
Expand Down Expand Up @@ -371,6 +420,7 @@ def select_classes(self, g, split_name, *args, **kwargs):
fout.writelines(str(node) + "\n" for node in nodes)
return nodes


class ChEBIOverXDeepSMILES(ChEBIOverX):
READER = dr.DeepChemDataReader

Expand All @@ -388,6 +438,7 @@ class ChEBIOver50(ChEBIOverX):
def label_number(self):
return 1332


class ChEBIOver100DeepSMILES(ChEBIOverXDeepSMILES, ChEBIOver100):
pass

Expand Down
13 changes: 4 additions & 9 deletions chebai/trainer/InnerCVTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning_utilities.core.rank_zero import WarningCache

from sklearn import model_selection
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from lightning.pytorch.callbacks.model_checkpoint import _is_dir, rank_zero_warn

from chebai.preprocessing.datasets.base import XYBaseDataModule
Expand All @@ -29,17 +29,17 @@ def __init__(self, *args, **kwargs):
self._logger_connector = _LoggerConnectorCVSupport(self)
self._logger_connector.on_trainer_init(self.logger, 1)


def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwargs):
if n_splits < 2:
self.fit(datamodule=datamodule, *args, **kwargs)
else:
datamodule.prepare_data()
datamodule.setup()

kfold = model_selection.KFold(n_splits=n_splits)
kfold = MultilabelStratifiedKFold(n_splits=n_splits)

for fold, (train_ids, val_ids) in enumerate(kfold.split(datamodule.train_val_data)):
for fold, (train_ids, val_ids) in enumerate(
kfold.split(datamodule.train_val_data, [data['labels'] for data in datamodule.train_val_data])):
train_dataloader = datamodule.train_dataloader(ids=train_ids)
val_dataloader = datamodule.val_dataloader(ids=val_ids)
init_kwargs = self.init_kwargs
Expand Down Expand Up @@ -81,7 +81,6 @@ class ModelCheckpointCVSupport(ModelCheckpoint):
def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
"""Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir"""
if self.dirpath is not None:
print(f'Eliminating existing dirpath {self.dirpath} at ModelCheckpoint setup')
self.dirpath = None
dirpath = self.__resolve_ckpt_dir(trainer)
dirpath = trainer.strategy.broadcast(dirpath)
Expand Down Expand Up @@ -109,7 +108,6 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
if self.dirpath is not None:
# short circuit if dirpath was passed to ModelCheckpoint
return self.dirpath
print(f'Found {len(trainer.loggers)} loggers')
if len(trainer.loggers) > 0:
if trainer.loggers[0].save_dir is not None:
save_dir = trainer.loggers[0].save_dir
Expand All @@ -119,7 +117,6 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
version = trainer.loggers[0].version
version = version if isinstance(version, str) else f"version_{version}"
cv_logger = trainer.loggers[0]
print(f'Found logger {cv_logger.__class__}')
if isinstance(cv_logger, CSVLoggerCVSupport) and cv_logger.fold is not None:
# log_dir includes fold
ckpt_path = os.path.join(cv_logger.log_dir, "checkpoints")
Expand All @@ -138,7 +135,6 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:

class _LoggerConnectorCVSupport(_LoggerConnector):
def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> None:
print(f'called configure_logger')
if not logger:
# logger is None or logger is False
self.trainer.loggers = []
Expand All @@ -159,5 +155,4 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non
elif isinstance(logger, Iterable):
self.trainer.loggers = list(logger)
else:
print(f'setting trainer.loggers to [logger]')
self.trainer.loggers = [logger]

0 comments on commit bb2340f

Please sign in to comment.