From 85da2f49221a4ddebebfc28144abe17717eb1e72 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 11 Apr 2023 17:08:06 +0200
Subject: [PATCH] Removed custom checkpointer, saving missing files

---
 src/ptbench/engine/trainer.py     | 91 +++++++---------------------
 src/ptbench/utils/checkpointer.py | 99 -------------------------------
 2 files changed, 21 insertions(+), 169 deletions(-)
 delete mode 100644 src/ptbench/utils/checkpointer.py

diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 41ade3f3..ddacf63c 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -129,9 +129,9 @@ def save_model_summary(output_folder, model):
     summary_path = os.path.join(output_folder, "model_summary.txt")
     logger.info(f"Saving model summary at {summary_path}...")
     with open(summary_path, "w") as f:
-        summary = str(ModelSummary(model, max_depth=-1))
-        f.write(summary)
-    return summary
+        summary = ModelSummary(model, max_depth=-1)
+        f.write(str(summary))
+    return summary, ModelSummary(model).total_parameters
 
 
 def static_information_to_csv(static_logfile_name, device, n):
@@ -374,62 +374,6 @@ def validate_epoch(loader, model, device, criterion, pbar_desc):
     return numpy.average(batch_losses, weights=samples_in_batch)
 
 
-def checkpointer_process(
-    checkpointer,
-    checkpoint_period,
-    valid_loss,
-    lowest_validation_loss,
-    arguments,
-    epoch,
-    max_epoch,
-):
-    """Process the checkpointer, save the final model and keep track of the
-    best model.
-
-    Parameters
-    ----------
-
-    checkpointer : :py:class:`ptbench.utils.checkpointer.Checkpointer`
-        checkpointer implementation
-
-    checkpoint_period : int
-        save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
-        not save intermediary checkpoints
-
-    valid_loss : float
-        Current epoch validation loss
-
-    lowest_validation_loss : float
-        Keeps track of the best (lowest) validation loss
-
-    arguments : dict
-        start and end epochs
-
-    max_epoch : int
-        end_potch
-
-    Returns
-    -------
-
-    lowest_validation_loss : float
-        The lowest validation loss currently observed
-    """
-    if checkpoint_period and (epoch % checkpoint_period == 0):
-        checkpointer.save("model_periodic_save", **arguments)
-
-    if valid_loss is not None and valid_loss < lowest_validation_loss:
-        lowest_validation_loss = valid_loss
-        logger.info(
-            f"Found new low on validation set:" f" {lowest_validation_loss:.6f}"
-        )
-        checkpointer.save("model_lowest_valid_loss", **arguments)
-
-    if epoch >= max_epoch:
-        checkpointer.save("model_final_epoch", **arguments)
-
-    return lowest_validation_loss
-
-
 def write_log_info(
     epoch,
     current_time,
@@ -578,7 +522,7 @@ def run(
     os.makedirs(output_folder, exist_ok=True)
 
     # Save model summary
-    _ = save_model_summary(output_folder, model)
+    r, n = save_model_summary(output_folder, model)
 
     csv_logger = CSVLogger(output_folder, "logs_csv")
     tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
@@ -590,6 +534,22 @@ def run(
         logging_level=logging.ERROR,
     )
 
+    checkpoint_callback = ModelCheckpoint(
+        output_folder,
+        "model_lowest_valid_loss",
+        save_last=True,
+        monitor="validation_loss",
+        mode="min",
+        save_on_train_epoch_end=False,
+        every_n_epochs=checkpoint_period,
+    )
+
+    checkpoint_callback.CHECKPOINT_NAME_LAST = "model_final_epoch"
+
+    # write static information to a CSV file
+    static_logfile_name = os.path.join(output_folder, "constants.csv")
+    static_information_to_csv(static_logfile_name, device, n)
+
     with resource_monitor:
         trainer = Trainer(
             accelerator="auto",
@@ -597,16 +557,7 @@ def run(
             max_epochs=max_epoch,
             logger=[csv_logger, tensorboard_logger],
             check_val_every_n_epoch=1,
-            callbacks=[
-                LoggingCallback(resource_monitor),
-                ModelCheckpoint(
-                    output_folder,
-                    monitor="validation_loss",
-                    mode="min",
-                    save_on_train_epoch_end=False,
-                    every_n_epochs=checkpoint_period,
-                ),
-            ],
+            callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
         )
 
         _ = trainer.fit(model, data_loader, valid_loader)
diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py
deleted file mode 100644
index 3e839b0e..00000000
--- a/src/ptbench/utils/checkpointer.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-import logging
-import os
-
-import torch
-
-logger = logging.getLogger(__name__)
-
-
-class Checkpointer:
-    """A simple pytorch checkpointer.
-
-    Parameters
-    ----------
-
-    model : torch.nn.Module
-        Network model, eventually loaded from a checkpointed file
-
-    optimizer : :py:mod:`torch.optim`, Optional
-        Optimizer
-
-    scheduler : :py:mod:`torch.optim`, Optional
-        Learning rate scheduler
-
-    path : :py:class:`str`, Optional
-        Directory where to save checkpoints.
-    """
-
-    def __init__(self, model, optimizer=None, scheduler=None, path="."):
-        self.model = model
-        self.optimizer = optimizer
-        self.scheduler = scheduler
-        self.path = os.path.realpath(path)
-
-    def save(self, name, **kwargs):
-        data = {}
-        data["model"] = self.model.state_dict()
-        if self.optimizer is not None:
-            data["optimizer"] = self.optimizer.state_dict()
-        if self.scheduler is not None:
-            data["scheduler"] = self.scheduler.state_dict()
-        data.update(kwargs)
-
-        name = f"{name}.pth"
-        outf = os.path.join(self.path, name)
-        logger.info(f"Saving checkpoint to {outf}")
-        torch.save(data, outf)
-        with open(self._last_checkpoint_filename, "w") as f:
-            f.write(name)
-
-    def load(self, f=None):
-        """Loads model, optimizer and scheduler from file.
-
-        Parameters
-        ==========
-
-        f : :py:class:`str`, Optional
-            Name of a file (absolute or relative to ``self.path``), that
-            contains the checkpoint data to load into the model, and optionally
-            into the optimizer and the scheduler.  If not specified, loads data
-            from current path.
-        """
-        if f is None:
-            f = self.last_checkpoint()
-
-        if f is None:
-            # no checkpoint could be found
-            logger.warning("No checkpoint found (and none passed)")
-            return {}
-
-        # loads file data into memory
-        logger.info(f"Loading checkpoint from {f}...")
-        checkpoint = torch.load(f, map_location=torch.device("cpu"))
-
-        # converts model entry to model parameters
-        self.model.load_state_dict(checkpoint.pop("model"))
-
-        if self.optimizer is not None:
-            self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
-        if self.scheduler is not None:
-            self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
-
-        return checkpoint
-
-    @property
-    def _last_checkpoint_filename(self):
-        return os.path.join(self.path, "last_checkpoint")
-
-    def has_checkpoint(self):
-        return os.path.exists(self._last_checkpoint_filename)
-
-    def last_checkpoint(self):
-        if self.has_checkpoint():
-            with open(self._last_checkpoint_filename) as fobj:
-                return os.path.join(self.path, fobj.read().strip())
-        return None
-- 
GitLab