Skip to content
Snippets Groups Projects
Commit b42b73d7 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Resume training from checkpoints

parent 85da2f49
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -462,6 +462,7 @@ def run(
output_folder,
monitoring_interval,
batch_chunk_count,
checkpoint,
):
"""Fits a CNN model using supervised learning and save it to disk.
......@@ -560,7 +561,7 @@ def run(
callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
)
_ = trainer.fit(model, data_loader, valid_loader)
_ = trainer.fit(model, data_loader, valid_loader, ckpt_path=checkpoint)
"""# write static information to a CSV file
static_logfile_name = os.path.join(output_folder, "constants.csv")
......
......@@ -254,13 +254,6 @@ def set_reproducible_cuda():
default=-1,
cls=ResourceOption,
)
@click.option(
"--weight",
"-w",
help="Path or URL to pretrained model file (.pth extension)",
required=False,
cls=ResourceOption,
)
@click.option(
"--normalization",
"-n",
......@@ -287,6 +280,14 @@ def set_reproducible_cuda():
default=5.0,
cls=ResourceOption,
)
@click.option(
"--resume-from",
help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a ckpt file.",
type=str,
required=False,
default=None,
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def train(
model,
......@@ -303,9 +304,9 @@ def train(
device,
seed,
parallel,
weight,
normalization,
monitoring_interval,
resume_from,
**_,
):
"""Trains an CNN to perform tuberculosis detection.
......@@ -483,6 +484,39 @@ def train(
arguments["epoch"] = 0
arguments["max_epoch"] = epochs
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
best_checkpoint_path = os.path.join(
output_folder, "model_lowest_valid_loss.ckpt"
)
if resume_from == "last":
if os.path.isfile(last_checkpoint_path):
checkpoint = last_checkpoint_path
logger.info(f"Resuming training from {resume_from} checkpoint")
else:
raise FileNotFoundError(
f"Could not find checkpoint {last_checkpoint_path}"
)
elif resume_from == "best":
if os.path.isfile(best_checkpoint_path):
checkpoint = last_checkpoint_path
logger.info(f"Resuming training from {resume_from} checkpoint")
else:
raise FileNotFoundError(
f"Could not find checkpoint {best_checkpoint_path}"
)
elif resume_from is None:
checkpoint = None
else:
if os.path.isfile(resume_from):
checkpoint = resume_from
logger.info(f"Resuming training from checkpoint {resume_from}")
else:
raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
logger.info("Training for {} epochs".format(arguments["max_epoch"]))
logger.info("Continuing from epoch {}".format(arguments["epoch"]))
......@@ -498,4 +532,5 @@ def train(
output_folder=output_folder,
monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count,
checkpoint=checkpoint,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment