From b42b73d7ab22d6812c26ed2e202453e91b59a734 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 12 Apr 2023 10:14:18 +0200 Subject: [PATCH] Resume training from checkpoints --- src/ptbench/engine/trainer.py | 3 ++- src/ptbench/scripts/train.py | 51 +++++++++++++++++++++++++++++------ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index ddacf63c..05b37feb 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -462,6 +462,7 @@ def run( output_folder, monitoring_interval, batch_chunk_count, + checkpoint, ): """Fits a CNN model using supervised learning and save it to disk. @@ -560,7 +561,7 @@ def run( callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], ) - _ = trainer.fit(model, data_loader, valid_loader) + _ = trainer.fit(model, data_loader, valid_loader, ckpt_path=checkpoint) """# write static information to a CSV file static_logfile_name = os.path.join(output_folder, "constants.csv") diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 12e80ae1..ea82f5a7 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -254,13 +254,6 @@ def set_reproducible_cuda(): default=-1, cls=ResourceOption, ) -@click.option( - "--weight", - "-w", - help="Path or URL to pretrained model file (.pth extension)", - required=False, - cls=ResourceOption, -) @click.option( "--normalization", "-n", @@ -287,6 +280,14 @@ def set_reproducible_cuda(): default=5.0, cls=ResourceOption, ) +@click.option( + "--resume-from", + help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a ckpt file.", + type=str, + required=False, + default=None, + cls=ResourceOption, +) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def train( model, @@ -303,9 +304,9 @@ def train( device, seed, parallel, - weight, normalization, monitoring_interval, + resume_from, **_, ): """Trains an CNN to perform tuberculosis detection. @@ -483,6 +484,39 @@ def train( arguments["epoch"] = 0 arguments["max_epoch"] = epochs + last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") + best_checkpoint_path = os.path.join( + output_folder, "model_lowest_valid_loss.ckpt" + ) + + if resume_from == "last": + if os.path.isfile(last_checkpoint_path): + checkpoint = last_checkpoint_path + logger.info(f"Resuming training from {resume_from} checkpoint") + else: + raise FileNotFoundError( + f"Could not find checkpoint {last_checkpoint_path}" + ) + + elif resume_from == "best": + if os.path.isfile(best_checkpoint_path): + checkpoint = last_checkpoint_path + logger.info(f"Resuming training from {resume_from} checkpoint") + else: + raise FileNotFoundError( + f"Could not find checkpoint {best_checkpoint_path}" + ) + + elif resume_from is None: + checkpoint = None + + else: + if os.path.isfile(resume_from): + checkpoint = resume_from + logger.info(f"Resuming training from checkpoint {resume_from}") + else: + raise FileNotFoundError(f"Could not find checkpoint {resume_from}") + logger.info("Training for {} epochs".format(arguments["max_epoch"])) logger.info("Continuing from epoch {}".format(arguments["epoch"])) @@ -498,4 +532,5 @@ def train( output_folder=output_folder, monitoring_interval=monitoring_interval, batch_chunk_count=batch_chunk_count, + checkpoint=checkpoint, ) -- GitLab