Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various features #6

Merged
merged 123 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
3404c86
add support for training set from different chebi version
sfluegel05 Nov 8, 2023
2b0ba42
optimise filter in train/val split
sfluegel05 Nov 8, 2023
5d63cc6
Merge branch 'ChEB-AI:dev' into feature-old-chebi
sfluegel05 Nov 8, 2023
36872b1
save splits and classes individually
sfluegel05 Nov 9, 2023
d1b2d29
fix circular imports
sfluegel05 Nov 9, 2023
f312e9b
add SMILES extraction for old chebi version
sfluegel05 Nov 9, 2023
25def45
add extendable tokens.txt to cover new tokens in different chebi version
sfluegel05 Nov 9, 2023
3b57c6c
minor fixes
sfluegel05 Nov 9, 2023
a2e4dfd
add test data set adapted to old version
sfluegel05 Nov 10, 2023
16b9dfa
add DeepSMILES support
sfluegel05 Nov 10, 2023
2d9ab05
fix processed file names dict
sfluegel05 Nov 13, 2023
d032e75
make creation of missing processed data sets individual for train/val…
sfluegel05 Nov 15, 2023
53a020a
add script for evaluation
sfluegel05 Nov 15, 2023
158c229
restructure data folder, include version parameter
sfluegel05 Nov 16, 2023
c4caedd
fix tokens.txt
sfluegel05 Nov 16, 2023
3833a6e
Merge pull request #1 from sfluegel05/feature-old-chebi
sfluegel05 Nov 16, 2023
e0cf796
introduce inner-cross-validation to data module
sfluegel05 Nov 16, 2023
09ccd6c
add trainer for inner-cv, update cli and dataloaders
sfluegel05 Nov 17, 2023
fc72b71
fix import
sfluegel05 Nov 17, 2023
c3a4dd6
process results for v227, instantiate new trainers for each fold
sfluegel05 Nov 17, 2023
9696cc1
fix trainer instantiation
sfluegel05 Nov 17, 2023
1e3ddc9
save checkpoints for different folds separately
sfluegel05 Nov 20, 2023
89f1851
debug log path
sfluegel05 Nov 20, 2023
575382e
debug log path
sfluegel05 Nov 20, 2023
15be9aa
fix setting init args
sfluegel05 Nov 20, 2023
72a3c64
fix logger save dir
sfluegel05 Nov 20, 2023
52acdb6
extend csvlogger
sfluegel05 Nov 20, 2023
481265c
improve console output
sfluegel05 Nov 20, 2023
d370384
add pretraining yamls, update pubchem download
sfluegel05 Nov 20, 2023
0d70db6
couple checkpoint path to log path
sfluegel05 Nov 20, 2023
8ec1ff4
fix processed data dir
sfluegel05 Nov 20, 2023
9fc70c1
fix callback class path
sfluegel05 Nov 20, 2023
0e53c50
fix class paths for callbacks
sfluegel05 Nov 20, 2023
2a11c69
add infrastructure for custom logger
sfluegel05 Nov 21, 2023
f65f096
udpdate gitignore
sfluegel05 Nov 21, 2023
c9d9b82
debug model checkpoint
sfluegel05 Nov 21, 2023
2c6368b
add custom logger connector
sfluegel05 Nov 21, 2023
4bde6d9
add custom logger connector
sfluegel05 Nov 21, 2023
98d22a1
debug ckpt path
sfluegel05 Nov 21, 2023
0d6132f
debug ckpt path
sfluegel05 Nov 21, 2023
08cb61e
Merge pull request #3 from sfluegel05/feature-crossvalidation
sfluegel05 Nov 22, 2023
8ff3d5f
clean up trainer, use common path for best checkpoints
Nov 22, 2023
0806f9a
fix conflicts
Nov 22, 2023
24e45fa
Merge branch 'sfluegel05-feature-pretraining' into feature-pretraining
Nov 22, 2023
c73d0f8
fix loss class path
Nov 22, 2023
950b6f0
re-add ElectraPreLoss to model module to keep compatibility with pret…
Nov 22, 2023
80e1748
add loss-based checkpointing for pretraining, fix checkpointing based…
Nov 23, 2023
bb2340f
switch to stratified data splits
Nov 23, 2023
5cfa50f
Merge branch 'features-sfluegel' into feature-pretraining
sfluegel05 Nov 23, 2023
61fef64
Merge pull request #5 from ChEB-AI/feature-pretraining
sfluegel05 Nov 23, 2023
c615f17
add error handling to deepSMILES reader
Nov 24, 2023
46852e4
improve error handling in deepSMILES reader
Nov 24, 2023
2aae072
add evaluation functions for pretraining
Nov 24, 2023
684bdf7
add epoch-level macro-f1
Nov 24, 2023
ad107a9
fix conversion error in forward
Nov 24, 2023
868cd75
debug epoch-level macro-f1
Nov 24, 2023
8040ad7
update readme
Nov 24, 2023
3de3992
change logger class
Nov 24, 2023
934c960
fix device handling for macro-f1
Nov 24, 2023
6bff764
Merge branch 'features-sfluegel' into feature-wandb-integration
Nov 24, 2023
b894986
set entity
Nov 24, 2023
c066813
process results for classification model
Nov 29, 2023
c9a1b76
add support for SELFIES
Nov 29, 2023
bd91161
Merge branch 'features-sfluegel' into feature-wandb-integration
Nov 29, 2023
27868ee
adapt selfies reader
Nov 29, 2023
ac4017b
add dir creation
Nov 29, 2023
17a8760
add file creation
Nov 29, 2023
5f50d8f
fix selfies reader exception handling
Nov 29, 2023
fe6c3f0
fix selfies reader file creation
Nov 29, 2023
8bf268a
fix selfies reader
Nov 29, 2023
59a57b2
fix electra model
Nov 29, 2023
0868b67
fix handling of errors in reader
Nov 30, 2023
5a3dd3d
set constraints in selfies reader
Nov 30, 2023
b30c97e
fix filtering for none-value in setup_preprocessed()
Nov 30, 2023
0ef01cd
add script for evaluating models trained on v148 and v200
Dec 1, 2023
9e0bbd5
add custom logger
Dec 4, 2023
afb4fbb
add subcommand predict_from_file
Dec 4, 2023
21937f6
deactivate dropout
Dec 4, 2023
a21e3f8
update checkpoints for wandb logging, refactor
Dec 5, 2023
e0f72ae
fix data preparation
Dec 5, 2023
9c75553
fix wandb run naming with cv
Dec 5, 2023
9be20cb
Merge branch 'feature-wandb-integration' into features-sfluegel
Dec 5, 2023
9fc9030
comment default_trainer.yml
Dec 5, 2023
9d88a41
fix typehint
Dec 5, 2023
7229249
fix typehint
Dec 5, 2023
524724a
fix wandb run name
Dec 6, 2023
e9f6553
reformat using black
Dec 7, 2023
b498f00
generalise token path separation, include tokens.txts, merge dev (exc…
Dec 8, 2023
4f932ef
reformat using black, update tokens.txt and improve token file handling
Dec 8, 2023
cb5d6b4
fix parallelisation problem when resolving checkpoint path
Dec 11, 2023
643ffe8
add support for integration of gnn module
Dec 11, 2023
d01a67a
add / improve classfication task evaluation functions
Dec 13, 2023
ecabf79
fix model evaluation
Dec 13, 2023
b197932
fix model evaluation
Dec 13, 2023
8a3b87a
add adjustment factor for macro-f1
Dec 14, 2023
462379b
add support for single class classification
Dec 14, 2023
e90d66b
add support for single class classification
Dec 14, 2023
5df0060
fix macro-adjust
Dec 14, 2023
fbc22f4
adapt checkpoints for single-class f1
Dec 14, 2023
132b5f4
update to-device conversion for gnn data
Dec 15, 2023
4491834
add trainer strategy, needed for gat model
Dec 18, 2023
c03bb36
sync epoch-level metrics and outputs across devices
Dec 18, 2023
72f47bd
add cosmetic logging improvements
Dec 18, 2023
1d6f590
improve dataloader worker handling
Dec 19, 2023
61ab18d
generalise reader hook at the end of preprocessing
Dec 19, 2023
c8bb0da
readjust macro-f1 calculation for each epoch
Dec 19, 2023
421ed68
adapt for gnn properties
Dec 20, 2023
950804c
improve data path handling, improve evaluation customisability
Dec 22, 2023
e458b86
fix load_processed_data
Dec 22, 2023
78ae996
implement macro-f1 as metric, fix reduction across devices
Dec 22, 2023
15c0a8a
fix macro-f1 dimensions
Dec 22, 2023
6ffa25b
split up logger config, link num_labels argument to macro-f1
Dec 22, 2023
8e55600
log metrics only at epoch end, fix macrof1 metric
Jan 2, 2024
c108686
reimplement cross-validation
Jan 3, 2024
a832015
reformat with isort and black
Jan 3, 2024
07a3a0d
add chebi-subset-based dataset
Jan 3, 2024
520d4bb
fix macro-f1
Jan 3, 2024
c131481
prevent creation of wandb artifacts of checkpoints in addition to che…
Jan 4, 2024
bddd271
fix device of macro-f1 in evaluation
Jan 4, 2024
5fc02d1
fix metric name in checkpoint callback config
Jan 4, 2024
7bbe5c1
fix masking in macro-f1-score
Jan 4, 2024
27ec370
add data_limit to evaluation
Jan 5, 2024
6c72805
fix passing of metric to log function
Jan 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
repos:
- repo: https://github.com/PyCQA/isort
rev: "5.12.0"
hooks:
- id: isort
#- repo: https://github.com/PyCQA/isort
# rev: "5.12.0"
# hooks:
# - id: isort
- repo: https://github.com/psf/black
rev: "22.10.0"
hooks:
Expand Down
51 changes: 50 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,53 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co

```
python -m chebai train --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --ckpt_path=[path-to-model-with-ontology-pretraining]
```
```

## Predicting classes given SMILES strings

```
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
```
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
one row for each SMILES string and one column for each class.


## Cross-validation
You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test
set. For that, you need to specify the total_number of folds as
```
--data.init_args.inner_k_folds=K
```
and the fold to be used in the current optimisation run as
```
--data.init_args.fold_index=I
```
To train K models, you need to do K such calls, each with a different `fold_index`. On the first call with a given
`inner_k_folds`, all folds will be created and stored in the data directory

## Chebi versions
Change the chebi version used for all sets (default: 200):
```
--data.init_args.chebi_version=VERSION
```
To change only the version of the train and validation sets independently of the test set, use
```
--data.init_args.chebi_version_train=VERSION
```

## Data folder structure
Data is stored in and retrieved from the raw and processed folders
```
data/${dataset_name}/${chebi_version}/raw/
```
and
```
data/${dataset_name}/${chebi_version}/processed/${reader_name}/
```
where `${dataset_name}` is the `_name`-attribute of the `DataModule` used,
`${chebi_version}` refers to the ChEBI version used (only for ChEBI-datasets) and
`${reader_name}` is the `name`-attribute of the `Reader` class associated with the dataset.

For cross-validation, the folds are stored as `cv_${n_folds}_fold/fold_{fold_index}_train.pkl`
and `cv_${n_folds}_fold/fold_{fold_index}_validation.pkl` in the raw directory.
In the processed directory, `.pt` is used instead of `.pkl`.
1 change: 1 addition & 0 deletions chebai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import torch

MODULE_PATH = os.path.abspath(os.path.dirname(__file__))
Expand Down
5 changes: 3 additions & 2 deletions chebai/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import os

from lightning.pytorch.callbacks import BasePredictionWriter
import torch
import os
import json


class ChebaiPredictionWriter(BasePredictionWriter):
Expand Down
Empty file added chebai/callbacks/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import torchmetrics


def custom_reduce_fx(input):
print(f"called reduce (device: {input.device})")
return torch.sum(input, dim=0)


class MacroF1(torchmetrics.Metric):
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.add_state(
"true_positives",
default=torch.zeros(num_labels, dtype=torch.int),
dist_reduce_fx="sum",
)
self.add_state(
"positive_predictions",
default=torch.zeros(num_labels, dtype=torch.int),
dist_reduce_fx="sum",
)
self.add_state(
"positive_labels",
default=torch.zeros(num_labels, dtype=torch.int),
dist_reduce_fx="sum",
)
self.threshold = threshold

def update(self, preds: torch.Tensor, labels: torch.Tensor):
tps = torch.sum(
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
)
self.true_positives += tps
self.positive_predictions += torch.sum(preds > self.threshold, dim=0)
self.positive_labels += torch.sum(labels, dim=0)

def compute(self):
# ignore classes without positive labels
# classes with positive labels, but no positive predictions will get a precision of "nan" (0 divided by 0),
# which is propagated to the classwise_f1 and then turned into 0
mask = self.positive_labels != 0
precision = self.true_positives[mask] / self.positive_predictions[mask]
recall = self.true_positives[mask] / self.positive_labels[mask]
classwise_f1 = 2 * precision * recall / (precision + recall)
# if (precision and recall are 0) or (precision is nan), set f1 to 0
classwise_f1 = classwise_f1.nan_to_num()
return torch.mean(classwise_f1)
63 changes: 63 additions & 0 deletions chebai/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

from lightning.fabric.utilities.cloud_io import _is_dir
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities.rank_zero import rank_zero_info
from lightning_utilities.core.rank_zero import rank_zero_warn


class CustomModelCheckpoint(ModelCheckpoint):
"""Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the
same directory as the other logs"""

def setup(
self, trainer: "Trainer", pl_module: "LightningModule", stage: str
) -> None:
"""Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir"""
if self.dirpath is not None:
self.dirpath = None
dirpath = self.__resolve_ckpt_dir(trainer)
dirpath = trainer.strategy.broadcast(dirpath)
self.dirpath = dirpath
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)

def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
"""Same as in parent class, duplicated because method in parent class is not accessible"""
if (
self.save_top_k != 0
and _is_dir(self._fs, dirpath, strict=True)
and len(self._fs.ls(dirpath)) > 0
):
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
"""Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs"""
rank_zero_info(f"Resolving checkpoint dir (custom)")
if self.dirpath is not None:
# short circuit if dirpath was passed to ModelCheckpoint
return self.dirpath
if len(trainer.loggers) > 0:
if trainer.loggers[0].save_dir is not None:
save_dir = trainer.loggers[0].save_dir
else:
save_dir = trainer.default_root_dir
name = trainer.loggers[0].name
version = trainer.loggers[0].version
version = version if isinstance(version, str) else f"version_{version}"
logger = trainer.loggers[0]
if isinstance(logger, WandbLogger) and isinstance(
logger.experiment.dir, str
):
ckpt_path = os.path.join(logger.experiment.dir, "checkpoints")
else:
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
else:
# if no loggers, use default_root_dir
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")

rank_zero_info(f"Now using checkpoint path {ckpt_path}")
return ckpt_path
30 changes: 27 additions & 3 deletions chebai/cli.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,39 @@
from lightning.pytorch.cli import LightningCLI
from typing import Dict, Set

from lightning.pytorch.cli import LightningArgumentParser, LightningCLI

from chebai.trainer.CustomTrainer import CustomTrainer


class ChebaiCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
def __init__(self, *args, **kwargs):
super().__init__(trainer_class=CustomTrainer, *args, **kwargs)

def add_arguments_to_parser(self, parser: LightningArgumentParser):
for kind in ("train", "val", "test"):
for average in ("micro", "macro"):
parser.link_arguments(
"model.init_args.out_dim",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
)
parser.link_arguments(
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
)

@staticmethod
def subcommands() -> Dict[str, Set[str]]:
"""Defines the list of available subcommands and the arguments to skip."""
return {
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
"validate": {"model", "dataloaders", "datamodule"},
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
"predict_from_file": {"model"},
}


def cli():
r = ChebaiCLI(save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"})
r = ChebaiCLI(
save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"},
)
Empty file added chebai/loggers/__init__.py
Empty file.
83 changes: 83 additions & 0 deletions chebai/loggers/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from datetime import datetime
from typing import Literal, Optional, Union
import os

from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
import wandb


class CustomLogger(WandbLogger):
"""Adds support for custom naming of runs and cross-validation"""

def __init__(
self,
save_dir: _PATH,
name: str = "logs",
version: Optional[Union[int, str]] = None,
prefix: str = "",
fold: Optional[int] = None,
project: Optional[str] = None,
entity: Optional[str] = None,
offline: bool = False,
log_model: Union[Literal["all"], bool] = False,
**kwargs,
):
if version is None:
version = f"{datetime.now():%y%m%d-%H%M}"
self._version = version
self._name = name
self._fold = fold
super().__init__(
name=self.name,
save_dir=save_dir,
version=None,
prefix=prefix,
log_model=log_model,
entity=entity,
project=project,
offline=offline,
**kwargs,
)

@property
def name(self) -> Optional[str]:
name = f"{self._name}_{self.version}"
if self._fold is not None:
name += f"_fold{self._fold}"
return name

@property
def version(self) -> Optional[str]:
return self._version

@property
def root_dir(self) -> Optional[str]:
return os.path.join(self.save_dir, self.name)

@property
def log_dir(self) -> str:
version = (
self.version if isinstance(self.version, str) else f"version_{self.version}"
)
if self._fold is None:
return os.path.join(self.root_dir, version)
return os.path.join(self.root_dir, version, f"fold_{self._fold}")

def set_fold(self, fold: int):
if fold != self._fold:
self._fold = fold
# start new experiment
wandb.finish()
self._wandb_init["name"] = self.name
self._experiment = None
_ = self.experiment

@property
def fold(self):
return self._fold

def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
# don't save checkpoint as wandb artifact
pass
2 changes: 1 addition & 1 deletion chebai/loss/pretraining.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch


class ElectraPreLoss(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -14,4 +15,3 @@ def forward(self, input, target, **loss_kwargs):
target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred
)
return gen_loss + disc_loss

8 changes: 5 additions & 3 deletions chebai/loss/semantic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from chebai.models.electra import extract_class_hierarchy
import os
import csv
import os
import pickle

import torch

from chebai.models.electra import extract_class_hierarchy

IMPLICATION_CACHE_FILE = "chebi.cache"


Expand Down
Loading
Loading