Skip to content

Commit

Permalink
split up logger config, link num_labels argument to macro-f1
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 22, 2023
1 parent 15c0a8a commit 6ffa25b
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 31 deletions.
9 changes: 5 additions & 4 deletions chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,20 @@ def apply_metric(self, target, pred, mode="train"):


class MacroF1(torchmetrics.Metric):
def __init__(self, n_labels, dist_sync_on_step=False, threshold=0.5):
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
print(f"got a num_labels argument: {num_labels}")
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.add_state(
"true_positives", default=torch.zeros((n_labels)), dist_reduce_fx="sum"
"true_positives", default=torch.zeros((num_labels)), dist_reduce_fx="sum"
)
self.add_state(
"positive_predictions",
default=torch.empty((n_labels)),
default=torch.empty((num_labels)),
dist_reduce_fx="sum",
)
self.add_state(
"positive_labels", default=torch.empty((n_labels)), dist_reduce_fx="sum"
"positive_labels", default=torch.empty((num_labels)), dist_reduce_fx="sum"
)
self.threshold = threshold

Expand Down
15 changes: 6 additions & 9 deletions chebai/cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Dict, Set

from lightning.pytorch.cli import LightningCLI
from chebai.trainer.InnerCVTrainer import InnerCVTrainer
from lightning.pytorch.cli import LightningCLI, LightningArgumentParser
from chebai.trainer.CustomTrainer import CustomTrainer
from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.models.base import ChebaiBaseNet


class ChebaiCLI(LightningCLI):
def __init__(self, *args, **kwargs):
super().__init__(trainer_class=InnerCVTrainer, *args, **kwargs)
super().__init__(trainer_class=CustomTrainer, *args, **kwargs)

def add_arguments_to_parser(self, parser):
def add_arguments_to_parser(self, parser: LightningArgumentParser):
for kind in ("train", "val", "test"):
for average in ("micro", "macro"):
parser.link_arguments(
Expand All @@ -19,11 +21,6 @@ def add_arguments_to_parser(self, parser):
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
)

# does any of this work? i wasnt able to find any evidence of cases where linked arguments are actually used
# why doesnt it work?
# parser.link_arguments("model.out_dim", "model.init_args.n_atom_properties")
# parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why

@staticmethod
def subcommands() -> Dict[str, Set[str]]:
"""Defines the list of available subcommands and the arguments to skip."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@
log = logging.getLogger(__name__)


class InnerCVTrainer(Trainer):
class CustomTrainer(Trainer):
def __init__(self, *args, **kwargs):
self.init_args = args
self.init_kwargs = kwargs
super().__init__(*args, **kwargs)
# instantiation custom logger connector
self._logger_connector.on_trainer_init(self.logger, 1)

def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwargs):
def cv_fit(self, datamodule: XYBaseDataModule, *args, **kwargs):
n_splits = datamodule.inner_k_folds
if n_splits < 2:
self.fit(datamodule=datamodule, *args, **kwargs)
else:
Expand All @@ -55,7 +56,7 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar
train_dataloader = datamodule.train_dataloader(ids=train_ids)
val_dataloader = datamodule.val_dataloader(ids=val_ids)
init_kwargs = self.init_kwargs
new_trainer = InnerCVTrainer(*self.init_args, **init_kwargs)
new_trainer = CustomTrainer(*self.init_args, **init_kwargs)
logger = new_trainer.logger
if isinstance(logger, CustomLogger):
logger.set_fold(fold)
Expand Down
5 changes: 1 addition & 4 deletions configs/metrics/micro-macro-f1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,4 @@ init_args:
init_args:
average: micro
macro-f1:
class_path: chebai.callbacks.epoch_metrics.MacroF1
init_args:
n_labels: 1446
# macro-f1 is a callback
class_path: chebai.callbacks.epoch_metrics.MacroF1
3 changes: 3 additions & 0 deletions configs/training/csv_logger.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class_path: lightning.pytorch.loggers.CSVLogger
init_args:
save_dir: logs
12 changes: 1 addition & 11 deletions configs/training/default_trainer.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
min_epochs: 100
max_epochs: 100
default_root_dir: &default_root_dir logs
strategy: 'ddp_find_unused_parameters_true'
logger:
# class_path: lightning.pytorch.loggers.CSVLogger
# init_args:
# save_dir: *default_root_dir
class_path: chebai.loggers.custom.CustomLogger # Extension of Wandb logger
init_args:
save_dir: *default_root_dir
project: 'chebai'
entity: 'chebai'
log_model: 'all'
logger: csv_logger.yml
callbacks: default_callbacks.yml
6 changes: 6 additions & 0 deletions configs/training/wandb_logger.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class_path: chebai.loggers.custom.CustomLogger # Extension of Wandb logger
init_args:
save_dir: logs
project: 'chebai'
entity: 'chebai'
log_model: 'all'

0 comments on commit 6ffa25b

Please sign in to comment.