Skip to content

Commit

Permalink
add custom logger
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 4, 2023
1 parent b30c97e commit 9e0bbd5
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 31 deletions.
3 changes: 2 additions & 1 deletion chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ def subcommands() -> Dict[str, Set[str]]:


def cli():
r = ChebaiCLI(save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"})
r = ChebaiCLI(save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"})
57 changes: 57 additions & 0 deletions chebai/loggers/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from datetime import datetime
from typing import Optional, Union, Literal

import wandb
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.loggers import WandbLogger
import os


class CustomLogger(WandbLogger):
"""Adds support for custom naming of runs and cross-validation"""

def __init__(self, save_dir: _PATH, name: str = "logs", version: Optional[Union[int, str]] = None, prefix: str = "",
fold: Optional[int] = None, project: Optional[str] = None, entity: Optional[str] = None,
offline: bool = False,
log_model: Union[Literal["all"], bool] = False, **kwargs):
if version is None:
version = f'{datetime.now():%y%m%d-%H%M}'
self._version = version
self._name = name
self._fold = fold
super().__init__(name=self.name, save_dir=save_dir, version=None, prefix=prefix,
log_model=log_model, entity=entity, project=project, offline=offline, **kwargs)

@property
def name(self) -> Optional[str]:
name = f'{self._name}_{self.version}'
if self._fold is not None:
name += f'_fold{self._fold}'
return name

@property
def version(self) -> Optional[str]:
return self._version

@property
def root_dir(self) -> Optional[str]:
return os.path.join(self.save_dir, self.name)

@property
def log_dir(self) -> str:
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
if self._fold is None:
return os.path.join(self.root_dir, version)
return os.path.join(self.root_dir, version, f'fold_{self._fold}')

def set_fold(self, fold: int):
if fold != self._fold:
self._fold = fold
# start new experiment
wandb.finish()
self._experiment = None
_ = self.experiment

@property
def fold(self):
return self._fold
46 changes: 20 additions & 26 deletions chebai/trainer/InnerCVTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.trainer.connectors.logger_connector import _LoggerConnector
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
from lightning.pytorch.loggers import CSVLogger, Logger, TensorBoardLogger
from lightning.pytorch.loggers import CSVLogger, Logger, TensorBoardLogger, WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning_utilities.core.rank_zero import WarningCache

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from lightning.pytorch.callbacks.model_checkpoint import _is_dir, rank_zero_warn

from chebai.loggers.custom import CustomLogger
from chebai.preprocessing.datasets.base import XYBaseDataModule

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,35 +43,28 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar
train_dataloader = datamodule.train_dataloader(ids=train_ids)
val_dataloader = datamodule.val_dataloader(ids=val_ids)
init_kwargs = self.init_kwargs
new_logger = CSVLoggerCVSupport(save_dir=self.logger.save_dir, name=self.logger.name,
version=self.logger.version, fold=fold)
init_kwargs['logger'] = new_logger
new_trainer = Trainer(*self.init_args, **init_kwargs)
print(f'Logging this fold at {new_trainer.logger.log_dir}')
logger = new_trainer.logger
if isinstance(logger, CustomLogger):
logger.set_fold(fold)
print(f'Logging this fold at {logger.experiment.dir}')
else:
rank_zero_warn(f"Using k-fold cross-validation without an adapted logger class")
new_trainer.fit(train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, *args, **kwargs)


# extend CSVLogger to include fold number in log path
class CSVLoggerCVSupport(CSVLogger):

def __init__(self, save_dir: _PATH, name: str = "lightning_logs", version: Optional[Union[int, str]] = None,
prefix: str = "", flush_logs_every_n_steps: int = 100, fold: int = None):
super().__init__(save_dir, name, version, prefix, flush_logs_every_n_steps)
self.fold = fold

@property
def log_dir(self) -> str:
"""The log directory for this run.
def log_dir(self) -> Optional[str]:
if len(self.loggers) > 0:
logger = self.loggers[0]
if isinstance(logger, WandbLogger):
dirpath = logger.experiment.dir
else:
dirpath = self.loggers[0].log_dir
else:
dirpath = self.default_root_dir

By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the
constructor's version parameter instead of ``None`` or an int.
Additionally: Save data for each fold separately
"""
# create a pseudo standard path
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
if self.fold is None:
return os.path.join(self.root_dir, version)
return os.path.join(self.root_dir, version, f'fold_{self.fold}')
dirpath = self.strategy.broadcast(dirpath)
return dirpath


class ModelCheckpointCVSupport(ModelCheckpoint):
Expand Down Expand Up @@ -114,7 +108,7 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
version = trainer.loggers[0].version
version = version if isinstance(version, str) else f"version_{version}"
cv_logger = trainer.loggers[0]
if isinstance(cv_logger, CSVLoggerCVSupport) and cv_logger.fold is not None:
if isinstance(cv_logger, CustomLogger) and cv_logger.fold is not None:
# log_dir includes fold
ckpt_path = os.path.join(cv_logger.log_dir, "checkpoints")
else:
Expand Down
7 changes: 4 additions & 3 deletions configs/training/default_callbacks.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
- class_path: chebai.trainer.InnerCVTrainer.ModelCheckpoint
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_micro-f1
mode: 'max'
filename: 'best_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}'
every_n_epochs: 1
save_top_k: 5
- class_path: chebai.trainer.InnerCVTrainer.ModelCheckpointCVSupport
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: 'per_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}'
every_n_epochs: 5
save_top_k: -1
- class_path: chebai.callbacks.epoch_metrics.EpochLevelMacroF1
- class_path: chebai.callbacks.epoch_metrics.EpochLevelMacroF1
#class_path: chebai.callbacks.save_config_callback.CustomSaveConfigCallback
3 changes: 2 additions & 1 deletion configs/training/default_trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ min_epochs: 100
max_epochs: 100
default_root_dir: &default_root_dir logs
logger:
class_path: lightning.pytorch.loggers.WandbLogger
class_path: chebai.loggers.custom.CustomLogger
init_args:
save_dir: *default_root_dir
project: 'chebai'
entity: 'chebai'
log_model: 'all'
callbacks: default_callbacks.yml

0 comments on commit 9e0bbd5

Please sign in to comment.