Skip to content
Snippets Groups Projects

Making use of LightningDataModule and simplification of data loading

Merged Daniel CARRON requested to merge add-datamodule into main
Compare and Show latest version
1 file
+ 15
6
Compare changes
  • Side-by-side
  • Inline
import logging
import os
import typing
logger = logging.getLogger(__name__)
def get_checkpoint(output_folder, resume_from):
def get_checkpoint(output_folder: str, resume_from: typing.Literal["last", "best"] | str | None) -> str | None :
"""Gets a checkpoint file.
Can return the best or last checkpoint, or a checkpoint at a specific path.
Ensures the checkpoint exists, raising an error if it is not the case.
If resume_from is None, checks the output directory if a checkpoint already exists and returns it.
If no checkpoint is found, returns None.
Parameters
----------
output_folder : :py:class:`str`
output_folder:
Directory in which checkpoints are stored.
resume_from : :py:class:`str`
resume_from:
Which model to get. Can be one of "best", "last", or a path to a checkpoint.
If None, gets the last checkpoint if it exists, otherwise returns None
Returns
-------
checkpoint_file : :py:class:`str`
The requested model.
checkpoint_file:
Path to the requested checkpoint or None.
"""
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
best_checkpoint_path = os.path.join(
@@ -49,7 +54,11 @@ def get_checkpoint(output_folder, resume_from):
)
elif resume_from is None:
checkpoint_file = None
if os.path.isfile(last_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(f"Found existing checkpoint {last_checkpoint_path}. Loading.")
else:
return None
else:
if os.path.isfile(resume_from):
Loading