From f4ceb41bc4d694ea08ae07d066f5a96eb4c4bfcd Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 12 May 2020 22:25:29 +0200
Subject: [PATCH] [engine.trainer] Implement per-epoch possible validation with
 checkpointing

---
 bob/ip/binseg/configs/datasets/__init__.py | 15 ++++-
 bob/ip/binseg/engine/ssltrainer.py         | 67 +++++++++++++++++++++-
 bob/ip/binseg/engine/trainer.py            | 51 +++++++++++++++-
 bob/ip/binseg/script/experiment.py         |  4 ++
 bob/ip/binseg/script/train.py              | 22 ++++++-
 bob/ip/binseg/test/test_cli.py             | 56 ++++++++++++------
 6 files changed, 191 insertions(+), 24 deletions(-)

diff --git a/bob/ip/binseg/configs/datasets/__init__.py b/bob/ip/binseg/configs/datasets/__init__.py
index f2621a72..4cdf1973 100644
--- a/bob/ip/binseg/configs/datasets/__init__.py
+++ b/bob/ip/binseg/configs/datasets/__init__.py
@@ -141,7 +141,11 @@ def make_dataset(subsets, transforms):
         A dictionary that contains the delayed sample lists for a number of
         named lists.  If one of the keys is ``train``, our standard dataset
         augmentation transforms are appended to the definition of that subset.
-        All other subsets remain un-augmented.
+        All other subsets remain un-augmented.  If one of the keys is
+        ``valid``, then this dataset will be also copied to the ``__valid__``
+        hidden dataset and will be used for validation during training.
+        Otherwise, if no ``valid`` subset is available, we set ``__valid__`` to
+        be the same as the unaugmented ``train`` subset, if one is available.
 
     transforms : list
         A list of transforms that needs to be applied to all samples in the set
@@ -166,5 +170,14 @@ def make_dataset(subsets, transforms):
                     transforms=transforms,
                     suffixes=(RANDOM_ROTATION + RANDOM_FLIP_JITTER),
                     )
+        if key == "valid":
+            # also use it for validation during training
+            retval["__valid__"] = retval[key]
+
+    if ("__train__" in retval) and ("train" in retval) \
+            and ("__valid__" not in retval):
+        # if the dataset does not have a validation set, we use the unaugmented
+        # training set as validation set
+        retval["__valid__"] = retval["train"]
 
     return retval
diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 9db427a7..6905a645 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -2,6 +2,7 @@
 # -*- coding: utf-8 -*-
 
 import os
+import sys
 import csv
 import time
 import shutil
