Skip to content
Snippets Groups Projects
Commit b81c10c5 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[checkpointer] Make it flexible to directory renames; Fix saving of trainlog...

[checkpointer] Make it flexible to directory renames; Fix saving of trainlog in case we are interrupted
parent c6f07e8d
No related branches found
No related tags found
1 merge request!12Streamlining
Pipeline #38719 failed
......@@ -75,12 +75,18 @@ def do_train(
# Log to file
logfile_name = os.path.join(output_folder, "trainlog.csv")
logfile_fields = ("epoch", "total-time", "eta", "average-loss",
"median-loss", "learning-rate", "memory-megabytes")
with open(logfile_name, "w", newline="") as logfile:
if arguments["epoch"] == 0 and os.path.exists(logfile_name):
logger.info(f"Truncating {logfile_name} - training is restarting...")
os.unlink(logfile_name)
logfile_fields = ("epoch", "total-time", "eta", "average-loss",
"median-loss", "learning-rate", "gpu-memory-megabytes")
with open(logfile_name, "a+", newline="") as logfile:
logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
logwriter.writeheader()
if arguments["epoch"] == 0:
logwriter.writeheader()
model.train().to(device)
for state in optimizer.state.values():
......
......@@ -117,7 +117,7 @@ logger = logging.getLogger(__name__)
@click.option(
"--epochs",
"-e",
help="Number of epochs used for training",
help="Number of epochs (complete training set passes) to train for",
show_default=True,
required=True,
default=1000,
......@@ -126,8 +126,8 @@ logger = logging.getLogger(__name__)
@click.option(
"--checkpoint-period",
"-p",
help="Number of epochs after which a checkpoint is saved. "
"A value of zero will disable check-pointing. If checkpointing is "
help="Number of epochs after which a checkpoint is saved. "
"A value of zero will disable check-pointing. If checkpointing is "
"enabled and training stops, it is automatically resumed from the "
"last saved checkpoint if training is restarted with the same "
"configuration.",
......
......@@ -46,10 +46,11 @@ class Checkpointer:
data["scheduler"] = self.scheduler.state_dict()
data.update(kwargs)
save_file = os.path.join(self.save_dir, "{}.pth".format(name))
logger.info("Saving checkpoint to {}".format(save_file))
dest_filename = f"{name}.pth"
save_file = os.path.join(self.save_dir, dest_filename)
logger.info(f"Saving checkpoint to {save_file}")
torch.save(data, save_file)
self.tag_last_checkpoint(save_file)
self.tag_last_checkpoint(dest_filename)
def load(self, f=None):
if self.has_checkpoint():
......@@ -59,14 +60,14 @@ class Checkpointer:
# no checkpoint could be found
logger.warn("No checkpoint found. Initializing model from scratch")
return {}
logger.info("Loading checkpoint from {}".format(f))
checkpoint = self._load_file(f)
self._load_model(checkpoint)
actual_file = os.path.join(self.save_dir, f)
if "optimizer" in checkpoint and self.optimizer:
logger.info("Loading optimizer from {}".format(f))
logger.info(f"Loading optimizer from {actual_file}")
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
if "scheduler" in checkpoint and self.scheduler:
logger.info("Loading scheduler from {}".format(f))
logger.info(f"Loading scheduler from {actual_file}")
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
# return any further checkpoint data
......@@ -94,7 +95,9 @@ class Checkpointer:
f.write(last_filename)
def _load_file(self, f):
return torch.load(f, map_location=torch.device("cpu"))
actual_file = os.path.join(self.save_dir, f)
logger.info(f"Loading checkpoint from {actual_file}")
return torch.load(actual_file, map_location=torch.device("cpu"))
def _load_model(self, checkpoint):
load_state_dict(self.model, checkpoint.pop("model"))
......@@ -108,10 +111,9 @@ class DetectronCheckpointer(Checkpointer):
scheduler=None,
save_dir="",
save_to_disk=None,
logger=None,
):
super(DetectronCheckpointer, self).__init__(
model, optimizer, scheduler, save_dir, save_to_disk, logger
model, optimizer, scheduler, save_dir, save_to_disk
)
def _load_file(self, f):
......@@ -119,7 +121,7 @@ class DetectronCheckpointer(Checkpointer):
if f.startswith("http"):
# if the file is a url path, download it and cache it
cached_f = cache_url(f)
logger.info("url {} cached in {}".format(f, cached_f))
logger.info(f"url {f} cached in {cached_f}")
f = cached_f
# load checkpoint
loaded = super(DetectronCheckpointer, self)._load_file(f)
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# https://github.com/facebookresearch/maskrcnn-benchmark
from collections import OrderedDict
import logging
logger = logging.getLogger(__name__)
import torch
......@@ -39,14 +42,13 @@ def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
max_size = max([len(key) for key in current_keys]) if current_keys else 1
max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
logger = logging.getLogger(__name__)
for idx_new, idx_old in enumerate(idxs.tolist()):
if idx_old == -1:
continue
key = current_keys[idx_new]
key_old = loaded_keys[idx_old]
model_state_dict[key] = loaded_state_dict[key_old]
logger.info(
logger.debug(
log_str_template.format(
key,
max_size,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment