From f3cc3c556bb26401aedfd5cbe9a97c673fd0d887 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 10 May 2023 12:01:08 +0200
Subject: [PATCH] Updated docstrings

---
 src/ptbench/engine/callbacks.py   |  5 +++++
 src/ptbench/engine/predictor.py   | 18 +++++++++---------
 src/ptbench/engine/trainer.py     | 16 ++++++++--------
 src/ptbench/scripts/predict.py    |  2 +-
 src/ptbench/scripts/train.py      |  4 ++--
 src/ptbench/utils/checkpointer.py | 20 ++++++++++++++++++++
 6 files changed, 45 insertions(+), 20 deletions(-)

diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index d6e35e84..b266ae62 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -10,6 +10,9 @@ from lightning.pytorch.callbacks import BasePredictionWriter
 # This ensures CSVLogger logs training and evaluation metrics on the same line
 # CSVLogger only accepts numerical values, not strings
 class LoggingCallback(Callback):
+    """Lightning callback to log various training metrics and device
+    information."""
+
     def __init__(self, resource_monitor):
         super().__init__()
         self.training_loss = []
@@ -79,6 +82,8 @@ class LoggingCallback(Callback):
 
 
 class PredictionsWriter(BasePredictionWriter):
+    """Lightning callback to write predictions to a file."""
+
     def __init__(self, logfile_name, logfile_fields, write_interval):
         super().__init__(write_interval)
         self.logfile_name = logfile_name
diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py
index 33dc18be..dc037af0 100644
--- a/src/ptbench/engine/predictor.py
+++ b/src/ptbench/engine/predictor.py
@@ -14,35 +14,35 @@ logger = logging.getLogger(__name__)
 
 
 def run(model, data_loader, name, accelerator, output_folder, grad_cams=False):
-    """Runs inference on input data, outputs HDF5 files with predictions.
+    """Runs inference on input data, outputs csv files with predictions.
 
     Parameters
     ---------
     model : :py:class:`torch.nn.Module`
-        neural network model (e.g. pasa)
+        Neural network model (e.g. pasa).
 
     data_loader : py:class:`torch.torch.utils.data.DataLoader`
+        The pytorch Dataloader used to iterate over batches.
 
     name : str
-        the local name of this dataset (e.g. ``train``, or ``test``), to be
+        The local name of this dataset (e.g. ``train``, or ``test``), to be
         used when saving measures files.
 
     accelerator : str
-        accelerator to use
+        A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
 
     output_folder : str
-        folder where to store output prediction and model
-        summary
+        Directory in which the results will be saved.
 
     grad_cams : bool
-        if we export grad cams for every prediction (must be used along
-        a batch size of 1 with the DensenetRS model)
+        If we export grad cams for every prediction (must be used along
+        a batch size of 1 with the DensenetRS model).
 
     Returns
     -------
 
     all_predictions : list
-        All the predictions associated with filename and groundtruth
+        All the predictions associated with filename and ground truth.
     """
     output_folder = os.path.join(output_folder, name)
 
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 76aaaf76..a85a3da5 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -170,10 +170,10 @@ def run(
     ----------
 
     model : :py:class:`torch.nn.Module`
-        Network (e.g. driu, hed, unet)
+        Neural network model (e.g. pasa).
 
     data_loader : :py:class:`torch.utils.data.DataLoader`
-        To be used to train the model
+        The pytorch Dataloader used to iterate over batches.
 
     valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
         To be used to validate the model and enable automatic checkpointing.
@@ -186,20 +186,20 @@ def run(
         the final training log.
 
     checkpoint_period : int
-        save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
-        not save intermediary checkpoints
+        Save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
+        not save intermediary checkpoints.
 
     accelerator : str
-        accelerator to use
+        A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
 
     arguments : dict
-        start and end epochs
+        Start and end epochs:
 
     output_folder : str
-        output path
+        Directory in which the results will be saved.
 
     monitoring_interval : int, float
-        interval, in seconds (or fractions), through which we should monitor
+        Interval, in seconds (or fractions), through which we should monitor
         resources during training.
 
     batch_chunk_count: int
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index 689bca1b..3613e75a 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -64,7 +64,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @click.option(
     "--accelerator",
     "-a",
-    help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available',
+    help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
     show_default=True,
     required=True,
     default="cpu",
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 50705299..12c5a287 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -178,7 +178,7 @@ def set_reproducible_cuda():
 @click.option(
     "--accelerator",
     "-a",
-    help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available',
+    help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
     show_default=True,
     required=True,
     default="cpu",
@@ -235,7 +235,7 @@ def set_reproducible_cuda():
 )
 @click.option(
     "--resume-from",
-    help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a ckpt file.",
+    help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a  model checkpoint.",
     type=str,
     required=False,
     default=None,
diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py
index a30cba94..81516d28 100644
--- a/src/ptbench/utils/checkpointer.py
+++ b/src/ptbench/utils/checkpointer.py
@@ -5,6 +5,26 @@ logger = logging.getLogger(__name__)
 
 
 def get_checkpoint(output_folder, resume_from):
+    """Gets a checkpoint file.
+
+    Can return the best or last checkpoint, or a checkpoint at a specific path.
+    Ensures the checkpoint exists, raising an error if it is not the case.
+
+    Parameters
+    ----------
+
+    output_folder : :py:class:`str`
+        Directory in which checkpoints are stored.
+
+    resume_from : :py:class:`str`
+        Which model to get. Can be one of "best", "last", or a path to a checkpoint.
+
+    Returns
+    -------
+
+    checkpoint_file : :py:class:`str`
+        The requested model.
+    """
     last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
     best_checkpoint_path = os.path.join(
         output_folder, "model_lowest_valid_loss.ckpt"
-- 
GitLab