import logging import os logger = logging.getLogger(__name__) def get_checkpoint(output_folder, resume_from): last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") best_checkpoint_path = os.path.join( output_folder, "model_lowest_valid_loss.ckpt" ) if resume_from == "last": if os.path.isfile(last_checkpoint_path): checkpoint_file = last_checkpoint_path logger.info(f"Resuming training from {resume_from} checkpoint") else: raise FileNotFoundError( f"Could not find checkpoint {last_checkpoint_path}" ) elif resume_from == "best": if os.path.isfile(best_checkpoint_path): checkpoint_file = last_checkpoint_path logger.info(f"Resuming training from {resume_from} checkpoint") else: raise FileNotFoundError( f"Could not find checkpoint {best_checkpoint_path}" ) elif resume_from is None: checkpoint_file = None else: if os.path.isfile(resume_from): checkpoint_file = resume_from logger.info(f"Resuming training from checkpoint {resume_from}") else: raise FileNotFoundError(f"Could not find checkpoint {resume_from}") return checkpoint_file