diff --git a/python/train.py b/python/train.py index 5ba4ab7c4..7928234f4 100755 --- a/python/train.py +++ b/python/train.py @@ -909,6 +909,14 @@ def log_metrics(metric_sums, metric_weights, metrics, metrics_out): metrics_out.write(json.dumps(metrics_to_print) + "\n") metrics_out.flush() + def clear_metric_nonfinite(metric_sums, metric_weights): + for metric in metric_sums: + if not math.isfinite(metric_sums[metric]): + logging.warning(f"NONFINITE VALUE OF METRIC {metric}, CLEARING IT BACK TO EMPTY") + metric_sums[metric] = 0.0 + metric_weights[metric] = 0.0 + + if rank == 0: train_metrics_out = open(os.path.join(traindir,"metrics_train.json"),"a") val_metrics_out = open(os.path.join(traindir,"metrics_val.json"),"a") @@ -975,6 +983,8 @@ def log_metrics(metric_sums, metric_weights, metrics, metrics_out): logging.info("GC collect") gc.collect() + clear_metric_nonfinite(running_metrics["sums"], running_metrics["weights"]) + logging.info("=========================================================================") logging.info("BEGINNING NEXT EPOCH " + str(num_epochs_this_instance)) logging.info("=========================================================================")