Skip to content
Snippets Groups Projects
checkpointer.py 2.44 KiB
Newer Older
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

def get_checkpoint(
    output_folder: str, resume_from: typing.Literal["last", "best"] | str | None
) -> str | None:
Daniel CARRON's avatar
Daniel CARRON committed
    """Gets a checkpoint file.

    Can return the best or last checkpoint, or a checkpoint at a specific path.
    Ensures the checkpoint exists, raising an error if it is not the case.

    If resume_from is None, checks the output directory if a checkpoint already exists and returns it.
    If no checkpoint is found, returns None.

Daniel CARRON's avatar
Daniel CARRON committed
    Parameters
    ----------

Daniel CARRON's avatar
Daniel CARRON committed
        Directory in which checkpoints are stored.

Daniel CARRON's avatar
Daniel CARRON committed
        Which model to get. Can be one of "best", "last", or a path to a checkpoint.
        If None, gets the last checkpoint if it exists, otherwise returns None
Daniel CARRON's avatar
Daniel CARRON committed

    Returns
    -------

    checkpoint_file:
        Path to the requested checkpoint or None.
Daniel CARRON's avatar
Daniel CARRON committed
    """
    last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
    best_checkpoint_path = os.path.join(
        output_folder, "model_lowest_valid_loss.ckpt"
    )

    if resume_from == "last":
        if os.path.isfile(last_checkpoint_path):
            checkpoint_file = last_checkpoint_path
            logger.info(f"Resuming training from {resume_from} checkpoint")
        else:
            raise FileNotFoundError(
                f"Could not find checkpoint {last_checkpoint_path}"
            )

    elif resume_from == "best":
        if os.path.isfile(best_checkpoint_path):
            checkpoint_file = last_checkpoint_path
            logger.info(f"Resuming training from {resume_from} checkpoint")
        else:
            raise FileNotFoundError(
                f"Could not find checkpoint {best_checkpoint_path}"
            )

    elif resume_from is None:
        if os.path.isfile(last_checkpoint_path):
            checkpoint_file = last_checkpoint_path
            logger.info(
                f"Found existing checkpoint {last_checkpoint_path}. Loading."
            )

    else:
        if os.path.isfile(resume_from):
            checkpoint_file = resume_from
            logger.info(f"Resuming training from checkpoint {resume_from}")
        else:
            raise FileNotFoundError(f"Could not find checkpoint {resume_from}")

    return checkpoint_file