Skip to content
Snippets Groups Projects
Commit 85da2f49 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Removed custom checkpointer, saving missing files

parent d20d2311
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -129,9 +129,9 @@ def save_model_summary(output_folder, model):
summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "w") as f:
summary = str(ModelSummary(model, max_depth=-1))
f.write(summary)
return summary
summary = ModelSummary(model, max_depth=-1)
f.write(str(summary))
return summary, ModelSummary(model).total_parameters
def static_information_to_csv(static_logfile_name, device, n):
......@@ -374,62 +374,6 @@ def validate_epoch(loader, model, device, criterion, pbar_desc):
return numpy.average(batch_losses, weights=samples_in_batch)
def checkpointer_process(
checkpointer,
checkpoint_period,
valid_loss,
lowest_validation_loss,
arguments,
epoch,
max_epoch,
):
"""Process the checkpointer, save the final model and keep track of the
best model.
Parameters
----------
checkpointer : :py:class:`ptbench.utils.checkpointer.Checkpointer`
checkpointer implementation
checkpoint_period : int
save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints
valid_loss : float
Current epoch validation loss
lowest_validation_loss : float
Keeps track of the best (lowest) validation loss
arguments : dict
start and end epochs
max_epoch : int
end_potch
Returns
-------
lowest_validation_loss : float
The lowest validation loss currently observed
"""
if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save("model_periodic_save", **arguments)
if valid_loss is not None and valid_loss < lowest_validation_loss:
lowest_validation_loss = valid_loss
logger.info(
f"Found new low on validation set:" f" {lowest_validation_loss:.6f}"
)
checkpointer.save("model_lowest_valid_loss", **arguments)
if epoch >= max_epoch:
checkpointer.save("model_final_epoch", **arguments)
return lowest_validation_loss
def write_log_info(
epoch,
current_time,
......@@ -578,7 +522,7 @@ def run(
os.makedirs(output_folder, exist_ok=True)
# Save model summary
_ = save_model_summary(output_folder, model)
r, n = save_model_summary(output_folder, model)
csv_logger = CSVLogger(output_folder, "logs_csv")
tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
......@@ -590,6 +534,22 @@ def run(
logging_level=logging.ERROR,
)
checkpoint_callback = ModelCheckpoint(
output_folder,
"model_lowest_valid_loss",
save_last=True,
monitor="validation_loss",
mode="min",
save_on_train_epoch_end=False,
every_n_epochs=checkpoint_period,
)
checkpoint_callback.CHECKPOINT_NAME_LAST = "model_final_epoch"
# write static information to a CSV file
static_logfile_name = os.path.join(output_folder, "constants.csv")
static_information_to_csv(static_logfile_name, device, n)
with resource_monitor:
trainer = Trainer(
accelerator="auto",
......@@ -597,16 +557,7 @@ def run(
max_epochs=max_epoch,
logger=[csv_logger, tensorboard_logger],
check_val_every_n_epoch=1,
callbacks=[
LoggingCallback(resource_monitor),
ModelCheckpoint(
output_folder,
monitor="validation_loss",
mode="min",
save_on_train_epoch_end=False,
every_n_epochs=checkpoint_period,
),
],
callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
)
_ = trainer.fit(model, data_loader, valid_loader)
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import os
import torch
logger = logging.getLogger(__name__)
class Checkpointer:
"""A simple pytorch checkpointer.
Parameters
----------
model : torch.nn.Module
Network model, eventually loaded from a checkpointed file
optimizer : :py:mod:`torch.optim`, Optional
Optimizer
scheduler : :py:mod:`torch.optim`, Optional
Learning rate scheduler
path : :py:class:`str`, Optional
Directory where to save checkpoints.
"""
def __init__(self, model, optimizer=None, scheduler=None, path="."):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.path = os.path.realpath(path)
def save(self, name, **kwargs):
data = {}
data["model"] = self.model.state_dict()
if self.optimizer is not None:
data["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
data["scheduler"] = self.scheduler.state_dict()
data.update(kwargs)
name = f"{name}.pth"
outf = os.path.join(self.path, name)
logger.info(f"Saving checkpoint to {outf}")
torch.save(data, outf)
with open(self._last_checkpoint_filename, "w") as f:
f.write(name)
def load(self, f=None):
"""Loads model, optimizer and scheduler from file.
Parameters
==========
f : :py:class:`str`, Optional
Name of a file (absolute or relative to ``self.path``), that
contains the checkpoint data to load into the model, and optionally
into the optimizer and the scheduler. If not specified, loads data
from current path.
"""
if f is None:
f = self.last_checkpoint()
if f is None:
# no checkpoint could be found
logger.warning("No checkpoint found (and none passed)")
return {}
# loads file data into memory
logger.info(f"Loading checkpoint from {f}...")
checkpoint = torch.load(f, map_location=torch.device("cpu"))
# converts model entry to model parameters
self.model.load_state_dict(checkpoint.pop("model"))
if self.optimizer is not None:
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
if self.scheduler is not None:
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
return checkpoint
@property
def _last_checkpoint_filename(self):
return os.path.join(self.path, "last_checkpoint")
def has_checkpoint(self):
return os.path.exists(self._last_checkpoint_filename)
def last_checkpoint(self):
if self.has_checkpoint():
with open(self._last_checkpoint_filename) as fobj:
return os.path.join(self.path, fobj.read().strip())
return None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment