From b81c10c52767578bf61bda61953e69d25cd4d3f7 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Sat, 4 Apr 2020 19:10:59 +0200 Subject: [PATCH] [checkpointer] Make it flexible to directory renames; Fix saving of trainlog in case we are interrupted --- bob/ip/binseg/engine/trainer.py | 14 ++++++++++---- bob/ip/binseg/script/train.py | 6 +++--- bob/ip/binseg/utils/checkpointer.py | 22 ++++++++++++---------- bob/ip/binseg/utils/model_serialization.py | 6 ++++-- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 44ab76f1..8a4b2e50 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -75,12 +75,18 @@ def do_train( # Log to file logfile_name = os.path.join(output_folder, "trainlog.csv") - logfile_fields = ("epoch", "total-time", "eta", "average-loss", - "median-loss", "learning-rate", "memory-megabytes") - with open(logfile_name, "w", newline="") as logfile: + if arguments["epoch"] == 0 and os.path.exists(logfile_name): + logger.info(f"Truncating {logfile_name} - training is restarting...") + os.unlink(logfile_name) + + logfile_fields = ("epoch", "total-time", "eta", "average-loss", + "median-loss", "learning-rate", "gpu-memory-megabytes") + with open(logfile_name, "a+", newline="") as logfile: logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) - logwriter.writeheader() + + if arguments["epoch"] == 0: + logwriter.writeheader() model.train().to(device) for state in optimizer.state.values(): diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py index 5a2ce6f1..12a939e8 100644 --- a/bob/ip/binseg/script/train.py +++ b/bob/ip/binseg/script/train.py @@ -117,7 +117,7 @@ logger = logging.getLogger(__name__) @click.option( "--epochs", "-e", - help="Number of epochs used for training", + help="Number of epochs (complete training set passes) to train for", show_default=True, required=True, default=1000, @@ -126,8 +126,8 @@ logger = logging.getLogger(__name__) @click.option( "--checkpoint-period", "-p", - help="Number of epochs after which a checkpoint is saved. " - "A value of zero will disable check-pointing. If checkpointing is " + help="Number of epochs after which a checkpoint is saved. " + "A value of zero will disable check-pointing. If checkpointing is " "enabled and training stops, it is automatically resumed from the " "last saved checkpoint if training is restarted with the same " "configuration.", diff --git a/bob/ip/binseg/utils/checkpointer.py b/bob/ip/binseg/utils/checkpointer.py index 8c0def2e..4ae57e5c 100644 --- a/bob/ip/binseg/utils/checkpointer.py +++ b/bob/ip/binseg/utils/checkpointer.py @@ -46,10 +46,11 @@ class Checkpointer: data["scheduler"] = self.scheduler.state_dict() data.update(kwargs) - save_file = os.path.join(self.save_dir, "{}.pth".format(name)) - logger.info("Saving checkpoint to {}".format(save_file)) + dest_filename = f"{name}.pth" + save_file = os.path.join(self.save_dir, dest_filename) + logger.info(f"Saving checkpoint to {save_file}") torch.save(data, save_file) - self.tag_last_checkpoint(save_file) + self.tag_last_checkpoint(dest_filename) def load(self, f=None): if self.has_checkpoint(): @@ -59,14 +60,14 @@ class Checkpointer: # no checkpoint could be found logger.warn("No checkpoint found. Initializing model from scratch") return {} - logger.info("Loading checkpoint from {}".format(f)) checkpoint = self._load_file(f) self._load_model(checkpoint) + actual_file = os.path.join(self.save_dir, f) if "optimizer" in checkpoint and self.optimizer: - logger.info("Loading optimizer from {}".format(f)) + logger.info(f"Loading optimizer from {actual_file}") self.optimizer.load_state_dict(checkpoint.pop("optimizer")) if "scheduler" in checkpoint and self.scheduler: - logger.info("Loading scheduler from {}".format(f)) + logger.info(f"Loading scheduler from {actual_file}") self.scheduler.load_state_dict(checkpoint.pop("scheduler")) # return any further checkpoint data @@ -94,7 +95,9 @@ class Checkpointer: f.write(last_filename) def _load_file(self, f): - return torch.load(f, map_location=torch.device("cpu")) + actual_file = os.path.join(self.save_dir, f) + logger.info(f"Loading checkpoint from {actual_file}") + return torch.load(actual_file, map_location=torch.device("cpu")) def _load_model(self, checkpoint): load_state_dict(self.model, checkpoint.pop("model")) @@ -108,10 +111,9 @@ class DetectronCheckpointer(Checkpointer): scheduler=None, save_dir="", save_to_disk=None, - logger=None, ): super(DetectronCheckpointer, self).__init__( - model, optimizer, scheduler, save_dir, save_to_disk, logger + model, optimizer, scheduler, save_dir, save_to_disk ) def _load_file(self, f): @@ -119,7 +121,7 @@ class DetectronCheckpointer(Checkpointer): if f.startswith("http"): # if the file is a url path, download it and cache it cached_f = cache_url(f) - logger.info("url {} cached in {}".format(f, cached_f)) + logger.info(f"url {f} cached in {cached_f}") f = cached_f # load checkpoint loaded = super(DetectronCheckpointer, self)._load_file(f) diff --git a/bob/ip/binseg/utils/model_serialization.py b/bob/ip/binseg/utils/model_serialization.py index 016f085e..4c84e84f 100644 --- a/bob/ip/binseg/utils/model_serialization.py +++ b/bob/ip/binseg/utils/model_serialization.py @@ -1,7 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # https://github.com/facebookresearch/maskrcnn-benchmark + from collections import OrderedDict + import logging +logger = logging.getLogger(__name__) import torch @@ -39,14 +42,13 @@ def align_and_update_state_dicts(model_state_dict, loaded_state_dict): max_size = max([len(key) for key in current_keys]) if current_keys else 1 max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 log_str_template = "{: <{}} loaded from {: <{}} of shape {}" - logger = logging.getLogger(__name__) for idx_new, idx_old in enumerate(idxs.tolist()): if idx_old == -1: continue key = current_keys[idx_new] key_old = loaded_keys[idx_old] model_state_dict[key] = loaded_state_dict[key_old] - logger.info( + logger.debug( log_str_template.format( key, max_size, -- GitLab