Skip to content

Commit

Permalink
Clear metric nonfinite
Browse files Browse the repository at this point in the history
  • Loading branch information
lightvector committed Mar 10, 2024
1 parent 44b5447 commit aea2420
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,14 @@ def log_metrics(metric_sums, metric_weights, metrics, metrics_out):
metrics_out.write(json.dumps(metrics_to_print) + "\n")
metrics_out.flush()

def clear_metric_nonfinite(metric_sums, metric_weights):
for metric in metric_sums:
if not math.isfinite(metric_sums[metric]):
logging.warning(f"NONFINITE VALUE OF METRIC {metric}, CLEARING IT BACK TO EMPTY")
metric_sums[metric] = 0.0
metric_weights[metric] = 0.0


if rank == 0:
train_metrics_out = open(os.path.join(traindir,"metrics_train.json"),"a")
val_metrics_out = open(os.path.join(traindir,"metrics_val.json"),"a")
Expand Down Expand Up @@ -975,6 +983,8 @@ def log_metrics(metric_sums, metric_weights, metrics, metrics_out):
logging.info("GC collect")
gc.collect()

clear_metric_nonfinite(running_metrics["sums"], running_metrics["weights"])

logging.info("=========================================================================")
logging.info("BEGINNING NEXT EPOCH " + str(num_epochs_this_instance))
logging.info("=========================================================================")
Expand Down

0 comments on commit aea2420

Please sign in to comment.