Skip to content
Snippets Groups Projects
Commit 826d392e authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Resume from last checkpoint by default if exists in output_folder

parent 9f737974
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
import logging
import os
import typing
logger = logging.getLogger(__name__)
def get_checkpoint(output_folder, resume_from):
def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best"] | str | None) -> str | None :
"""Gets a checkpoint file.
Can return the best or last checkpoint, or a checkpoint at a specific path.
Ensures the checkpoint exists, raising an error if it is not the case.
If resume_from is None, checks the output directory if a checkpoint already exists and returns it.
If no checkpoint is found, returns None.
Parameters
----------
output_folder : :py:class:`str`
output_folder:
Directory in which checkpoints are stored.
resume_from : :py:class:`str`
resume_from:
Which model to get. Can be one of "best", "last", or a path to a checkpoint.
If None, gets the last checkpoint if it exists, otherwise returns None
Returns
-------
checkpoint_file : :py:class:`str`
The requested model.
checkpoint_file:
Path to the requested checkpoint or None.
"""
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
best_checkpoint_path = os.path.join(
......@@ -49,7 +54,11 @@ def get_checkpoint(output_folder, resume_from):
)
elif resume_from is None:
checkpoint_file = None
if os.path.isfile(last_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(f"Found existing checkpoint {last_checkpoint_path}. Loading.")
else:
return None
else:
if os.path.isfile(resume_from):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment