From 44e40482f03d6e8ee742b45cda7bee67c9392455 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Wed, 23 May 2018 15:18:47 +0200
Subject: [PATCH] Implement model saving in bob tf eval. Fixes #54

---
 bob/learn/tensorflow/script/eval.py           | 105 ++++++++++++++----
 .../tensorflow/test/test_estimator_scripts.py |  37 +++++-
 2 files changed, 121 insertions(+), 21 deletions(-)

diff --git a/bob/learn/tensorflow/script/eval.py b/bob/learn/tensorflow/script/eval.py
index 4a80f24a..1d106fff 100644
--- a/bob/learn/tensorflow/script/eval.py
+++ b/bob/learn/tensorflow/script/eval.py
@@ -4,20 +4,86 @@
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
+import click
 import logging
 import os
-import time
 import six
+import shutil
 import sys
 import tensorflow as tf
+import time
+from glob import glob
+from collections import defaultdict
 from ..utils.eval import get_global_step
-import click
 from bob.extension.scripts.click_helper import (
     verbosity_option, ConfigCommand, ResourceOption)
+from bob.io.base import create_directories_safe
 
 logger = logging.getLogger(__name__)
 
 
+def save_best_n_models(train_dir, save_dir, evaluated_file,
+                       keep_n_best_models):
+    create_directories_safe(save_dir)
+    evaluated = read_evaluated_file(evaluated_file)
+
+    def _key(x):
+        x = x[1]
+        ac = x.get('accuracy')
+        lo = x.get('loss') or 0
+        return ac * -1 if ac is not None else lo
+
+    best_models = dict(sorted(
+        evaluated.items(), key=_key)[:keep_n_best_models])
+
+    # delete the old saved models that are not in top N best anymore
+    saved_models = defaultdict(list)
+    for path in glob('{}/model.ckpt-*'.format(save_dir)):
+        global_step = path.split('model.ckpt-')[1].split('.')[0]
+        saved_models[global_step].append(path)
+
+    for global_step, paths in saved_models.items():
+        if global_step not in best_models:
+            for path in paths:
+                logger.info("Deleting `%s'", path)
+                os.remove(path)
+
+    # copy over the best models if not already there
+    for global_step in best_models:
+        for path in glob('{}/model.ckpt-{}.*'.format(train_dir, global_step)):
+            dst = os.path.join(save_dir, os.path.basename(path))
+            if os.path.isfile(dst):
+                continue
+            logger.info("Copying `%s' over to `%s'", path, dst)
+            shutil.copy(path, dst)
+
+
+def read_evaluated_file(path):
+    evaluated = {}
+    with open(path) as f:
+        for line in f:
+            global_step, line = line.split(' ', 1)
+            temp = {}
+            for k_v in line.strip().split(', '):
+                k, v = k_v.split(' = ')
+                v = float(v)
+                if 'global_step' in k:
+                    v = int(v)
+                temp[k] = v
+            evaluated[global_step] = temp
+    return evaluated
+
+
+def append_evaluated_file(path, evaluations):
+    str_evaluations = ', '.join(
+        '%s = %s' % (k, v)
+        for k, v in sorted(six.iteritems(evaluations)))
+    with open(path, 'a') as f:
+        f.write('{} {}\n'.format(evaluations['global_step'],
+                                 str_evaluations))
+    return str_evaluations
+
+
 @click.command(entry_point_group='bob.learn.tensorflow.config',
                cls=ConfigCommand)
 @click.option('--estimator', '-e', required=True, cls=ResourceOption,
@@ -28,12 +94,14 @@ logger = logging.getLogger(__name__)
               entry_point_group='bob.learn.tensorflow.hook')
 @click.option('--run-once', cls=ResourceOption, default=False,
               show_default=True)
-@click.option('--eval-interval-secs', cls=ResourceOption, type=click.types.INT,
+@click.option('--eval-interval-secs', cls=ResourceOption, type=click.INT,
               default=60, show_default=True)
 @click.option('--name', cls=ResourceOption)
+@click.option('--keep-n-best-models', '-K', type=click.INT, cls=ResourceOption,
+              default=0, show_default=True)
 @verbosity_option(cls=ResourceOption)
 def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
-         **kwargs):
+         keep_n_best_models, **kwargs):
     """Evaluates networks using Tensorflow estimators.
 
     \b
@@ -76,18 +144,16 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
     logger.debug('run_once: %s', run_once)
     logger.debug('eval_interval_secs: %s', eval_interval_secs)
     logger.debug('name: %s', name)
+    logger.debug('keep_n_best_models: %s', keep_n_best_models)
     logger.debug('kwargs: %s', kwargs)
 
-    if name:
-        real_name = 'eval_' + name
-    else:
-        real_name = 'eval'
-    evaluated_file = os.path.join(estimator.model_dir, real_name, 'evaluated')
+    real_name = 'eval_' + name if name else 'eval'
+    eval_dir = os.path.join(estimator.model_dir, real_name)
+    evaluated_file = os.path.join(eval_dir, 'evaluated')
     while True:
-        evaluated_steps = []
+        evaluated_steps = {}
         if os.path.exists(evaluated_file):
-            with open(evaluated_file) as f:
-                evaluated_steps = [line.split()[0] for line in f]
+            evaluated_steps = read_evaluated_file(evaluated_file)
 
         ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
         if (not ckpt) or (not ckpt.model_checkpoint_path):
@@ -113,14 +179,15 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
                 name=name,
             )
 
-            str_evaluations = ', '.join(
-                '%s = %s' % (k, v)
-                for k, v in sorted(six.iteritems(evaluations)))
-            print(str_evaluations)
+            str_evaluations = append_evaluated_file(
+                evaluated_file, evaluations)
+            click.echo(str_evaluations)
             sys.stdout.flush()
-            with open(evaluated_file, 'a') as f:
-                f.write('{} {}\n'.format(evaluations['global_step'],
-                                         str_evaluations))
+
+            # Save the best N models into the eval directory
+            save_best_n_models(estimator.model_dir, eval_dir, evaluated_file,
+                               keep_n_best_models)
+
         if run_once:
             break
         time.sleep(eval_interval_secs)
diff --git a/bob/learn/tensorflow/test/test_estimator_scripts.py b/bob/learn/tensorflow/test/test_estimator_scripts.py
index 7e69f8de..625c6962 100644
--- a/bob/learn/tensorflow/test/test_estimator_scripts.py
+++ b/bob/learn/tensorflow/test/test_estimator_scripts.py
@@ -1,6 +1,7 @@
 from __future__ import print_function
 import os
 import shutil
+from glob import glob
 from tempfile import mkdtemp
 from click.testing import CliRunner
 from bob.io.base.test_utils import datafile
@@ -122,7 +123,7 @@ def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord):
         result.exc_info, result.output, result.exception)
 
 
-def _eval(tmpdir, model_dir, dummy_tfrecord):
+def _eval(tmpdir, model_dir, dummy_tfrecord, extra_args=[]):
     config = CONFIG % {
         'model_dir': model_dir,
         'tfrecord_filenames': dummy_tfrecord
@@ -131,7 +132,7 @@ def _eval(tmpdir, model_dir, dummy_tfrecord):
     with open(config_path, 'w') as f:
         f.write(config)
     runner = CliRunner()
-    result = runner.invoke(eval_script, args=[config_path])
+    result = runner.invoke(eval_script, args=[config_path] + extra_args)
     assert result.exit_code == 0, '%s\n%s\n%s' % (
         result.exc_info, result.output, result.exception)
 
@@ -179,3 +180,35 @@ def test_eval():
             shutil.rmtree(tmpdir)
         except Exception:
             pass
+
+
+def test_eval_keep_n_model():
+    tmpdir = mkdtemp(prefix='bob_')
+    try:
+        model_dir = os.path.join(tmpdir, 'model_dir')
+        eval_dir = os.path.join(model_dir, 'eval')
+
+        print('\nCreating a dummy tfrecord')
+        dummy_tfrecord = _create_tfrecord(tmpdir)
+
+        print('Training a dummy network')
+        _create_checkpoint(tmpdir, model_dir, dummy_tfrecord)
+
+        print('Evaluating a dummy network')
+        _eval(tmpdir, model_dir, dummy_tfrecord, ['-K', '1'])
+
+        evaluated_path = os.path.join(eval_dir, 'evaluated')
+        assert os.path.exists(evaluated_path), evaluated_path
+        with open(evaluated_path) as f:
+            doc = f.read()
+
+        assert '1 ' in doc, doc
+        assert '200 ' in doc, doc
+        assert len(glob('{}/model.ckpt-*'.format(eval_dir))) == 3, \
+            os.listdir(eval_dir)
+
+    finally:
+        try:
+            shutil.rmtree(tmpdir)
+        except Exception:
+            pass
-- 
GitLab