Skip to content
Snippets Groups Projects
checkpointer.py 1.81 KiB
Newer Older
import logging
import os

logger = logging.getLogger(__name__)


def get_checkpoint(output_folder, resume_from):
Daniel CARRON's avatar
Daniel CARRON committed
    """Gets a checkpoint file.

    Can return the best or last checkpoint, or a checkpoint at a specific path.
    Ensures the checkpoint exists, raising an error if it is not the case.

    Parameters
    ----------

    output_folder : :py:class:`str`
        Directory in which checkpoints are stored.

    resume_from : :py:class:`str`
        Which model to get. Can be one of "best", "last", or a path to a checkpoint.

    Returns
    -------

    checkpoint_file : :py:class:`str`
        The requested model.
    """
    last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
    best_checkpoint_path = os.path.join(
        output_folder, "model_lowest_valid_loss.ckpt"
    )

    if resume_from == "last":
        if os.path.isfile(last_checkpoint_path):
            checkpoint_file = last_checkpoint_path
            logger.info(f"Resuming training from {resume_from} checkpoint")
        else:
            raise FileNotFoundError(
                f"Could not find checkpoint {last_checkpoint_path}"
            )

    elif resume_from == "best":
        if os.path.isfile(best_checkpoint_path):
            checkpoint_file = last_checkpoint_path
            logger.info(f"Resuming training from {resume_from} checkpoint")
        else:
            raise FileNotFoundError(
                f"Could not find checkpoint {best_checkpoint_path}"
            )

    elif resume_from is None:
        checkpoint_file = None

    else:
        if os.path.isfile(resume_from):
            checkpoint_file = resume_from
            logger.info(f"Resuming training from checkpoint {resume_from}")
        else:
            raise FileNotFoundError(f"Could not find checkpoint {resume_from}")

    return checkpoint_file