Skip to content

Commit

Permalink
Add logging for target classes f1-scores
Browse files Browse the repository at this point in the history
  • Loading branch information
dawerner committed Nov 21, 2023
1 parent ea0b0ef commit cbc5930
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from lightning.pytorch.core.module import LightningModule
import torch
from typing import Optional, Dict, Any
from torchmetrics.classification import MultilabelF1Score

logging.getLogger("pysmiles").setLevel(logging.CRITICAL)

Expand All @@ -28,6 +29,7 @@ def __init__(
self.train_metrics = metrics["train"]
self.validation_metrics = metrics["validation"]
self.test_metrics = metrics["test"]
self.vector_metric = MultilabelF1Score(num_labels=out_dim, average='none')
self.pass_loss_kwargs = pass_loss_kwargs

def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -69,6 +71,11 @@ def test_step(self, batch, batch_idx):
def predict_step(self, batch, batch_idx, **kwargs):
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)

def print_vector_metric(self, file_path, classes_scores):
with open(file_path, 'a') as file:
line = ', '.join([str(score.item()) for score in classes_scores])
file.write(f'{line}\n')

def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=False):
data = self._process_batch(batch, batch_idx)
labels = data["labels"]
Expand Down Expand Up @@ -96,6 +103,8 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
)
if metrics and labels is not None:
pr, tar = self._get_prediction_and_labels(data, labels, model_output)

# handle scalar metrics
for metric_name, metric in metrics.items():
m = metric(pr, tar)
if isinstance(m, dict):
Expand All @@ -121,6 +130,12 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
logger=True,
sync_dist=sync_dist,
)

# handle f1 score vector of target classes
v_m = self.vector_metric(pr, tar)
if isinstance(v_m, torch.Tensor) and v_m.dim() > 0:
# path should be set dynamically
self.print_vector_metric(f'logs/electra_chebi_roles100/{prefix}vector_metric.csv', v_m)
return d

def forward(self, x):
Expand Down

0 comments on commit cbc5930

Please sign in to comment.