From b81c10c52767578bf61bda61953e69d25cd4d3f7 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.anjos@idiap.ch>
Date: Sat, 4 Apr 2020 19:10:59 +0200
Subject: [PATCH] [checkpointer] Make it flexible to directory renames; Fix
 saving of trainlog in case we are interrupted

---
 bob/ip/binseg/engine/trainer.py            | 14 ++++++++++----
 bob/ip/binseg/script/train.py              |  6 +++---
 bob/ip/binseg/utils/checkpointer.py        | 22 ++++++++++++----------
 bob/ip/binseg/utils/model_serialization.py |  6 ++++--
 4 files changed, 29 insertions(+), 19 deletions(-)

diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 44ab76f1..8a4b2e50 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -75,12 +75,18 @@ def do_train(
 
     # Log to file
     logfile_name = os.path.join(output_folder, "trainlog.csv")
-    logfile_fields = ("epoch", "total-time", "eta", "average-loss",
-            "median-loss", "learning-rate", "memory-megabytes")
 
-    with open(logfile_name, "w", newline="") as logfile:
+    if arguments["epoch"] == 0 and os.path.exists(logfile_name):
+        logger.info(f"Truncating {logfile_name} - training is restarting...")
+        os.unlink(logfile_name)
+
+    logfile_fields = ("epoch", "total-time", "eta", "average-loss",
+            "median-loss", "learning-rate", "gpu-memory-megabytes")
+    with open(logfile_name, "a+", newline="") as logfile:
         logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
-        logwriter.writeheader()
+
+        if arguments["epoch"] == 0:
+            logwriter.writeheader()
 
         model.train().to(device)
         for state in optimizer.state.values():
diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py
index 5a2ce6f1..12a939e8 100644
--- a/bob/ip/binseg/script/train.py
+++ b/bob/ip/binseg/script/train.py
@@ -117,7 +117,7 @@ logger = logging.getLogger(__name__)
 @click.option(
     "--epochs",
     "-e",
-    help="Number of epochs used for training",
+    help="Number of epochs (complete training set passes) to train for",
     show_default=True,
     required=True,
     default=1000,
@@ -126,8 +126,8 @@ logger = logging.getLogger(__name__)
 @click.option(
     "--checkpoint-period",
     "-p",
-    help="Number of epochs after which a checkpoint is saved.  "
-    "A value of zero will disable check-pointing.  If checkpointing is "
+    help="Number of epochs after which a checkpoint is saved. "
+    "A value of zero will disable check-pointing. If checkpointing is "
     "enabled and training stops, it is automatically resumed from the "
     "last saved checkpoint if training is restarted with the same "
     "configuration.",
diff --git a/bob/ip/binseg/utils/checkpointer.py b/bob/ip/binseg/utils/checkpointer.py
index 8c0def2e..4ae57e5c 100644
--- a/bob/ip/binseg/utils/checkpointer.py
+++ b/bob/ip/binseg/utils/checkpointer.py
@@ -46,10 +46,11 @@ class Checkpointer:
             data["scheduler"] = self.scheduler.state_dict()
         data.update(kwargs)
 
-        save_file = os.path.join(self.save_dir, "{}.pth".format(name))
-        logger.info("Saving checkpoint to {}".format(save_file))
+        dest_filename = f"{name}.pth"
+        save_file = os.path.join(self.save_dir, dest_filename)
+        logger.info(f"Saving checkpoint to {save_file}")
         torch.save(data, save_file)
-        self.tag_last_checkpoint(save_file)
+        self.tag_last_checkpoint(dest_filename)
 
     def load(self, f=None):
         if self.has_checkpoint():
@@ -59,14 +60,14 @@ class Checkpointer:
             # no checkpoint could be found
             logger.warn("No checkpoint found. Initializing model from scratch")
             return {}
-        logger.info("Loading checkpoint from {}".format(f))
         checkpoint = self._load_file(f)
         self._load_model(checkpoint)
+        actual_file = os.path.join(self.save_dir, f)
         if "optimizer" in checkpoint and self.optimizer:
-            logger.info("Loading optimizer from {}".format(f))
+            logger.info(f"Loading optimizer from {actual_file}")
             self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
         if "scheduler" in checkpoint and self.scheduler:
-            logger.info("Loading scheduler from {}".format(f))
+            logger.info(f"Loading scheduler from {actual_file}")
             self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
 
         # return any further checkpoint data
@@ -94,7 +95,9 @@ class Checkpointer:
             f.write(last_filename)
 
     def _load_file(self, f):
-        return torch.load(f, map_location=torch.device("cpu"))
+        actual_file = os.path.join(self.save_dir, f)
+        logger.info(f"Loading checkpoint from {actual_file}")
+        return torch.load(actual_file, map_location=torch.device("cpu"))
 
     def _load_model(self, checkpoint):
         load_state_dict(self.model, checkpoint.pop("model"))
@@ -108,10 +111,9 @@ class DetectronCheckpointer(Checkpointer):
         scheduler=None,
         save_dir="",
         save_to_disk=None,
-        logger=None,
     ):
         super(DetectronCheckpointer, self).__init__(
-            model, optimizer, scheduler, save_dir, save_to_disk, logger
+            model, optimizer, scheduler, save_dir, save_to_disk
         )
 
     def _load_file(self, f):
@@ -119,7 +121,7 @@ class DetectronCheckpointer(Checkpointer):
         if f.startswith("http"):
             # if the file is a url path, download it and cache it
             cached_f = cache_url(f)
-            logger.info("url {} cached in {}".format(f, cached_f))
+            logger.info(f"url {f} cached in {cached_f}")
             f = cached_f
         # load checkpoint
         loaded = super(DetectronCheckpointer, self)._load_file(f)
diff --git a/bob/ip/binseg/utils/model_serialization.py b/bob/ip/binseg/utils/model_serialization.py
index 016f085e..4c84e84f 100644
--- a/bob/ip/binseg/utils/model_serialization.py
+++ b/bob/ip/binseg/utils/model_serialization.py
@@ -1,7 +1,10 @@
 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
 # https://github.com/facebookresearch/maskrcnn-benchmark
+
 from collections import OrderedDict
+
 import logging
+logger = logging.getLogger(__name__)
 
 import torch
 
@@ -39,14 +42,13 @@ def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
     max_size = max([len(key) for key in current_keys]) if current_keys else 1
     max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
     log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
-    logger = logging.getLogger(__name__)
     for idx_new, idx_old in enumerate(idxs.tolist()):
         if idx_old == -1:
             continue
         key = current_keys[idx_new]
         key_old = loaded_keys[idx_old]
         model_state_dict[key] = loaded_state_dict[key_old]
-        logger.info(
+        logger.debug(
             log_str_template.format(
                 key,
                 max_size,
-- 
GitLab