diff --git a/best_checkpoint_copier/__init__.py b/best_checkpoint_copier/__init__.py index 337c7b6..c710f8f 100644 --- a/best_checkpoint_copier/__init__.py +++ b/best_checkpoint_copier/__init__.py @@ -3,6 +3,7 @@ import shutil import tensorflow as tf + class Checkpoint(object): dir = None file = None @@ -25,7 +26,13 @@ class BestCheckpointCopier(tf.estimator.Exporter): sort_key_fn = None sort_reverse = None - def __init__(self, name='best_checkpoints', checkpoints_to_keep=5, score_metric='Loss/total_loss', compare_fn=lambda x,y: x.score < y.score, sort_key_fn=lambda x: x.score, sort_reverse=False): + def __init__(self, name='best_checkpoints', + checkpoints_to_keep=5, + score_metric='Loss/total_loss', + compare_fn=lambda x,y: x.score < y.score, + sort_key_fn=lambda x: x.score, + sort_reverse=False, + patience=25): self.checkpoints = [] self.checkpoints_to_keep = checkpoints_to_keep self.compare_fn = compare_fn @@ -33,6 +40,8 @@ def __init__(self, name='best_checkpoints', checkpoints_to_keep=5, score_metric= self.score_metric = score_metric self.sort_key_fn = sort_key_fn self.sort_reverse = sort_reverse + self.early_stop_patient = 0 + self.patience = patience super(BestCheckpointCopier, self).__init__() def _copyCheckpoint(self, checkpoint): @@ -40,9 +49,13 @@ def _copyCheckpoint(self, checkpoint): os.makedirs(desination_dir, exist_ok=True) for file in glob.glob(r'{}*'.format(checkpoint.path)): + if not self.name in desination_dir: desination_dir = os.path.join(desination_dir, self.name) self._log('copying {} to {}'.format(file, desination_dir)) shutil.copy(file, desination_dir) + with open(desination_dir+"/meta.txt", "a") as f: + f.write(f"{file.split('/')[-1]} {round(checkpoint.score, 5)}\n") + def _destinationDir(self, checkpoint): return os.path.join(checkpoint.dir, self.name) @@ -53,6 +66,7 @@ def _keepCheckpoint(self, checkpoint): self.checkpoints = sorted(self.checkpoints, key=self.sort_key_fn, reverse=self.sort_reverse) self._copyCheckpoint(checkpoint) + self.early_stop_patient = 0 def _log(self, statement): tf.logging.info('[{}] {}'.format(self.__class__.__name__, statement)) @@ -83,7 +97,10 @@ def export(self, estimator, export_path, checkpoint_path, eval_result, is_the_fi checkpoint = Checkpoint(path=checkpoint_path, score=score) if self._shouldKeep(checkpoint): - self._keepCheckpoint(checkpoint) self._pruneCheckpoints(checkpoint) + self._keepCheckpoint(checkpoint) else: self._log('skipping checkpoint {}'.format(checkpoint.path)) + self.early_stop_patient += 1 + if self.early_stop_patient > self.patience: + raise ValueError (f"Stopping training. mAP@.50 didn't improve in last {self.patience} epochs.")