diff --git a/chebai/models/base.py b/chebai/models/base.py index b5080b6d..3a149832 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -122,8 +122,10 @@ def _log_metrics(self, prefix, metrics, batch_size): # don't use sync_dist=True if the metric is a torchmetrics-metric # (see https://github.com/Lightning-AI/pytorch-lightning/discussions/6501#discussioncomment-569757) for metric_name, metric in metrics.items(): - m = metric.compute() + m = None # m = metric.compute() if isinstance(m, dict): + # todo: is this case needed? it requires logging values directly which does not give accurate results + # with the current metric-setup for k, m2 in m.items(): self.log( f"{prefix}{metric_name}{k}", @@ -137,7 +139,7 @@ def _log_metrics(self, prefix, metrics, batch_size): else: self.log( f"{prefix}{metric_name}", - m, + metric, batch_size=batch_size, on_step=False, on_epoch=True,