Skip to content
Snippets Groups Projects
checkpointer.py 1.3 KiB
Newer Older
import logging
import os

logger = logging.getLogger(__name__)


def get_checkpoint(output_folder, resume_from):
    last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
    best_checkpoint_path = os.path.join(
        output_folder, "model_lowest_valid_loss.ckpt"
    )

    if resume_from == "last":
        if os.path.isfile(last_checkpoint_path):
            checkpoint_file = last_checkpoint_path
            logger.info(f"Resuming training from {resume_from} checkpoint")
        else:
            raise FileNotFoundError(
                f"Could not find checkpoint {last_checkpoint_path}"
            )

    elif resume_from == "best":
        if os.path.isfile(best_checkpoint_path):
            checkpoint_file = last_checkpoint_path
            logger.info(f"Resuming training from {resume_from} checkpoint")
        else:
            raise FileNotFoundError(
                f"Could not find checkpoint {best_checkpoint_path}"
            )

    elif resume_from is None:
        checkpoint_file = None

    else:
        if os.path.isfile(resume_from):
            checkpoint_file = resume_from
            logger.info(f"Resuming training from checkpoint {resume_from}")
        else:
            raise FileNotFoundError(f"Could not find checkpoint {resume_from}")

    return checkpoint_file