Skip to content

Commit

Permalink
Merge branch 'features-sfluegel' into feature-pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel05 authored Nov 23, 2023
2 parents bb2340f + 8ff3d5f commit 5cfa50f
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 32 deletions.
31 changes: 0 additions & 31 deletions chebai/trainer/InnerCVTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(self, *args, **kwargs):
self.init_kwargs = kwargs
super().__init__(*args, **kwargs)
# instantiation custom logger connector
self._logger_connector = _LoggerConnectorCVSupport(self)
self._logger_connector.on_trainer_init(self.logger, 1)

def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwargs):
Expand All @@ -47,8 +46,6 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar
version=self.logger.version, fold=fold)
init_kwargs['logger'] = new_logger
new_trainer = Trainer(*self.init_args, **init_kwargs)
self._logger_connector = _LoggerConnectorCVSupport(self)
self._logger_connector.on_trainer_init(self.logger, 1)
print(f'Logging this fold at {new_trainer.logger.log_dir}')
new_trainer.fit(train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, *args, **kwargs)

Expand Down Expand Up @@ -128,31 +125,3 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:

print(f'Now using checkpoint path {ckpt_path}')
return ckpt_path


warning_cache = WarningCache()


class _LoggerConnectorCVSupport(_LoggerConnector):
def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> None:
if not logger:
# logger is None or logger is False
self.trainer.loggers = []
elif logger is True:
# default logger
if _TENSORBOARD_AVAILABLE or _TENSORBOARDX_AVAILABLE:
logger_ = TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())
else:
warning_cache.warn(
"Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch`"
" package, due to potential conflicts with other packages in the ML ecosystem. For this reason,"
" `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard`"
" or `tensorboardX` packages are found."
" Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default"
)
logger_ = CSVLogger(save_dir=self.trainer.default_root_dir) # type: ignore[assignment]
self.trainer.loggers = [logger_]
elif isinstance(logger, Iterable):
self.trainer.loggers = list(logger)
else:
self.trainer.loggers = [logger]
2 changes: 1 addition & 1 deletion configs/training/default_callbacks.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
- class_path: chebai.trainer.InnerCVTrainer.ModelCheckpointCVSupport
- class_path: chebai.trainer.InnerCVTrainer.ModelCheckpoint
init_args:
monitor: val_micro-f1
mode: 'max'
Expand Down

0 comments on commit 5cfa50f

Please sign in to comment.