@@ -167,6 +168,7 @@ def guess_labels(unlabelled_images, model):
 def run(
     model,
     data_loader,
+    valid_loader,
     optimizer,
     criterion,
     scheduler,
@@ -192,6 +194,11 @@ def run(
         Network (e.g. driu, hed, unet)
 
     data_loader : :py:class:`torch.utils.data.DataLoader`
+        To be used to train the model
+
+    valid_loader : :py:class:`torch.utils.data.DataLoader`
+        To be used to validate the model and enable automatic checkpointing.
+        If set to ``None``, then do not validate it.
 
     optimizer : :py:mod:`torch.optim`
 
@@ -277,10 +284,16 @@ def run(
         "median_unlabelled_loss",
         "learning_rate",
     )
+    if valid_loader is not None:
+        logfile_fields += ("validation_average_loss", "validation_median_loss")
     logfile_fields += tuple([k[0] for k in cpu_log()])
     if device != "cpu":
         logfile_fields += tuple([k[0] for k in gpu_log()])
 
+    # the lowest validation loss obtained so far - this value is updated only
+    # if a validation set is available
+    lowest_validation_loss = sys.float_info.max
+
     with open(logfile_name, "a+", newline="") as logfile:
         logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
 
@@ -313,6 +326,7 @@ def run(
             # Epoch time
             start_epoch_time = time.time()
 
+            # progress bar only on interactive jobs
             for samples in tqdm(
                 data_loader, desc="batch", leave=False, disable=None
             ):
@@ -330,8 +344,6 @@ def run(
                 unlabelled_ground_truths = guess_labels(
                     unlabelled_images, model
                 )
-                # unlabelled_ground_truths = sharpen(unlabelled_ground_truths,0.5)
-                # images, ground_truths, unlabelled_images, unlabelled_ground_truths = mix_up(0.75, images, ground_truths, unlabelled_images, unlabelled_ground_truths)
 
                 # loss evaluation and learning (backward step)
                 ramp_up_factor = square_rampup(
@@ -356,9 +368,49 @@ def run(
             if PYTORCH_GE_110:
                 scheduler.step()
 
+            # calculates the validation loss if necessary
+            valid_losses = None
+            if valid_loader is not None:
+                valid_losses = SmoothedValue(len(valid_loader))
+                for samples in tqdm(
+                    valid_loader, desc="valid", leave=False, disable=None
+                ):
+
+                    # labelled
+                    images = samples[1].to(device)
+                    ground_truths = samples[2].to(device)
+                    unlabelled_images = samples[4].to(device)
+                    # labelled outputs
+                    outputs = model(images)
+                    unlabelled_outputs = model(unlabelled_images)
+                    # guessed unlabelled outputs
+                    unlabelled_ground_truths = guess_labels(
+                        unlabelled_images, model
+                    )
+                    loss, ll, ul = criterion(
+                        outputs,
+                        ground_truths,
+                        unlabelled_outputs,
+                        unlabelled_ground_truths,
+                        ramp_up_factor,
+                    )
+
+                    valid_losses.update(loss)
+
             if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save(f"model_{epoch:03d}", **arguments)
 
+            if (
+                valid_losses is not None
+                and valid_losses.avg < lowest_validation_loss
+            ):
+                lowest_validation_loss = valid_losses.avg
+                logger.info(
+                    f"Found new low on validation set:"
+                    f" {lowest_validation_loss:.6f}"
+                )
+                checkpointer.save(f"model_lowest_valid_loss", **arguments)
+
             if epoch >= max_epoch:
                 checkpointer.save("model_final", **arguments)
 
@@ -380,7 +432,16 @@ def run(
                 ("median_labelled_loss", f"{labelled_loss.median:.6f}"),
                 ("median_unlabelled_loss", f"{unlabelled_loss.median:.6f}"),
                 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
-            ) + cpu_log()
+            )
+            if valid_losses is not None:
+                logdata += (
+                    ("validation_average_loss", f"{valid_losses.avg:.6f}"),
+                    ("validation_median_loss", f"{valid_losses.median:.6f}"),
+                )
+            logdata += cpu_log()
+            if device != "cpu":
+                logdata += gpu_log()
+
             if device != "cpu":
                 logdata += gpu_log()
 
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 81b191dd..7d1f841b 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -2,6 +2,7 @@
 # -*- coding: utf-8 -*-
 
 import os
+import sys
 import csv
 import time
 import shutil
@@ -25,6 +26,7 @@ PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
 def run(
     model,
     data_loader,
+    valid_loader,
     optimizer,
     criterion,
     scheduler,
@@ -48,6 +50,11 @@ def run(
         Network (e.g. driu, hed, unet)
 
     data_loader : :py:class:`torch.utils.data.DataLoader`
+        To be used to train the model
+
+    valid_loader : :py:class:`torch.utils.data.DataLoader`
+        To be used to validate the model and enable automatic checkpointing.
+        If set to ``None``, then do not validate it.
 
     optimizer : :py:mod:`torch.optim`
 
@@ -127,10 +134,16 @@ def run(
         "median_loss",
         "learning_rate",
     )
+    if valid_loader is not None:
+        logfile_fields += ("validation_average_loss", "validation_median_loss")
     logfile_fields += tuple([k[0] for k in cpu_log()])
     if device != "cpu":
         logfile_fields += tuple([k[0] for k in gpu_log()])
 
+    # the lowest validation loss obtained so far - this value is updated only
+    # if a validation set is available
+    lowest_validation_loss = sys.float_info.max
+
     with open(logfile_name, "a+", newline="") as logfile:
         logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
 
@@ -187,9 +200,39 @@ def run(
             if PYTORCH_GE_110:
                 scheduler.step()
 
+            # calculates the validation loss if necessary
+            valid_losses = None
+            if valid_loader is not None:
+                valid_losses = SmoothedValue(len(valid_loader))
+                for samples in tqdm(
+                    valid_loader, desc="valid", leave=False, disable=None
+                ):
+                    # data forwarding on the existing network
+                    images = samples[1].to(device)
+                    ground_truths = samples[2].to(device)
+                    masks = None
+                    if len(samples) == 4:
+                        masks = samples[-1].to(device)
+
+                    outputs = model(images)
+
+                    loss = criterion(outputs, ground_truths, masks)
+                    valid_losses.update(loss)
+
             if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save(f"model_{epoch:03d}", **arguments)
 
+            if (
+                valid_losses is not None
+                and valid_losses.avg < lowest_validation_loss
+            ):
+                lowest_validation_loss = valid_losses.avg
+                logger.info(
+                    f"Found new low on validation set:"
+                    f" {lowest_validation_loss:.6f}"
+                )
+                checkpointer.save(f"model_lowest_valid_loss", **arguments)
+
             if epoch >= max_epoch:
                 checkpointer.save("model_final", **arguments)
 
@@ -209,7 +252,13 @@ def run(
                 ("average_loss", f"{losses.avg:.6f}"),
                 ("median_loss", f"{losses.median:.6f}"),
                 ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
-            ) + cpu_log()
+            )
+            if valid_losses is not None:
+                logdata += (
+                    ("validation_average_loss", f"{valid_losses.avg:.6f}"),
+                    ("validation_median_loss", f"{valid_losses.median:.6f}"),
+                )
+            logdata += cpu_log()
             if device != "cpu":
                 logdata += gpu_log()
 
diff --git a/bob/ip/binseg/script/experiment.py b/bob/ip/binseg/script/experiment.py
index 9afa6f5b..30d14e2c 100644
--- a/bob/ip/binseg/script/experiment.py
+++ b/bob/ip/binseg/script/experiment.py
@@ -274,6 +274,10 @@ def experiment(
 
     * ``__train__``: dataset used for training, prioritarily.  It is typically
       the dataset containing data augmentation pipelines.
+    * ``__valid__``: dataset used for validation.  It is typically disjoint
+      from the training and test sets.  In such a case, we checkpoint the model
+      with the lowest loss on the validation set as well, throughout all the
+      training, besides the model at the end of training.
     * ``train`` (optional): a copy of the ``__train__`` dataset, without data
       augmentation, that will be evaluated alongside other sets available
     * ``*``: any other name, not starting with an underscore character (``_``),
diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py
index 3076aae4..da56a3bd 100644
--- a/bob/ip/binseg/script/train.py
+++ b/bob/ip/binseg/script/train.py
@@ -71,7 +71,9 @@ logger = logging.getLogger(__name__)
     "training the network model.  The dataset description must include all "
     "required pre-processing, including eventual data augmentation.  If a "
     "dataset named ``__train__`` is available, it is used prioritarily for "
-    "training instead of ``train``.",
+    "training instead of ``train``.  If a dataset named ``__valid__`` is "
+    "available, it is used for model validation (and automatic check-pointing) "
+    "at each epoch.",
     required=True,
     cls=ResourceOption,
 )
@@ -227,6 +229,7 @@ def train(
     torch.manual_seed(seed)
 
     use_dataset = dataset
+    validation_dataset = None
     if isinstance(dataset, dict):
         if "__train__" in dataset:
             logger.info("Found (dedicated) '__train__' set for training")
@@ -234,6 +237,11 @@ def train(
         else:
             use_dataset = dataset["train"]
 
+        if "__valid__" in dataset:
+            logger.info("Found (dedicated) '__valid__' set for validation")
+            logger.info("Will checkpoint lowest loss model on validation set")
+            validation_dataset = dataset["__valid__"]
+
     # PyTorch dataloader
     data_loader = DataLoader(
         dataset=use_dataset,
@@ -243,6 +251,16 @@ def train(
         pin_memory=torch.cuda.is_available(),
     )
 
+    valid_loader = None
+    if validation_dataset is not None:
+        valid_loader = DataLoader(
+                dataset=validation_dataset,
+                batch_size=batch_size,
+                shuffle=False,
+                drop_last=False,
+                pin_memory=torch.cuda.is_available(),
+                )
+
     # Checkpointer
     checkpointer = DetectronCheckpointer(
         model, optimizer, scheduler, save_dir=output_folder, save_to_disk=True
@@ -262,6 +280,7 @@ def train(
         run(
             model,
             data_loader,
+            valid_loader,
             optimizer,
             criterion,
             scheduler,
@@ -277,6 +296,7 @@ def train(
         run(
             model,
             data_loader,
+            valid_loader,
             optimizer,
             criterion,
             scheduler,
diff --git a/bob/ip/binseg/test/test_cli.py b/bob/ip/binseg/test/test_cli.py
index 3d3f728f..9a7a9847 100644
--- a/bob/ip/binseg/test/test_cli.py
+++ b/bob/ip/binseg/test/test_cli.py
@@ -88,14 +88,14 @@ def _check_experiment_stare(overlay):
 
         output_folder = "results"
         options = [
-                "m2unet",
-                config.name,
-                "-vv",
-                "--epochs=1",
-                "--batch-size=1",
-                "--steps=10",
-                f"--output-folder={output_folder}",
-                ]
+            "m2unet",
+            config.name,
+            "-vv",
+            "--epochs=1",
+            "--batch-size=1",
+            "--steps=10",
+            f"--output-folder={output_folder}",
+        ]
         if overlay:
             options += ["--overlayed"]
         result = runner.invoke(experiment, options)
@@ -107,6 +107,9 @@ def _check_experiment_stare(overlay):
         # check model was saved
         train_folder = os.path.join(output_folder, "model")
         assert os.path.exists(os.path.join(train_folder, "model_final.pth"))
+        assert os.path.exists(
+            os.path.join(train_folder, "model_lowest_valid_loss.pth")
+        )
         assert os.path.exists(os.path.join(train_folder, "last_checkpoint"))
         assert os.path.exists(os.path.join(train_folder, "constants.csv"))
         assert os.path.exists(os.path.join(train_folder, "trainlog.csv"))
@@ -123,7 +126,9 @@ def _check_experiment_stare(overlay):
         if overlay:
             # check overlayed images are there (since we requested them)
             assert os.path.exists(basedir)
-            nose.tools.eq_(len(fnmatch.filter(os.listdir(basedir), "*.png")), 20)
+            nose.tools.eq_(
+                len(fnmatch.filter(os.listdir(basedir), "*.png")), 20
+            )
         else:
             assert not os.path.exists(basedir)
 
@@ -135,7 +140,7 @@ def _check_experiment_stare(overlay):
             os.path.join(eval_folder, "second-annotator", "train.csv")
         )
         assert os.path.exists(
-            os.path.join(eval_folder, "second-annotator" , "test.csv")
+            os.path.join(eval_folder, "second-annotator", "test.csv")
         )
 
         overlay_folder = os.path.join(output_folder, "overlayed", "analysis")
@@ -143,18 +148,23 @@ def _check_experiment_stare(overlay):
         if overlay:
             # check overlayed images are there (since we requested them)
             assert os.path.exists(basedir)
-            nose.tools.eq_(len(fnmatch.filter(os.listdir(basedir), "*.png")), 20)
+            nose.tools.eq_(
+                len(fnmatch.filter(os.listdir(basedir), "*.png")), 20
+            )
         else:
             assert not os.path.exists(basedir)
 
         # check overlayed images from first-to-second annotator comparisons
         # are there (since we requested them)
-        overlay_folder = os.path.join(output_folder, "overlayed", "analysis",
-                "second-annotator")
+        overlay_folder = os.path.join(
+            output_folder, "overlayed", "analysis", "second-annotator"
+        )
         basedir = os.path.join(overlay_folder, "stare-images")
         if overlay:
             assert os.path.exists(basedir)
-            nose.tools.eq_(len(fnmatch.filter(os.listdir(basedir), "*.png")), 20)
+            nose.tools.eq_(
+                len(fnmatch.filter(os.listdir(basedir), "*.png")), 20
+            )
         else:
             assert not os.path.exists(basedir)
 
@@ -165,10 +175,13 @@ def _check_experiment_stare(overlay):
         keywords = {
             r"^Started training$": 1,
             r"^Found \(dedicated\) '__train__' set for training$": 1,
+            r"^Found \(dedicated\) '__valid__' set for validation$": 1,
+            r"^Will checkpoint lowest loss model on validation set$": 1,
             r"^Continuing from epoch 0$": 1,
             r"^Saving model summary at.*$": 1,
             r"^Model has.*$": 1,
-            r"^Saving checkpoint": 1,
+            r"^Found new low on validation set.*$": 1,
+            r"^Saving checkpoint": 2,
             r"^Ended training$": 1,
             r"^Started prediction$": 1,
             r"^Loading checkpoint from": 2,
@@ -223,7 +236,7 @@ def _check_train(runner):
         config.write(
             "from bob.ip.binseg.configs.datasets.stare import _maker\n"
         )
-        config.write("dataset = _maker('ah', _raw)['train']\n")
+        config.write("dataset = _maker('ah', _raw)\n")
         config.flush()
 
         output_folder = "results"
@@ -241,15 +254,21 @@ def _check_train(runner):
         _assert_exit_0(result)
 
         assert os.path.exists(os.path.join(output_folder, "model_final.pth"))
+        assert os.path.exists(
+            os.path.join(output_folder, "model_lowest_valid_loss.pth")
+        )
         assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
         assert os.path.exists(os.path.join(output_folder, "constants.csv"))
         assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
         assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
 
         keywords = {
+            r"^Found \(dedicated\) '__train__' set for training$": 1,
+            r"^Found \(dedicated\) '__valid__' set for validation$": 1,
             r"^Continuing from epoch 0$": 1,
             r"^Saving model summary at.*$": 1,
             r"^Model has.*$": 1,
+            rf"^Saving checkpoint to {output_folder}/model_lowest_valid_loss.pth$": 1,
             rf"^Saving checkpoint to {output_folder}/model_final.pth$": 1,
             r"^Total training time:": 1,
         }
@@ -364,8 +383,9 @@ def _check_evaluate(runner):
         _assert_exit_0(result)
 
         assert os.path.exists(os.path.join(output_folder, "test.csv"))
-        assert os.path.exists(os.path.join(output_folder,
-            "second-annotator", "test.csv"))
+        assert os.path.exists(
+            os.path.join(output_folder, "second-annotator", "test.csv")
+        )
 
         # check overlayed images are there (since we requested them)
         basedir = os.path.join(overlay_folder, "stare-images")
-- 
GitLab