# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import logging import os import typing logger = logging.getLogger(__name__) def get_checkpoint( output_folder: str, resume_from: typing.Literal["last", "best"] | str | None ) -> str | None: """Gets a checkpoint file. Can return the best or last checkpoint, or a checkpoint at a specific path. Ensures the checkpoint exists, raising an error if it is not the case. If resume_from is None, checks the output directory if a checkpoint already exists and returns it. If no checkpoint is found, returns None. Parameters ---------- output_folder: Directory in which checkpoints are stored. resume_from: Which model to get. Can be one of "best", "last", or a path to a checkpoint. If None, gets the last checkpoint if it exists, otherwise returns None Returns ------- checkpoint_file: Path to the requested checkpoint or None. """ last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") best_checkpoint_path = os.path.join( output_folder, "model_lowest_valid_loss.ckpt" ) if resume_from == "last": if os.path.isfile(last_checkpoint_path): checkpoint_file = last_checkpoint_path logger.info(f"Resuming training from {resume_from} checkpoint") else: raise FileNotFoundError( f"Could not find checkpoint {last_checkpoint_path}" ) elif resume_from == "best": if os.path.isfile(best_checkpoint_path): checkpoint_file = last_checkpoint_path logger.info(f"Resuming training from {resume_from} checkpoint") else: raise FileNotFoundError( f"Could not find checkpoint {best_checkpoint_path}" ) elif resume_from is None: if os.path.isfile(last_checkpoint_path): checkpoint_file = last_checkpoint_path logger.info( f"Found existing checkpoint {last_checkpoint_path}. Loading." ) else: return None else: if os.path.isfile(resume_from): checkpoint_file = resume_from logger.info(f"Resuming training from checkpoint {resume_from}") else: raise FileNotFoundError(f"Could not find checkpoint {resume_from}") return checkpoint_file