Skip to content
Snippets Groups Projects
Commit 826d392e authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Resume from last checkpoint by default if exists in output_folder

parent 9f737974
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
import logging import logging
import os import os
import typing
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_checkpoint(output_folder, resume_from): def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best"] | str | None) -> str | None :
"""Gets a checkpoint file. """Gets a checkpoint file.
Can return the best or last checkpoint, or a checkpoint at a specific path. Can return the best or last checkpoint, or a checkpoint at a specific path.
Ensures the checkpoint exists, raising an error if it is not the case. Ensures the checkpoint exists, raising an error if it is not the case.
If resume_from is None, checks the output directory if a checkpoint already exists and returns it.
If no checkpoint is found, returns None.
Parameters Parameters
---------- ----------
output_folder : :py:class:`str` output_folder:
Directory in which checkpoints are stored. Directory in which checkpoints are stored.
resume_from : :py:class:`str` resume_from:
Which model to get. Can be one of "best", "last", or a path to a checkpoint. Which model to get. Can be one of "best", "last", or a path to a checkpoint.
If None, gets the last checkpoint if it exists, otherwise returns None
Returns Returns
------- -------
checkpoint_file : :py:class:`str` checkpoint_file:
The requested model. Path to the requested checkpoint or None.
""" """
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
best_checkpoint_path = os.path.join( best_checkpoint_path = os.path.join(
...@@ -49,7 +54,11 @@ def get_checkpoint(output_folder, resume_from): ...@@ -49,7 +54,11 @@ def get_checkpoint(output_folder, resume_from):
) )
elif resume_from is None: elif resume_from is None:
checkpoint_file = None if os.path.isfile(last_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(f"Found existing checkpoint {last_checkpoint_path}. Loading.")
else:
return None
else: else:
if os.path.isfile(resume_from): if os.path.isfile(resume_from):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment