From f3cc3c556bb26401aedfd5cbe9a97c673fd0d887 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 10 May 2023 12:01:08 +0200 Subject: [PATCH] Updated docstrings --- src/ptbench/engine/callbacks.py | 5 +++++ src/ptbench/engine/predictor.py | 18 +++++++++--------- src/ptbench/engine/trainer.py | 16 ++++++++-------- src/ptbench/scripts/predict.py | 2 +- src/ptbench/scripts/train.py | 4 ++-- src/ptbench/utils/checkpointer.py | 20 ++++++++++++++++++++ 6 files changed, 45 insertions(+), 20 deletions(-) diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index d6e35e84..b266ae62 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -10,6 +10,9 @@ from lightning.pytorch.callbacks import BasePredictionWriter # This ensures CSVLogger logs training and evaluation metrics on the same line # CSVLogger only accepts numerical values, not strings class LoggingCallback(Callback): + """Lightning callback to log various training metrics and device + information.""" + def __init__(self, resource_monitor): super().__init__() self.training_loss = [] @@ -79,6 +82,8 @@ class LoggingCallback(Callback): class PredictionsWriter(BasePredictionWriter): + """Lightning callback to write predictions to a file.""" + def __init__(self, logfile_name, logfile_fields, write_interval): super().__init__(write_interval) self.logfile_name = logfile_name diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 33dc18be..dc037af0 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -14,35 +14,35 @@ logger = logging.getLogger(__name__) def run(model, data_loader, name, accelerator, output_folder, grad_cams=False): - """Runs inference on input data, outputs HDF5 files with predictions. + """Runs inference on input data, outputs csv files with predictions. Parameters --------- model : :py:class:`torch.nn.Module` - neural network model (e.g. pasa) + Neural network model (e.g. pasa). data_loader : py:class:`torch.torch.utils.data.DataLoader` + The pytorch Dataloader used to iterate over batches. name : str - the local name of this dataset (e.g. ``train``, or ``test``), to be + The local name of this dataset (e.g. ``train``, or ``test``), to be used when saving measures files. accelerator : str - accelerator to use + A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) output_folder : str - folder where to store output prediction and model - summary + Directory in which the results will be saved. grad_cams : bool - if we export grad cams for every prediction (must be used along - a batch size of 1 with the DensenetRS model) + If we export grad cams for every prediction (must be used along + a batch size of 1 with the DensenetRS model). Returns ------- all_predictions : list - All the predictions associated with filename and groundtruth + All the predictions associated with filename and ground truth. """ output_folder = os.path.join(output_folder, name) diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 76aaaf76..a85a3da5 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -170,10 +170,10 @@ def run( ---------- model : :py:class:`torch.nn.Module` - Network (e.g. driu, hed, unet) + Neural network model (e.g. pasa). data_loader : :py:class:`torch.utils.data.DataLoader` - To be used to train the model + The pytorch Dataloader used to iterate over batches. valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` To be used to validate the model and enable automatic checkpointing. @@ -186,20 +186,20 @@ def run( the final training log. checkpoint_period : int - save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do - not save intermediary checkpoints + Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do + not save intermediary checkpoints. accelerator : str - accelerator to use + A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) arguments : dict - start and end epochs + Start and end epochs: output_folder : str - output path + Directory in which the results will be saved. monitoring_interval : int, float - interval, in seconds (or fractions), through which we should monitor + Interval, in seconds (or fractions), through which we should monitor resources during training. batch_chunk_count: int diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 689bca1b..3613e75a 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -64,7 +64,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--accelerator", "-a", - help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available', + help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', show_default=True, required=True, default="cpu", diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 50705299..12c5a287 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -178,7 +178,7 @@ def set_reproducible_cuda(): @click.option( "--accelerator", "-a", - help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available', + help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', show_default=True, required=True, default="cpu", @@ -235,7 +235,7 @@ def set_reproducible_cuda(): ) @click.option( "--resume-from", - help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a ckpt file.", + help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.", type=str, required=False, default=None, diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py index a30cba94..81516d28 100644 --- a/src/ptbench/utils/checkpointer.py +++ b/src/ptbench/utils/checkpointer.py @@ -5,6 +5,26 @@ logger = logging.getLogger(__name__) def get_checkpoint(output_folder, resume_from): + """Gets a checkpoint file. + + Can return the best or last checkpoint, or a checkpoint at a specific path. + Ensures the checkpoint exists, raising an error if it is not the case. + + Parameters + ---------- + + output_folder : :py:class:`str` + Directory in which checkpoints are stored. + + resume_from : :py:class:`str` + Which model to get. Can be one of "best", "last", or a path to a checkpoint. + + Returns + ------- + + checkpoint_file : :py:class:`str` + The requested model. + """ last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") best_checkpoint_path = os.path.join( output_folder, "model_lowest_valid_loss.ckpt" -- GitLab