diff --git a/chebai/callbacks/epoch_metrics.py b/chebai/callbacks/epoch_metrics.py index 30fde179..323797b1 100644 --- a/chebai/callbacks/epoch_metrics.py +++ b/chebai/callbacks/epoch_metrics.py @@ -149,19 +149,20 @@ def apply_metric(self, target, pred, mode="train"): class MacroF1(torchmetrics.Metric): - def __init__(self, n_labels, dist_sync_on_step=False, threshold=0.5): + def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5): + print(f"got a num_labels argument: {num_labels}") super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state( - "true_positives", default=torch.zeros((n_labels)), dist_reduce_fx="sum" + "true_positives", default=torch.zeros((num_labels)), dist_reduce_fx="sum" ) self.add_state( "positive_predictions", - default=torch.empty((n_labels)), + default=torch.empty((num_labels)), dist_reduce_fx="sum", ) self.add_state( - "positive_labels", default=torch.empty((n_labels)), dist_reduce_fx="sum" + "positive_labels", default=torch.empty((num_labels)), dist_reduce_fx="sum" ) self.threshold = threshold diff --git a/chebai/cli.py b/chebai/cli.py index a3d8a4f7..6d53b9fe 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -1,14 +1,16 @@ from typing import Dict, Set -from lightning.pytorch.cli import LightningCLI -from chebai.trainer.InnerCVTrainer import InnerCVTrainer +from lightning.pytorch.cli import LightningCLI, LightningArgumentParser +from chebai.trainer.CustomTrainer import CustomTrainer +from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.models.base import ChebaiBaseNet class ChebaiCLI(LightningCLI): def __init__(self, *args, **kwargs): - super().__init__(trainer_class=InnerCVTrainer, *args, **kwargs) + super().__init__(trainer_class=CustomTrainer, *args, **kwargs) - def add_arguments_to_parser(self, parser): + def add_arguments_to_parser(self, parser: LightningArgumentParser): for kind in ("train", "val", "test"): for average in ("micro", "macro"): parser.link_arguments( @@ -19,11 +21,6 @@ def add_arguments_to_parser(self, parser): "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" ) - # does any of this work? i wasnt able to find any evidence of cases where linked arguments are actually used - # why doesnt it work? - # parser.link_arguments("model.out_dim", "model.init_args.n_atom_properties") - # parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why - @staticmethod def subcommands() -> Dict[str, Set[str]]: """Defines the list of available subcommands and the arguments to skip.""" diff --git a/chebai/trainer/InnerCVTrainer.py b/chebai/trainer/CustomTrainer.py similarity index 95% rename from chebai/trainer/InnerCVTrainer.py rename to chebai/trainer/CustomTrainer.py index ca947380..431be8b5 100644 --- a/chebai/trainer/InnerCVTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -29,7 +29,7 @@ log = logging.getLogger(__name__) -class InnerCVTrainer(Trainer): +class CustomTrainer(Trainer): def __init__(self, *args, **kwargs): self.init_args = args self.init_kwargs = kwargs @@ -37,7 +37,8 @@ def __init__(self, *args, **kwargs): # instantiation custom logger connector self._logger_connector.on_trainer_init(self.logger, 1) - def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwargs): + def cv_fit(self, datamodule: XYBaseDataModule, *args, **kwargs): + n_splits = datamodule.inner_k_folds if n_splits < 2: self.fit(datamodule=datamodule, *args, **kwargs) else: @@ -55,7 +56,7 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar train_dataloader = datamodule.train_dataloader(ids=train_ids) val_dataloader = datamodule.val_dataloader(ids=val_ids) init_kwargs = self.init_kwargs - new_trainer = InnerCVTrainer(*self.init_args, **init_kwargs) + new_trainer = CustomTrainer(*self.init_args, **init_kwargs) logger = new_trainer.logger if isinstance(logger, CustomLogger): logger.set_fold(fold) diff --git a/configs/metrics/micro-macro-f1.yml b/configs/metrics/micro-macro-f1.yml index 7472e740..7273bd4c 100644 --- a/configs/metrics/micro-macro-f1.yml +++ b/configs/metrics/micro-macro-f1.yml @@ -6,7 +6,4 @@ init_args: init_args: average: micro macro-f1: - class_path: chebai.callbacks.epoch_metrics.MacroF1 - init_args: - n_labels: 1446 - # macro-f1 is a callback \ No newline at end of file + class_path: chebai.callbacks.epoch_metrics.MacroF1 \ No newline at end of file diff --git a/configs/training/csv_logger.yml b/configs/training/csv_logger.yml new file mode 100644 index 00000000..ed14c4e7 --- /dev/null +++ b/configs/training/csv_logger.yml @@ -0,0 +1,3 @@ +class_path: lightning.pytorch.loggers.CSVLogger +init_args: + save_dir: logs \ No newline at end of file diff --git a/configs/training/default_trainer.yml b/configs/training/default_trainer.yml index 392c2994..147c3500 100644 --- a/configs/training/default_trainer.yml +++ b/configs/training/default_trainer.yml @@ -1,15 +1,5 @@ min_epochs: 100 max_epochs: 100 default_root_dir: &default_root_dir logs -strategy: 'ddp_find_unused_parameters_true' -logger: -# class_path: lightning.pytorch.loggers.CSVLogger -# init_args: -# save_dir: *default_root_dir - class_path: chebai.loggers.custom.CustomLogger # Extension of Wandb logger - init_args: - save_dir: *default_root_dir - project: 'chebai' - entity: 'chebai' - log_model: 'all' +logger: csv_logger.yml callbacks: default_callbacks.yml \ No newline at end of file diff --git a/configs/training/wandb_logger.yml b/configs/training/wandb_logger.yml new file mode 100644 index 00000000..b7c51418 --- /dev/null +++ b/configs/training/wandb_logger.yml @@ -0,0 +1,6 @@ +class_path: chebai.loggers.custom.CustomLogger # Extension of Wandb logger +init_args: + save_dir: logs + project: 'chebai' + entity: 'chebai' + log_model: 'all' \ No newline at end of file