-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
sfluegel
committed
Dec 4, 2023
1 parent
b30c97e
commit 9e0bbd5
Showing
5 changed files
with
85 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from datetime import datetime | ||
from typing import Optional, Union, Literal | ||
|
||
import wandb | ||
from lightning.fabric.utilities.types import _PATH | ||
from lightning.pytorch.loggers import WandbLogger | ||
import os | ||
|
||
|
||
class CustomLogger(WandbLogger): | ||
"""Adds support for custom naming of runs and cross-validation""" | ||
|
||
def __init__(self, save_dir: _PATH, name: str = "logs", version: Optional[Union[int, str]] = None, prefix: str = "", | ||
fold: Optional[int] = None, project: Optional[str] = None, entity: Optional[str] = None, | ||
offline: bool = False, | ||
log_model: Union[Literal["all"], bool] = False, **kwargs): | ||
if version is None: | ||
version = f'{datetime.now():%y%m%d-%H%M}' | ||
self._version = version | ||
self._name = name | ||
self._fold = fold | ||
super().__init__(name=self.name, save_dir=save_dir, version=None, prefix=prefix, | ||
log_model=log_model, entity=entity, project=project, offline=offline, **kwargs) | ||
|
||
@property | ||
def name(self) -> Optional[str]: | ||
name = f'{self._name}_{self.version}' | ||
if self._fold is not None: | ||
name += f'_fold{self._fold}' | ||
return name | ||
|
||
@property | ||
def version(self) -> Optional[str]: | ||
return self._version | ||
|
||
@property | ||
def root_dir(self) -> Optional[str]: | ||
return os.path.join(self.save_dir, self.name) | ||
|
||
@property | ||
def log_dir(self) -> str: | ||
version = self.version if isinstance(self.version, str) else f"version_{self.version}" | ||
if self._fold is None: | ||
return os.path.join(self.root_dir, version) | ||
return os.path.join(self.root_dir, version, f'fold_{self._fold}') | ||
|
||
def set_fold(self, fold: int): | ||
if fold != self._fold: | ||
self._fold = fold | ||
# start new experiment | ||
wandb.finish() | ||
self._experiment = None | ||
_ = self.experiment | ||
|
||
@property | ||
def fold(self): | ||
return self._fold |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,14 @@ | ||
- class_path: chebai.trainer.InnerCVTrainer.ModelCheckpoint | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint | ||
init_args: | ||
monitor: val_micro-f1 | ||
mode: 'max' | ||
filename: 'best_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}' | ||
every_n_epochs: 1 | ||
save_top_k: 5 | ||
- class_path: chebai.trainer.InnerCVTrainer.ModelCheckpointCVSupport | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint | ||
init_args: | ||
filename: 'per_{epoch}_{val_loss:.4f}_{val_micro-f1:.2f}' | ||
every_n_epochs: 5 | ||
save_top_k: -1 | ||
- class_path: chebai.callbacks.epoch_metrics.EpochLevelMacroF1 | ||
- class_path: chebai.callbacks.epoch_metrics.EpochLevelMacroF1 | ||
#class_path: chebai.callbacks.save_config_callback.CustomSaveConfigCallback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters