diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 6003eaffbdd14960e0b4298e366a991c4dd33324..f5bd3a7afa0b19e8ad8650cb325c7fc5ba79a166 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -2,14 +2,14 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import os - import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup from pytorch_lightning import seed_everything +from ..utils.checkpointer import get_checkpoint + logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -432,41 +432,17 @@ def train( arguments = {} arguments["max_epoch"] = epochs + arguments["epoch"] = 0 - last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") - best_checkpoint_path = os.path.join( - output_folder, "model_lowest_valid_loss.ckpt" - ) - - if resume_from == "last": - if os.path.isfile(last_checkpoint_path): - checkpoint = last_checkpoint_path - logger.info(f"Resuming training from {resume_from} checkpoint") - else: - raise FileNotFoundError( - f"Could not find checkpoint {last_checkpoint_path}" - ) - - elif resume_from == "best": - if os.path.isfile(best_checkpoint_path): - checkpoint = last_checkpoint_path - logger.info(f"Resuming training from {resume_from} checkpoint") - else: - raise FileNotFoundError( - f"Could not find checkpoint {best_checkpoint_path}" - ) - - elif resume_from is None: - checkpoint = None + checkpoint_file = get_checkpoint(output_folder, resume_from) - else: - if os.path.isfile(resume_from): - checkpoint = resume_from - logger.info(f"Resuming training from checkpoint {resume_from}") - else: - raise FileNotFoundError(f"Could not find checkpoint {resume_from}") + # We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit() + if checkpoint_file is not None: + checkpoint = torch.load(checkpoint_file) + arguments["epoch"] = checkpoint["epoch"] logger.info("Training for {} epochs".format(arguments["max_epoch"])) + logger.info("Continuing from epoch {}".format(arguments["epoch"])) run( model=model, @@ -479,5 +455,5 @@ def train( output_folder=output_folder, monitoring_interval=monitoring_interval, batch_chunk_count=batch_chunk_count, - checkpoint=checkpoint, + checkpoint=checkpoint_file, ) diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..a30cba9444504a2abb350f4f7d9ca513bde9dde3 --- /dev/null +++ b/src/ptbench/utils/checkpointer.py @@ -0,0 +1,41 @@ +import logging +import os + +logger = logging.getLogger(__name__) + + +def get_checkpoint(output_folder, resume_from): + last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") + best_checkpoint_path = os.path.join( + output_folder, "model_lowest_valid_loss.ckpt" + ) + + if resume_from == "last": + if os.path.isfile(last_checkpoint_path): + checkpoint_file = last_checkpoint_path + logger.info(f"Resuming training from {resume_from} checkpoint") + else: + raise FileNotFoundError( + f"Could not find checkpoint {last_checkpoint_path}" + ) + + elif resume_from == "best": + if os.path.isfile(best_checkpoint_path): + checkpoint_file = last_checkpoint_path + logger.info(f"Resuming training from {resume_from} checkpoint") + else: + raise FileNotFoundError( + f"Could not find checkpoint {best_checkpoint_path}" + ) + + elif resume_from is None: + checkpoint_file = None + + else: + if os.path.isfile(resume_from): + checkpoint_file = resume_from + logger.info(f"Resuming training from checkpoint {resume_from}") + else: + raise FileNotFoundError(f"Could not find checkpoint {resume_from}") + + return checkpoint_file