Skip to content
Snippets Groups Projects
Commit bf25f806 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Get epochs from checkpointer and moved checkpointer selection code

parent a4e335c1
No related branches found
No related tags found
1 merge request!4Moved code to lightning
Pipeline #72399 failed
......@@ -2,14 +2,14 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import os
import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from pytorch_lightning import seed_everything
from ..utils.checkpointer import get_checkpoint
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -432,41 +432,17 @@ def train(
arguments = {}
arguments["max_epoch"] = epochs
arguments["epoch"] = 0
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
best_checkpoint_path = os.path.join(
output_folder, "model_lowest_valid_loss.ckpt"
)
if resume_from == "last":
if os.path.isfile(last_checkpoint_path):
checkpoint = last_checkpoint_path
logger.info(f"Resuming training from {resume_from} checkpoint")
else:
raise FileNotFoundError(
f"Could not find checkpoint {last_checkpoint_path}"
)
elif resume_from == "best":
if os.path.isfile(best_checkpoint_path):
checkpoint = last_checkpoint_path
logger.info(f"Resuming training from {resume_from} checkpoint")
else:
raise FileNotFoundError(
f"Could not find checkpoint {best_checkpoint_path}"
)
elif resume_from is None:
checkpoint = None
checkpoint_file = get_checkpoint(output_folder, resume_from)
else:
if os.path.isfile(resume_from):
checkpoint = resume_from
logger.info(f"Resuming training from checkpoint {resume_from}")
else:
raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
# We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit()
if checkpoint_file is not None:
checkpoint = torch.load(checkpoint_file)
arguments["epoch"] = checkpoint["epoch"]
logger.info("Training for {} epochs".format(arguments["max_epoch"]))
logger.info("Continuing from epoch {}".format(arguments["epoch"]))
run(
model=model,
......@@ -479,5 +455,5 @@ def train(
output_folder=output_folder,
monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count,
checkpoint=checkpoint,
checkpoint=checkpoint_file,
)
import logging
import os
logger = logging.getLogger(__name__)
def get_checkpoint(output_folder, resume_from):
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
best_checkpoint_path = os.path.join(
output_folder, "model_lowest_valid_loss.ckpt"
)
if resume_from == "last":
if os.path.isfile(last_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(f"Resuming training from {resume_from} checkpoint")
else:
raise FileNotFoundError(
f"Could not find checkpoint {last_checkpoint_path}"
)
elif resume_from == "best":
if os.path.isfile(best_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(f"Resuming training from {resume_from} checkpoint")
else:
raise FileNotFoundError(
f"Could not find checkpoint {best_checkpoint_path}"
)
elif resume_from is None:
checkpoint_file = None
else:
if os.path.isfile(resume_from):
checkpoint_file = resume_from
logger.info(f"Resuming training from checkpoint {resume_from}")
else:
raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
return checkpoint_file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment