diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index d6e35e84c14f59d400bd52d83b3d7b03cf9fb645..b266ae6221cf9a925ff941f1c99bdfdd044fa23f 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -10,6 +10,9 @@ from lightning.pytorch.callbacks import BasePredictionWriter # This ensures CSVLogger logs training and evaluation metrics on the same line # CSVLogger only accepts numerical values, not strings class LoggingCallback(Callback): + """Lightning callback to log various training metrics and device + information.""" + def __init__(self, resource_monitor): super().__init__() self.training_loss = [] @@ -79,6 +82,8 @@ class LoggingCallback(Callback): class PredictionsWriter(BasePredictionWriter): + """Lightning callback to write predictions to a file.""" + def __init__(self, logfile_name, logfile_fields, write_interval): super().__init__(write_interval) self.logfile_name = logfile_name diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 33dc18bea49097bfd3935b6b08141149a39ec392..dc037af0679d3493a718a26dc393a61d00659a27 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -14,35 +14,35 @@ logger = logging.getLogger(__name__) def run(model, data_loader, name, accelerator, output_folder, grad_cams=False): - """Runs inference on input data, outputs HDF5 files with predictions. + """Runs inference on input data, outputs csv files with predictions. Parameters --------- model : :py:class:`torch.nn.Module` - neural network model (e.g. pasa) + Neural network model (e.g. pasa). data_loader : py:class:`torch.torch.utils.data.DataLoader` + The pytorch Dataloader used to iterate over batches. name : str - the local name of this dataset (e.g. ``train``, or ``test``), to be + The local name of this dataset (e.g. ``train``, or ``test``), to be used when saving measures files. accelerator : str - accelerator to use + A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) output_folder : str - folder where to store output prediction and model - summary + Directory in which the results will be saved. grad_cams : bool - if we export grad cams for every prediction (must be used along - a batch size of 1 with the DensenetRS model) + If we export grad cams for every prediction (must be used along + a batch size of 1 with the DensenetRS model). Returns ------- all_predictions : list - All the predictions associated with filename and groundtruth + All the predictions associated with filename and ground truth. """ output_folder = os.path.join(output_folder, name) diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 76aaaf7626fe729174c0bf9980141da393c87fa7..a85a3da566691922323fbeb8a56d3472def48389 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -170,10 +170,10 @@ def run( ---------- model : :py:class:`torch.nn.Module` - Network (e.g. driu, hed, unet) + Neural network model (e.g. pasa). data_loader : :py:class:`torch.utils.data.DataLoader` - To be used to train the model + The pytorch Dataloader used to iterate over batches. valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader` To be used to validate the model and enable automatic checkpointing. @@ -186,20 +186,20 @@ def run( the final training log. checkpoint_period : int - save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do - not save intermediary checkpoints + Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do + not save intermediary checkpoints. accelerator : str - accelerator to use + A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) arguments : dict - start and end epochs + Start and end epochs: output_folder : str - output path + Directory in which the results will be saved. monitoring_interval : int, float - interval, in seconds (or fractions), through which we should monitor + Interval, in seconds (or fractions), through which we should monitor resources during training. batch_chunk_count: int diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 689bca1bc7a201b9b5855716f6d19120e4630437..3613e75ad03c7f64ce9ce2c81e013d251fdf2858 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -64,7 +64,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--accelerator", "-a", - help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available', + help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', show_default=True, required=True, default="cpu", diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 507052997760a8b8d114aa6613ace087e37efe7c..12c5a287f5682340ecb1275f4439ebd4056c657b 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -178,7 +178,7 @@ def set_reproducible_cuda(): @click.option( "--accelerator", "-a", - help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available', + help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', show_default=True, required=True, default="cpu", @@ -235,7 +235,7 @@ def set_reproducible_cuda(): ) @click.option( "--resume-from", - help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a ckpt file.", + help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.", type=str, required=False, default=None, diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py index a30cba9444504a2abb350f4f7d9ca513bde9dde3..81516d28110871690ab148305fb9f5925a728601 100644 --- a/src/ptbench/utils/checkpointer.py +++ b/src/ptbench/utils/checkpointer.py @@ -5,6 +5,26 @@ logger = logging.getLogger(__name__) def get_checkpoint(output_folder, resume_from): + """Gets a checkpoint file. + + Can return the best or last checkpoint, or a checkpoint at a specific path. + Ensures the checkpoint exists, raising an error if it is not the case. + + Parameters + ---------- + + output_folder : :py:class:`str` + Directory in which checkpoints are stored. + + resume_from : :py:class:`str` + Which model to get. Can be one of "best", "last", or a path to a checkpoint. + + Returns + ------- + + checkpoint_file : :py:class:`str` + The requested model. + """ last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") best_checkpoint_path = os.path.join( output_folder, "model_lowest_valid_loss.ckpt"