From e0f72aea6a8235f299ee5bdfac6eb868fe2bdd8e Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 5 Dec 2023 17:16:36 +0100 Subject: [PATCH] fix data preparation --- .gitignore | 3 - chebai/callbacks/model_checkpoint.py | 6 +- chebai/preprocessing/datasets/chebi.py | 133 +++++++++---------------- chebai/trainer/InnerCVTrainer.py | 2 +- 4 files changed, 55 insertions(+), 89 deletions(-) diff --git a/.gitignore b/.gitignore index cd34d6b4..4c14111e 100644 --- a/.gitignore +++ b/.gitignore @@ -161,6 +161,3 @@ cython_debug/ #.idea/ configs/ -# the notebook I put in the wrong folder -chebai/preprocessing/datasets/demo_old_chebi.ipynb -demo_examine_pretraining_data.ipynb \ No newline at end of file diff --git a/chebai/callbacks/model_checkpoint.py b/chebai/callbacks/model_checkpoint.py index cf461384..b5740438 100644 --- a/chebai/callbacks/model_checkpoint.py +++ b/chebai/callbacks/model_checkpoint.py @@ -1,6 +1,10 @@ from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.fabric.utilities.types import _PATH - +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch import Trainer, LightningModule +import os +from lightning.fabric.utilities.cloud_io import _is_dir +from lightning.pytorch.utilities.rank_zero import rank_zero_info class CustomModelCheckpoint(ModelCheckpoint): """Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 22d7623a..01510972 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -119,12 +119,11 @@ def select_classes(self, g, split_name, *args, **kwargs): raise NotImplementedError def graph_to_raw_dataset(self, g, split_name=None): - """Preparation step before creating splits, uses graph created by extract_class_hierarchy() + """Preparation step before creating splits, uses graph created by extract_class_hierarchy(), split_name is only relevant, if a separate train_version is set""" smiles = nx.get_node_attributes(g, "smiles") names = nx.get_node_attributes(g, "name") - print("build labels") print(f"Process graph") molecules, smiles_list = zip( @@ -199,68 +198,50 @@ def setup_processed(self): self._setup_pruned_test_set() self.reader.save_token_cache() - def get_splits(self, df: pd.DataFrame): - print("Split dataset") + def get_test_split(self, df: pd.DataFrame): + print("Split dataset into train (including val) / test") df_list = df.values.tolist() df_list = [row[3:] for row in df_list] - msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0) + test_size = 1 - self.train_split - (1 - self.train_split) ** 2 + msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=0) train_split = [] test_split = [] for (train_split, test_split) in msss.split( - df_list, df_list, + df_list, df_list, ): train_split = train_split test_split = test_split break df_train = df.iloc[train_split] df_test = df.iloc[test_split] - if self.use_inner_cross_validation: - return df_train, df_test + return df_train, df_test - df_test_list = df_test.values.tolist() - df_test_list = [row[3:] for row in df_test_list] - validation_split = [] - test_split = [] - msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-self.train_split, random_state=0) - for (test_split, validation_split) in msss.split( - df_test_list, df_test_list - ): - test_split = test_split - validation_split = validation_split - break + def get_train_val_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFrame): + """ Use test set (e.g., loaded from another chebi version or generated in get_test_split), avoid overlap""" + print(f"Split dataset into train / val with given test set") - df_validation = df_test.iloc[validation_split] - df_test = df_test.iloc[test_split] - return df_train, df_test, df_validation - - def get_splits_given_test(self, df: pd.DataFrame, test_df: pd.DataFrame): - """ Use test set from another chebi version the model does not train on, avoid overlap""" - print(f"Split dataset for chebi_v{self.chebi_version_train}") df_trainval = df test_smiles = test_df['SMILES'].tolist() - mask = [] - for row in df_trainval: - if row['SMILES'] in test_smiles: - mask.append(False) - else: - mask.append(True) + mask = [smiles not in test_smiles for smiles in df_trainval['SMILES']] df_trainval = df_trainval[mask] + if self.use_inner_cross_validation: return df_trainval - # assume that size of validation split should relate to train split as in get_splits() - msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=self.train_split ** 2, random_state=0) + # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) + test_size = ((1 - self.train_split) ** 2) / self.train_split + msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=0) - df_trainval_list = df_trainval.tolist() + df_trainval_list = df_trainval.values.tolist() df_trainval_list = [row[3:] for row in df_trainval_list] train_split = [] validation_split = [] for (train_split, validation_split) in msss.split( - df_trainval_list, df_trainval_list + df_trainval_list, df_trainval_list ): train_split = train_split validation_split = validation_split @@ -309,6 +290,16 @@ def processed_file_names(self): def raw_file_names(self): return list(self.raw_file_names_dict.values()) + def _load_chebi(self, version: int): + chebi_name = f'chebi.obo' if version == self.chebi_version else f'chebi_v{version}.obo' + chebi_path = os.path.join(self.raw_dir, chebi_name) + if not os.path.isfile(chebi_path): + print(f"Load ChEBI ontology (v_{version})") + url = f"http://purl.obolibrary.org/obo/chebi/{version}/chebi.obo" + r = requests.get(url, allow_redirects=True) + open(chebi_path, "wb").write(r.content) + return chebi_path + def prepare_data(self, *args, **kwargs): print("Check for raw data in", self.raw_dir) if any( @@ -317,56 +308,30 @@ def prepare_data(self, *args, **kwargs): ): os.makedirs(self.raw_dir, exist_ok=True) print("Missing raw data. Go fetch...") - if self.chebi_version_train is None: - # load chebi_v{chebi_version}, create splits - chebi_path = os.path.join(self.raw_dir, f"chebi.obo") - if not os.path.isfile(chebi_path): - print("Load ChEBI ontology") - url = f"http://purl.obolibrary.org/obo/chebi/{self.chebi_version}/chebi.obo" - r = requests.get(url, allow_redirects=True) - open(chebi_path, "wb").write(r.content) + # missing test set -> create + if not os.path.isfile(os.path.join(self.raw_dir, self.raw_file_names_dict['test'])): + chebi_path = self._load_chebi(self.chebi_version) g = extract_class_hierarchy(chebi_path) - splits = {} - full_data = self.graph_to_raw_dataset(g) - if self.use_inner_cross_validation: - splits['train_val'], splits['test'] = self.get_splits(full_data) - else: - splits['train'], splits['test'], splits['validation'] = self.get_splits(full_data) - for label, split in splits.items(): - self.save(split, self.raw_file_names_dict[label]) + df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['test']) + _, test_df = self.get_test_split(df) + self.save(test_df, self.raw_file_names_dict['test']) + # load test_split from file else: - # missing test set -> create - if not os.path.isfile(os.path.join(self.raw_dir, self.raw_file_names_dict['test'])): - chebi_path = os.path.join(self.raw_dir, f"chebi.obo") - if not os.path.isfile(chebi_path): - print("Load ChEBI ontology") - url = f"http://purl.obolibrary.org/obo/chebi/{self.chebi_version}/chebi.obo" - r = requests.get(url, allow_redirects=True) - open(chebi_path, "wb").write(r.content) - g = extract_class_hierarchy(chebi_path) - df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['test']) - _, test_split, _ = self.get_splits(df) - self.save(df, self.raw_file_names_dict['test']) - else: - # load test_split from file - with open(os.path.join(self.raw_dir, self.raw_file_names_dict['test']), "rb") as input_file: - test_split = [row[0] for row in pickle.load(input_file).values] - chebi_path = os.path.join(self.raw_dir, f"chebi_v{self.chebi_version_train}.obo") - if not os.path.isfile(chebi_path): - print(f"Load ChEBI ontology (v_{self.chebi_version_train})") - url = f"http://purl.obolibrary.org/obo/chebi/{self.chebi_version_train}/chebi.obo" - r = requests.get(url, allow_redirects=True) - open(chebi_path, "wb").write(r.content) - g = extract_class_hierarchy(chebi_path) - if self.use_inner_cross_validation: - df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train_val']) - train_val_df = self.get_splits_given_test(df, test_split) - self.save(train_val_df, self.raw_file_names_dict['train_val']) - else: - df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train']) - train_split, val_split = self.get_splits_given_test(df, test_split) - self.save(train_split, self.raw_file_names_dict['train']) - self.save(val_split, self.raw_file_names_dict['validation']) + with open(os.path.join(self.raw_dir, self.raw_file_names_dict['test']), "rb") as input_file: + test_df = pickle.load(input_file) + # create train/val split based on test set + chebi_path = self._load_chebi( + self.chebi_version_train if self.chebi_version_train is not None else self.chebi_version) + g = extract_class_hierarchy(chebi_path) + if self.use_inner_cross_validation: + df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train_val']) + train_val_df = self.get_train_val_splits_given_test(df, test_df) + self.save(train_val_df, self.raw_file_names_dict['train_val']) + else: + df = self.graph_to_raw_dataset(g, self.raw_file_names_dict['train']) + train_split, val_split = self.get_train_val_splits_given_test(df, test_df) + self.save(train_split, self.raw_file_names_dict['train']) + self.save(val_split, self.raw_file_names_dict['validation']) class JCIExtendedBase(_ChEBIDataExtractor): diff --git a/chebai/trainer/InnerCVTrainer.py b/chebai/trainer/InnerCVTrainer.py index 30f48e9c..ad06fe97 100644 --- a/chebai/trainer/InnerCVTrainer.py +++ b/chebai/trainer/InnerCVTrainer.py @@ -39,7 +39,7 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar train_dataloader = datamodule.train_dataloader(ids=train_ids) val_dataloader = datamodule.val_dataloader(ids=val_ids) init_kwargs = self.init_kwargs - new_trainer = Trainer(*self.init_args, **init_kwargs) + new_trainer = InnerCVTrainer(*self.init_args, **init_kwargs) logger = new_trainer.logger if isinstance(logger, CustomLogger): logger.set_fold(fold)