Skip to content
Snippets Groups Projects
Commit f3cc3c55 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Updated docstrings

parent b3b1462b
No related branches found
No related tags found
1 merge request!4Moved code to lightning
Pipeline #73262 passed
...@@ -10,6 +10,9 @@ from lightning.pytorch.callbacks import BasePredictionWriter ...@@ -10,6 +10,9 @@ from lightning.pytorch.callbacks import BasePredictionWriter
# This ensures CSVLogger logs training and evaluation metrics on the same line # This ensures CSVLogger logs training and evaluation metrics on the same line
# CSVLogger only accepts numerical values, not strings # CSVLogger only accepts numerical values, not strings
class LoggingCallback(Callback): class LoggingCallback(Callback):
"""Lightning callback to log various training metrics and device
information."""
def __init__(self, resource_monitor): def __init__(self, resource_monitor):
super().__init__() super().__init__()
self.training_loss = [] self.training_loss = []
...@@ -79,6 +82,8 @@ class LoggingCallback(Callback): ...@@ -79,6 +82,8 @@ class LoggingCallback(Callback):
class PredictionsWriter(BasePredictionWriter): class PredictionsWriter(BasePredictionWriter):
"""Lightning callback to write predictions to a file."""
def __init__(self, logfile_name, logfile_fields, write_interval): def __init__(self, logfile_name, logfile_fields, write_interval):
super().__init__(write_interval) super().__init__(write_interval)
self.logfile_name = logfile_name self.logfile_name = logfile_name
......
...@@ -14,35 +14,35 @@ logger = logging.getLogger(__name__) ...@@ -14,35 +14,35 @@ logger = logging.getLogger(__name__)
def run(model, data_loader, name, accelerator, output_folder, grad_cams=False): def run(model, data_loader, name, accelerator, output_folder, grad_cams=False):
"""Runs inference on input data, outputs HDF5 files with predictions. """Runs inference on input data, outputs csv files with predictions.
Parameters Parameters
--------- ---------
model : :py:class:`torch.nn.Module` model : :py:class:`torch.nn.Module`
neural network model (e.g. pasa) Neural network model (e.g. pasa).
data_loader : py:class:`torch.torch.utils.data.DataLoader` data_loader : py:class:`torch.torch.utils.data.DataLoader`
The pytorch Dataloader used to iterate over batches.
name : str name : str
the local name of this dataset (e.g. ``train``, or ``test``), to be The local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files. used when saving measures files.
accelerator : str accelerator : str
accelerator to use A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
output_folder : str output_folder : str
folder where to store output prediction and model Directory in which the results will be saved.
summary
grad_cams : bool grad_cams : bool
if we export grad cams for every prediction (must be used along If we export grad cams for every prediction (must be used along
a batch size of 1 with the DensenetRS model) a batch size of 1 with the DensenetRS model).
Returns Returns
------- -------
all_predictions : list all_predictions : list
All the predictions associated with filename and groundtruth All the predictions associated with filename and ground truth.
""" """
output_folder = os.path.join(output_folder, name) output_folder = os.path.join(output_folder, name)
......
...@@ -170,10 +170,10 @@ def run( ...@@ -170,10 +170,10 @@ def run(
---------- ----------
model : :py:class:`torch.nn.Module` model : :py:class:`torch.nn.Module`
Network (e.g. driu, hed, unet) Neural network model (e.g. pasa).
data_loader : :py:class:`torch.utils.data.DataLoader` data_loader : :py:class:`torch.utils.data.DataLoader`
To be used to train the model The pytorch Dataloader used to iterate over batches.
valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
To be used to validate the model and enable automatic checkpointing. To be used to validate the model and enable automatic checkpointing.
...@@ -186,20 +186,20 @@ def run( ...@@ -186,20 +186,20 @@ def run(
the final training log. the final training log.
checkpoint_period : int checkpoint_period : int
save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints not save intermediary checkpoints.
accelerator : str accelerator : str
accelerator to use A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
arguments : dict arguments : dict
start and end epochs Start and end epochs:
output_folder : str output_folder : str
output path Directory in which the results will be saved.
monitoring_interval : int, float monitoring_interval : int, float
interval, in seconds (or fractions), through which we should monitor Interval, in seconds (or fractions), through which we should monitor
resources during training. resources during training.
batch_chunk_count: int batch_chunk_count: int
......
...@@ -64,7 +64,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -64,7 +64,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.option( @click.option(
"--accelerator", "--accelerator",
"-a", "-a",
help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available', help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
show_default=True, show_default=True,
required=True, required=True,
default="cpu", default="cpu",
......
...@@ -178,7 +178,7 @@ def set_reproducible_cuda(): ...@@ -178,7 +178,7 @@ def set_reproducible_cuda():
@click.option( @click.option(
"--accelerator", "--accelerator",
"-a", "-a",
help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available', help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
show_default=True, show_default=True,
required=True, required=True,
default="cpu", default="cpu",
...@@ -235,7 +235,7 @@ def set_reproducible_cuda(): ...@@ -235,7 +235,7 @@ def set_reproducible_cuda():
) )
@click.option( @click.option(
"--resume-from", "--resume-from",
help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a ckpt file.", help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.",
type=str, type=str,
required=False, required=False,
default=None, default=None,
......
...@@ -5,6 +5,26 @@ logger = logging.getLogger(__name__) ...@@ -5,6 +5,26 @@ logger = logging.getLogger(__name__)
def get_checkpoint(output_folder, resume_from): def get_checkpoint(output_folder, resume_from):
"""Gets a checkpoint file.
Can return the best or last checkpoint, or a checkpoint at a specific path.
Ensures the checkpoint exists, raising an error if it is not the case.
Parameters
----------
output_folder : :py:class:`str`
Directory in which checkpoints are stored.
resume_from : :py:class:`str`
Which model to get. Can be one of "best", "last", or a path to a checkpoint.
Returns
-------
checkpoint_file : :py:class:`str`
The requested model.
"""
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
best_checkpoint_path = os.path.join( best_checkpoint_path = os.path.join(
output_folder, "model_lowest_valid_loss.ckpt" output_folder, "model_lowest_valid_loss.ckpt"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment