From 85da2f49221a4ddebebfc28144abe17717eb1e72 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Apr 2023 17:08:06 +0200 Subject: [PATCH] Removed custom checkpointer, saving missing files --- src/ptbench/engine/trainer.py | 91 +++++++--------------------- src/ptbench/utils/checkpointer.py | 99 ------------------------------- 2 files changed, 21 insertions(+), 169 deletions(-) delete mode 100644 src/ptbench/utils/checkpointer.py diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 41ade3f3..ddacf63c 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -129,9 +129,9 @@ def save_model_summary(output_folder, model): summary_path = os.path.join(output_folder, "model_summary.txt") logger.info(f"Saving model summary at {summary_path}...") with open(summary_path, "w") as f: - summary = str(ModelSummary(model, max_depth=-1)) - f.write(summary) - return summary + summary = ModelSummary(model, max_depth=-1) + f.write(str(summary)) + return summary, ModelSummary(model).total_parameters def static_information_to_csv(static_logfile_name, device, n): @@ -374,62 +374,6 @@ def validate_epoch(loader, model, device, criterion, pbar_desc): return numpy.average(batch_losses, weights=samples_in_batch) -def checkpointer_process( - checkpointer, - checkpoint_period, - valid_loss, - lowest_validation_loss, - arguments, - epoch, - max_epoch, -): - """Process the checkpointer, save the final model and keep track of the - best model. - - Parameters - ---------- - - checkpointer : :py:class:`ptbench.utils.checkpointer.Checkpointer` - checkpointer implementation - - checkpoint_period : int - save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do - not save intermediary checkpoints - - valid_loss : float - Current epoch validation loss - - lowest_validation_loss : float - Keeps track of the best (lowest) validation loss - - arguments : dict - start and end epochs - - max_epoch : int - end_potch - - Returns - ------- - - lowest_validation_loss : float - The lowest validation loss currently observed - """ - if checkpoint_period and (epoch % checkpoint_period == 0): - checkpointer.save("model_periodic_save", **arguments) - - if valid_loss is not None and valid_loss < lowest_validation_loss: - lowest_validation_loss = valid_loss - logger.info( - f"Found new low on validation set:" f" {lowest_validation_loss:.6f}" - ) - checkpointer.save("model_lowest_valid_loss", **arguments) - - if epoch >= max_epoch: - checkpointer.save("model_final_epoch", **arguments) - - return lowest_validation_loss - - def write_log_info( epoch, current_time, @@ -578,7 +522,7 @@ def run( os.makedirs(output_folder, exist_ok=True) # Save model summary - _ = save_model_summary(output_folder, model) + r, n = save_model_summary(output_folder, model) csv_logger = CSVLogger(output_folder, "logs_csv") tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard") @@ -590,6 +534,22 @@ def run( logging_level=logging.ERROR, ) + checkpoint_callback = ModelCheckpoint( + output_folder, + "model_lowest_valid_loss", + save_last=True, + monitor="validation_loss", + mode="min", + save_on_train_epoch_end=False, + every_n_epochs=checkpoint_period, + ) + + checkpoint_callback.CHECKPOINT_NAME_LAST = "model_final_epoch" + + # write static information to a CSV file + static_logfile_name = os.path.join(output_folder, "constants.csv") + static_information_to_csv(static_logfile_name, device, n) + with resource_monitor: trainer = Trainer( accelerator="auto", @@ -597,16 +557,7 @@ def run( max_epochs=max_epoch, logger=[csv_logger, tensorboard_logger], check_val_every_n_epoch=1, - callbacks=[ - LoggingCallback(resource_monitor), - ModelCheckpoint( - output_folder, - monitor="validation_loss", - mode="min", - save_on_train_epoch_end=False, - every_n_epochs=checkpoint_period, - ), - ], + callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], ) _ = trainer.fit(model, data_loader, valid_loader) diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py deleted file mode 100644 index 3e839b0e..00000000 --- a/src/ptbench/utils/checkpointer.py +++ /dev/null @@ -1,99 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import logging -import os - -import torch - -logger = logging.getLogger(__name__) - - -class Checkpointer: - """A simple pytorch checkpointer. - - Parameters - ---------- - - model : torch.nn.Module - Network model, eventually loaded from a checkpointed file - - optimizer : :py:mod:`torch.optim`, Optional - Optimizer - - scheduler : :py:mod:`torch.optim`, Optional - Learning rate scheduler - - path : :py:class:`str`, Optional - Directory where to save checkpoints. - """ - - def __init__(self, model, optimizer=None, scheduler=None, path="."): - self.model = model - self.optimizer = optimizer - self.scheduler = scheduler - self.path = os.path.realpath(path) - - def save(self, name, **kwargs): - data = {} - data["model"] = self.model.state_dict() - if self.optimizer is not None: - data["optimizer"] = self.optimizer.state_dict() - if self.scheduler is not None: - data["scheduler"] = self.scheduler.state_dict() - data.update(kwargs) - - name = f"{name}.pth" - outf = os.path.join(self.path, name) - logger.info(f"Saving checkpoint to {outf}") - torch.save(data, outf) - with open(self._last_checkpoint_filename, "w") as f: - f.write(name) - - def load(self, f=None): - """Loads model, optimizer and scheduler from file. - - Parameters - ========== - - f : :py:class:`str`, Optional - Name of a file (absolute or relative to ``self.path``), that - contains the checkpoint data to load into the model, and optionally - into the optimizer and the scheduler. If not specified, loads data - from current path. - """ - if f is None: - f = self.last_checkpoint() - - if f is None: - # no checkpoint could be found - logger.warning("No checkpoint found (and none passed)") - return {} - - # loads file data into memory - logger.info(f"Loading checkpoint from {f}...") - checkpoint = torch.load(f, map_location=torch.device("cpu")) - - # converts model entry to model parameters - self.model.load_state_dict(checkpoint.pop("model")) - - if self.optimizer is not None: - self.optimizer.load_state_dict(checkpoint.pop("optimizer")) - if self.scheduler is not None: - self.scheduler.load_state_dict(checkpoint.pop("scheduler")) - - return checkpoint - - @property - def _last_checkpoint_filename(self): - return os.path.join(self.path, "last_checkpoint") - - def has_checkpoint(self): - return os.path.exists(self._last_checkpoint_filename) - - def last_checkpoint(self): - if self.has_checkpoint(): - with open(self._last_checkpoint_filename) as fobj: - return os.path.join(self.path, fobj.read().strip()) - return None -- GitLab