Newer
Older
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import os
import typing
logger = logging.getLogger(__name__)
def get_checkpoint(
output_folder: str, resume_from: typing.Literal["last", "best"] | str | None
) -> str | None:
"""Gets a checkpoint file.
Can return the best or last checkpoint, or a checkpoint at a specific path.
Ensures the checkpoint exists, raising an error if it is not the case.
If resume_from is None, checks the output directory if a checkpoint already exists and returns it.
If no checkpoint is found, returns None.
output_folder:
resume_from:
Which model to get. Can be one of "best", "last", or a path to a checkpoint.
If None, gets the last checkpoint if it exists, otherwise returns None
checkpoint_file:
Path to the requested checkpoint or None.
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
best_checkpoint_path = os.path.join(
output_folder, "model_lowest_valid_loss.ckpt"
)
if resume_from == "last":
if os.path.isfile(last_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(f"Resuming training from {resume_from} checkpoint")
else:
raise FileNotFoundError(
f"Could not find checkpoint {last_checkpoint_path}"
)
elif resume_from == "best":
if os.path.isfile(best_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(f"Resuming training from {resume_from} checkpoint")
else:
raise FileNotFoundError(
f"Could not find checkpoint {best_checkpoint_path}"
)
elif resume_from is None:
if os.path.isfile(last_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(
f"Found existing checkpoint {last_checkpoint_path}. Loading."
)
else:
return None
else:
if os.path.isfile(resume_from):
checkpoint_file = resume_from
logger.info(f"Resuming training from checkpoint {resume_from}")
else:
raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
return checkpoint_file