From ec656ca45e3dd6ee54face7d1488ac8a19852035 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Fri, 26 Jan 2024 16:35:39 +0100
Subject: [PATCH] [doc] Reformatted docstrings to proper numpy-style

---
 src/mednet/config/data/hivtb/datamodule.py    |  8 +--
 .../config/data/montgomery/datamodule.py      |  7 +-
 .../config/data/nih_cxr14/datamodule.py       |  8 +--
 src/mednet/config/data/padchest/datamodule.py |  8 +--
 src/mednet/config/data/shenzhen/datamodule.py |  8 +--
 src/mednet/config/data/tbpoc/datamodule.py    |  4 +-
 src/mednet/config/data/tbx11k/datamodule.py   | 15 ++--
 .../data/tbx11k/make_splits_from_database.py  | 22 ++----
 src/mednet/data/augmentations.py              | 29 +-------
 src/mednet/data/datamodule.py                 | 38 ++--------
 src/mednet/data/image_utils.py                |  4 +-
 src/mednet/data/split.py                      | 22 ++----
 src/mednet/engine/callbacks.py                | 72 ++++++-------------
 src/mednet/engine/device.py                   | 10 +--
 src/mednet/engine/evaluator.py                | 48 ++++++-------
 src/mednet/engine/loggers.py                  |  1 -
 src/mednet/engine/predictor.py                | 17 +++--
 src/mednet/engine/saliency/completeness.py    | 14 ++--
 src/mednet/engine/saliency/evaluator.py       | 23 +++---
 src/mednet/engine/saliency/generator.py       |  2 +-
 .../engine/saliency/interpretability.py       | 29 +++-----
 src/mednet/engine/saliency/viewer.py          | 21 +++---
 src/mednet/engine/trainer.py                  | 37 +++-------
 src/mednet/models/alexnet.py                  | 12 +---
 src/mednet/models/densenet.py                 | 12 +---
 src/mednet/models/logistic_regression.py      |  6 +-
 src/mednet/models/loss_weights.py             | 14 ++--
 src/mednet/models/mlp.py                      |  7 +-
 src/mednet/models/normalizer.py               | 10 +--
 src/mednet/models/pasa.py                     | 14 ++--
 src/mednet/models/separate.py                 |  7 +-
 src/mednet/models/transforms.py               | 10 +--
 src/mednet/scripts/click.py                   | 11 +--
 src/mednet/scripts/database.py                |  4 +-
 src/mednet/scripts/train.py                   |  2 -
 src/mednet/scripts/train_analysis.py          | 10 +--
 src/mednet/scripts/utils.py                   |  2 -
 src/mednet/utils/checkpointer.py              | 19 ++---
 src/mednet/utils/resources.py                 | 58 +++++----------
 src/mednet/utils/summary.py                   | 19 +++--
 src/mednet/utils/tensorboard.py               |  3 +-
 41 files changed, 227 insertions(+), 440 deletions(-)

diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py
index d5bf8103..2185b14d 100644
--- a/src/mednet/config/data/hivtb/datamodule.py
+++ b/src/mednet/config/data/hivtb/datamodule.py
@@ -49,10 +49,10 @@ class RawDataLoader(_BaseRawDataLoader):
             where to find the image to be loaded, and an integer, representing the
             sample label.
 
-
         Returns
         -------
-            The sample representation
+        Sample
+            The sample representation.
         """
         image = PIL.Image.open(os.path.join(self.datadir, sample[0])).convert(
             "L"
@@ -78,10 +78,10 @@ class RawDataLoader(_BaseRawDataLoader):
             where to find the image to be loaded, and an integer, representing the
             sample label.
 
-
         Returns
         -------
-            The integer label associated with the sample
+        int
+            The integer label associated with the sample.
         """
         return sample[1]
 
diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py
index ec1b7c14..dbe43cc6 100644
--- a/src/mednet/config/data/montgomery/datamodule.py
+++ b/src/mednet/config/data/montgomery/datamodule.py
@@ -51,7 +51,8 @@ class RawDataLoader(_BaseRawDataLoader):
 
         Returns
         -------
-            The sample representation
+        Sample
+            The sample representation.
         """
         # N.B.: Montgomery images are encoded as grayscale PNGs, so no need to
         # convert them again with Image.convert("L").
@@ -77,10 +78,10 @@ class RawDataLoader(_BaseRawDataLoader):
             where to find the image to be loaded, and an integer, representing the
             sample label.
 
-
         Returns
         -------
-            The integer label associated with the sample
+        int
+            The integer label associated with the sample.
         """
         return sample[1]
 
diff --git a/src/mednet/config/data/nih_cxr14/datamodule.py b/src/mednet/config/data/nih_cxr14/datamodule.py
index b42f1d1c..4971814c 100644
--- a/src/mednet/config/data/nih_cxr14/datamodule.py
+++ b/src/mednet/config/data/nih_cxr14/datamodule.py
@@ -72,10 +72,10 @@ class RawDataLoader(_BaseRawDataLoader):
             where to find the image to be loaded, and an integer, representing the
             sample label.
 
-
         Returns
         -------
-            The sample representation
+        Sample
+            The sample representation.
         """
         file_path = sample[0]  # default
         if self.idiap_file_organisation:
@@ -112,10 +112,10 @@ class RawDataLoader(_BaseRawDataLoader):
             where to find the image to be loaded, and an integer, representing the
             sample label.
 
-
         Returns
         -------
-            The integer labels associated with the sample
+        list[int]
+            The integer labels associated with the sample.
         """
         return sample[1]
 
diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py
index a065dece..3abf944d 100644
--- a/src/mednet/config/data/padchest/datamodule.py
+++ b/src/mednet/config/data/padchest/datamodule.py
@@ -50,10 +50,10 @@ class RawDataLoader(_BaseRawDataLoader):
             where to find the image to be loaded, and an integer, representing the
             sample label.
 
-
         Returns
         -------
-            The sample representation
+        Sample
+            The sample representation.
         """
         # N.B.: PadChest images are encoded as 16-bit grayscale images
         image = PIL.Image.open(os.path.join(self.datadir, sample[0]))
@@ -79,10 +79,10 @@ class RawDataLoader(_BaseRawDataLoader):
             where to find the image to be loaded, and an integer, representing the
             sample label.
 
-
         Returns
         -------
-            The integer labels associated with the sample
+        list[int]
+            The integer labels associated with the sample.
         """
         return sample[1]
 
diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py
index 64091794..4a210cfe 100644
--- a/src/mednet/config/data/shenzhen/datamodule.py
+++ b/src/mednet/config/data/shenzhen/datamodule.py
@@ -50,10 +50,10 @@ class RawDataLoader(_BaseRawDataLoader):
             where to find the image to be loaded, and an integer, representing the
             sample label.
 
-
         Returns
         -------
-            The sample representation
+        Sample
+            The sample representation.
         """
         # N.B.: Image.convert("L") is required to normalize grayscale back to
         # normal (instead of inverted).
@@ -81,10 +81,10 @@ class RawDataLoader(_BaseRawDataLoader):
             where to find the image to be loaded, and an integer, representing the
             sample label.
 
-
         Returns
         -------
-            The integer label associated with the sample
+        int
+            The integer label associated with the sample.
         """
         return sample[1]
 
diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py
index ffe59568..a7870e13 100644
--- a/src/mednet/config/data/tbpoc/datamodule.py
+++ b/src/mednet/config/data/tbpoc/datamodule.py
@@ -44,10 +44,10 @@ class RawDataLoader(_BaseRawDataLoader):
             where to find the image to be loaded, and an integer, representing the
             sample label.
 
-
         Returns
         -------
-            The sample representation
+        Sample
+            The sample representation.
         """
         # images from TBPOC are encoded as grayscale JPEGs, no need to
         # call convert("L") here.
diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py
index ae71e46b..5baa60ff 100644
--- a/src/mednet/config/data/tbx11k/datamodule.py
+++ b/src/mednet/config/data/tbx11k/datamodule.py
@@ -52,6 +52,7 @@ class BoundingBox:
 
         Returns
         -------
+        int
             The area in square-pixels.
         """
         return self.width * self.height
@@ -75,14 +76,14 @@ class BoundingBox:
         * 2 pixels of index : i1 = 1 and i2 = 3, the segment from pixel i1 to
           i2 contains 3 pixels ie l = i2 - i1 + 1
 
-
         Parameters
         ----------
         other
-            The other bounding box to check intersections for
+            The other bounding box to check intersections for.
 
         Returns
         -------
+        int
             The area intersection between this and the other bounding-box in
             square pixels.
         """
@@ -158,10 +159,10 @@ class RawDataLoader(_BaseRawDataLoader):
             sample label, and possible radiological findings represented by
             bounding boxes.
 
-
         Returns
         -------
-            The sample representation
+        Sample
+            The sample representation.
         """
         image = PIL.Image.open(os.path.join(self.datadir, sample[0]))
         tensor = to_tensor(image)
@@ -188,10 +189,10 @@ class RawDataLoader(_BaseRawDataLoader):
             sample label, and possible radiological findings represented by
             bounding boxes.
 
-
         Returns
         -------
-            The integer label associated with the sample
+        int
+            The integer label associated with the sample.
         """
         return sample[1]
 
@@ -206,9 +207,9 @@ class RawDataLoader(_BaseRawDataLoader):
             sample label, and possible radiological findings represented by
             bounding boxes.
 
-
         Returns
         -------
+        BoundingBoxes
             Bounding box annotations, if any available with the sample.
         """
         if len(sample) > 2:
diff --git a/src/mednet/config/data/tbx11k/make_splits_from_database.py b/src/mednet/config/data/tbx11k/make_splits_from_database.py
index c432a297..90238b07 100644
--- a/src/mednet/config/data/tbx11k/make_splits_from_database.py
+++ b/src/mednet/config/data/tbx11k/make_splits_from_database.py
@@ -202,18 +202,14 @@ def create_v1_default_split(d: dict, seed: int, validation_size: float) -> dict:
     1. The original validation set becomes the test set.
     2. The original training set is split into new training and validation
        sets.  The selection of samples is stratified (respects class
-       proportions in Özgür's way - see comments)
-
+       proportions in Özgür's way - see comments).
 
     Parameters
     ----------
-
     d
-        The original dataset that will be split
-
+        The original dataset that will be split.
     seed
-        The seed to use at the relevant RNG
-
+        The seed to use at the relevant RNG.
     validation_size
         The proportion of data when we split the training set to make a
         train and validation sets.
@@ -295,19 +291,15 @@ def create_folds(
 
     Parameters
     ----------
-
     d
-        The original split to consider
-
+        The original split to consider.
     n
-        The number of folds to produce
-
+        The number of folds to produce.
 
     Returns
     -------
-
-    folds
-        All the ``n`` folds
+    list[dict]
+        All the ``n`` folds.
     """
 
     X = d["train"] + d["validation"] + d["test"]
diff --git a/src/mednet/data/augmentations.py b/src/mednet/data/augmentations.py
index 91eaa595..b76171c9 100644
--- a/src/mednet/data/augmentations.py
+++ b/src/mednet/data/augmentations.py
@@ -42,37 +42,28 @@ def _elastic_deformation_on_image(
     :py:func:`scipy.ndimage.map_coordinates`).  It is very inefficient since it
     requires data to be moved off the current running device and then back.
 
-
     Parameters
     ----------
-
     img
         The input image to apply elastic deformation to.  This image should
         always have this shape: ``[C, H, W]``. It should always represent a
         tensor on the CPU.
-
     alpha
-        A multiplier for the gaussian filter outputs
-
+        A multiplier for the gaussian filter outputs.
     sigma
         Standard deviation for Gaussian kernel.
-
     spline_order
         The order of the spline interpolation, default is 1. The order has to
         be in the range 0-5.
-
     mode
         The mode parameter determines how the input array is extended beyond
         its boundaries.
-
     p
         Probability that this transformation will be applied.  Meaningful when
         using it as a data augmentation technique.
 
-
     Returns
     -------
-
     tensor
         The image with elastic deformation applied, as a tensor on the CPU.
     """
@@ -150,37 +141,28 @@ def _elastic_deformation_on_batch(
     :py:func:`scipy.ndimage.map_coordinates`).  It is very inefficient since it
     requires data to be moved off the current running device and then back.
 
-
     Parameters
     ----------
-
     img
         The input image to apply elastic deformation to.  This image should
         always have this shape: ``[C, H, W]``. It should always represent a
         tensor on the CPU.
-
     alpha
-        A multiplier for the gaussian filter outputs
-
+        A multiplier for the gaussian filter outputs.
     sigma
         Standard deviation for Gaussian kernel.
-
     spline_order
         The order of the spline interpolation, default is 1. The order has to
         be in the range 0-5.
-
     mode
         The mode parameter determines how the input array is extended beyond
         its boundaries.
-
     p
         Probability that this transformation will be applied.  Meaningful when
         using it as a data augmentation technique.
 
-
     Returns
     -------
-
     tensor
         A batch of images with elastic deformation applied, as a tensor on the CPU.
     """
@@ -220,28 +202,21 @@ class ElasticDeformation:
 
     Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0
 
-
     Parameters
     ----------
-
     alpha
         A multiplier for the gaussian filter outputs.
-
     sigma
         Standard deviation for Gaussian kernel.
-
     spline_order
         The order of the spline interpolation, default is 1. The order has to
         be in the range 0-5.
-
     mode
         The mode parameter determines how the input array is extended beyond
         its boundaries.
-
     p
         Probability that this transformation will be applied.  Meaningful when
         using it as a data augmentation technique.
-
     parallel
         Use multiprocessing for processing batches of data: if set to -1
         (default), disables multiprocessing.  If set to -2, then enable
diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py
index 5c8e3058..758f0ef0 100644
--- a/src/mednet/data/datamodule.py
+++ b/src/mednet/data/datamodule.py
@@ -37,12 +37,12 @@ def _sample_size_bytes(s: Sample) -> int:
     Parameters
     ----------
     s
-        The sample to be analyzed
-
+        The sample to be analyzed.
 
     Returns
     -------
-        The size in bytes occupied by this sample
+    int
+        The size in bytes occupied by this sample.
     """
 
     def _tensor_size_bytes(t: torch.Tensor) -> int:
@@ -68,16 +68,13 @@ class _DelayedLoadingDataset(Dataset):
     This list mimics a pytorch Dataset, except that raw data loading is done
     on-the-fly, as the samples are requested through the bracket operator.
 
-
     Parameters
     ----------
     raw_dataset
         An iterable containing the raw dataset samples representing one of the
         database split datasets.
-
     loader
         An object instance that can load samples and labels from storage.
-
     transforms
         A set of transforms that should be applied on-the-fly for this dataset,
         to fit the output of the raw-data-loader to the model of interest.
@@ -129,19 +126,17 @@ def _apply_loader_and_transforms(
     Parameters
     ----------
     info
-        The sample information, as loaded from its raw dataset dictionary
-
+        The sample information, as loaded from its raw dataset dictionary.
     load
-        The raw-data loader function to use for loading the sample
-
+        The raw-data loader function to use for loading the sample.
     model_transform
         A callable that will transform the loaded tensor into something
         suitable for the model it will train.  Typically, this will be a
         composed transform.
 
-
     Returns
     -------
+    Sample
         The loaded and transformed sample.
     """
     sample = load(info)
@@ -155,22 +150,18 @@ class _CachedDataset(Dataset):
     instead of delaying that to the indexing.  Beyond raw-data-loading,
     ``transforms`` given upon construction contribute to the cached samples.
 
-
     Parameters
     ----------
     raw_dataset
         An iterable containing the raw dataset samples representing one of the
         database split datasets.
-
     loader
         An object instance that can load samples and labels from storage.
-
     parallel
         Use multiprocessing for data loading: if set to -1 (default), disables
         multiprocessing data loading.  Set to 0 to enable as many data loading
         instances as processing cores available in the system.  Set to >= 1
         to enable that many multiprocessing instances for data loading.
-
     transforms
         A set of transforms that should be applied to the cached samples for
         this dataset, to fit the output of the raw-data-loader to the model of
@@ -320,24 +311,20 @@ def _make_balanced_random_sampler(
     We then instantiate a pytorch sampler using the inverse probabilities (the
     more samples in a class, the less likely it becomes to be sampled.
 
-
     Parameters
     ----------
     dataset
         An instance of torch Dataset.
         :py:class:`torch.utils.data.ConcatDataset` are supported.
-
     target
         The name of a metadata key pointing to an integer property that allows
         balancing the dataset.
 
-
     Returns
     -------
         A sampler, to be used in a dataloader equipped with the same dataset
         used to calculate the relative sample weights.
 
-
     Raises
     ------
     RuntimeError
@@ -418,7 +405,6 @@ class ConcatDataModule(lightning.LightningDataModule):
     prediction and testing conditions.  Parallelisation is handled by a simple
     input flag.
 
-
     Parameters
     ----------
     splits
@@ -443,7 +429,6 @@ class ConcatDataModule(lightning.LightningDataModule):
         Entries named ``monitor-...`` will be considered extra datasets that do
         not influence any early stop criteria during training, and are just
         monitored beyond the ``validation`` dataset.
-
     cache_samples
         If set, then issue raw data loading during ``prepare_data()``, and
         serves samples from CPU memory.  Otherwise, loads samples from disk on
@@ -451,12 +436,10 @@ class ConcatDataModule(lightning.LightningDataModule):
         for CPU memory.  Sufficient CPU memory must be available before you set
         this attribute to ``True``.  It is typically useful for relatively small
         datasets.
-
     balance_sampler_by_class
         If set, then modifies the random sampler used during training and
         validation to balance sample picking probability, making sample
         across classes **and** datasets equitable.
-
     batch_size
         Number of samples in every **training** batch (this parameter affects
         memory requirements for the network).  If the number of samples in the
@@ -468,7 +451,6 @@ class ConcatDataModule(lightning.LightningDataModule):
         the last batch will be smaller than the first, unless
         ``drop_incomplete_batch`` is set to ``true``, in which case this batch
         is not used.
-
     batch_chunk_count
         Number of chunks in every batch (this parameter affects memory
         requirements for the network). The number of samples loaded for every
@@ -481,13 +463,11 @@ class ConcatDataModule(lightning.LightningDataModule):
         whole batch to be processed at once. Otherwise the batch is broken into
         batch-chunk-count pieces, and gradients are accumulated to complete
         each batch.
-
     drop_incomplete_batch
         If set, then may drop the last batch in an epoch in case it is
         incomplete.  If you set this option, you should also consider
         increasing the total number of training epochs, as the total number
         of training steps may be reduced.
-
     parallel
         Use multiprocessing for data loading: if set to -1 (default), disables
         multiprocessing data loading.  Set to 0 to enable as many data loading
@@ -670,7 +650,6 @@ class ConcatDataModule(lightning.LightningDataModule):
             the last batch will be smaller than the first, unless
             ``drop_incomplete_batch`` is set to ``true``, in which case this batch
             is not used.
-
         batch_chunk_count
             Number of chunks in every batch (this parameter affects memory
             requirements for the network). The number of samples loaded for every
@@ -770,7 +749,6 @@ class ConcatDataModule(lightning.LightningDataModule):
         If you have set ``cache_samples``, samples are loaded at this stage and
         cached in memory.
 
-
         Parameters
         ----------
         stage
@@ -808,7 +786,6 @@ class ConcatDataModule(lightning.LightningDataModule):
         If you have set ``cache_samples``, samples are loaded and this may
         effectivley release all the associated memory.
 
-
         Parameters
         ----------
         stage
@@ -901,7 +878,6 @@ class CachingDataModule(ConcatDataModule):
     Apart from construction, the behaviour of this data module is very similar
     to its simpler counterpart, serving training, validation and test sets.
 
-
     Parameters
     ----------
     database_split
@@ -923,10 +899,8 @@ class CachingDataModule(ConcatDataModule):
         Entries named ``monitor-...`` will be considered extra datasets that do
         not influence any early stop criteria during training, and are just
         monitored beyond the ``validation`` dataset.
-
     raw_data_loader
         An object instance that can load samples and labels from storage.
-
     **kwargs
         List of named parameters matching those of
         :py:class:`ConcatDataModule`, other than ``splits``.
diff --git a/src/mednet/data/image_utils.py b/src/mednet/data/image_utils.py
index 3dce9cd2..c0e0ff6a 100644
--- a/src/mednet/data/image_utils.py
+++ b/src/mednet/data/image_utils.py
@@ -15,14 +15,14 @@ def remove_black_borders(
     Parameters
     ----------
         img
-            A PIL image
+            A PIL image.
         threshold
             Threshold value from which borders are considered black.
             Defaults to 0.
 
     Returns
     -------
-        A PIL image with black borders removed
+        A PIL image with black borders removed.
     """
 
     img_array = numpy.asarray(img)
diff --git a/src/mednet/data/split.py b/src/mednet/data/split.py
index ddae2a1c..634c1b65 100644
--- a/src/mednet/data/split.py
+++ b/src/mednet/data/split.py
@@ -59,10 +59,8 @@ class JSONDatabaseSplit(DatabaseSplit):
     actual JSON file descriptors are loaded on demand using
     a py:func:`functools.cached_property`.
 
-
     Parameters
     ----------
-
     path
         Absolute path to a JSON formatted file containing the database split to be
         recognized by this object.
@@ -80,12 +78,10 @@ class JSONDatabaseSplit(DatabaseSplit):
         The first call to this (cached) property will trigger full JSON file
         loading from disk.  Subsequent calls will be cached.
 
-
         Returns
         -------
-
-        datasets : dict
-            A dictionary mapping dataset names to lists of JSON objects
+        DatabaseSplit
+            A dictionary mapping dataset names to lists of JSON objects.
         """
 
         if str(self._path).endswith(".bz2"):
@@ -134,10 +130,8 @@ class CSVDatabaseSplit(DatabaseSplit):
     Objects of this class behave like a dictionary in which keys are dataset
     names in the split, and values represent samples data and meta-data.
 
-
     Parameters
     ----------
-
     directory
         Absolute path to a directory containing the database split organized
         as a set of CSV files, one per dataset.
@@ -160,10 +154,8 @@ class CSVDatabaseSplit(DatabaseSplit):
         The first call to this (cached) property will trigger all CSV file
         loading from disk.  Subsequent calls will be cached.
 
-
         Returns
         -------
-
         datasets : dict
             A dictionary mapping dataset names to lists of JSON objects
         """
@@ -211,29 +203,23 @@ def check_database_split_loading(
     This function will return the number of errors when loading samples, and will
     log more detailed information to the logging stream.
 
-
     Parameters
     ----------
-
     database_split
         A mapping that contains the database split.  Each key represents the
         name of a dataset in the split.  Each value is a (potentially complex)
         object that represents a single sample.
-
     loader
         A loader object that knows how to handle full-samples or just labels.
-
     limit
         Maximum number of samples to check (in each split/dataset
         combination) in this dataset.  If set to zero, then check
         everything.
 
-
     Returns
     -------
-
-    errors
-        Number of errors found
+    int
+        Number of errors found.
     """
     logger.info(
         "Checking if all samples in all datasets of this split can be loaded..."
diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py
index cba08495..90912a8c 100644
--- a/src/mednet/engine/callbacks.py
+++ b/src/mednet/engine/callbacks.py
@@ -21,14 +21,12 @@ class LoggingCallback(lightning.pytorch.Callback):
     Rationale:
 
     1. Losses are logged at the end of every batch, accumulated and handled by
-       the lightning framework
+       the lightning framework.
     2. Everything else is done at the end of a training or validation epoch and
        mostly concerns runtime metrics such as memory and cpu/gpu utilisation.
 
-
     Parameters
     ----------
-
     train_resource_monitor
         A monitor that watches resource usage (CPU/GPU) in a separate process
         and totally asynchronously with the code execution.
@@ -67,15 +65,12 @@ class LoggingCallback(lightning.pytorch.Callback):
 
         This method is executed whenever you *start* training a module.
 
-
         Parameters
         ---------
-
         trainer
-            The Lightning trainer object
-
+            The Lightning trainer object.
         pl_module
-            The lightning module that is being trained
+            The lightning module that is being trained.
         """
         self._start_training_time = time.time()
 
@@ -96,15 +91,12 @@ class LoggingCallback(lightning.pytorch.Callback):
            This is executed **while** you are training.  Be very succint or
            face the consequences of slow training!
 
-
         Parameters
         ---------
-
         trainer
-            The Lightning trainer object
-
+            The Lightning trainer object.
         pl_module
-            The lightning module that is being trained
+            The lightning module that is being trained.
         """
         # summarizes resource usage since the last checkpoint
         # clears internal buffers and starts accumulating again.
@@ -123,15 +115,12 @@ class LoggingCallback(lightning.pytorch.Callback):
         epochs happen as often as possible.  You want to make this code
         relatively fast to avoid significative runtime slow-downs.
 
-
         Parameters
         ----------
-
         trainer
-            The Lightning trainer object
-
+            The Lightning trainer object.
         pl_module
-            The lightning module that is being trained
+            The lightning module that is being trained.
         """
 
         # evaluates this training epoch total time, and log it
@@ -189,24 +178,18 @@ class LoggingCallback(lightning.pytorch.Callback):
            This is executed **while** you are training.  Be very succint or
            face the consequences of slow training!
 
-
         Parameters
         ----------
-
         trainer
-            The Lightning trainer object
-
+            The Lightning trainer object.
         pl_module
-            The lightning module that is being trained
-
+            The lightning module that is being trained.
         outputs
-            The outputs of the module's ``training_step``
-
+            The outputs of the module's ``training_step``.
         batch
-            The data that the training step received
-
+            The data that the training step received.
         batch_idx
-            The relative number of the batch
+            The relative number of the batch.
         """
         pl_module.log(
             "loss/train",
@@ -234,15 +217,12 @@ class LoggingCallback(lightning.pytorch.Callback):
            This is executed **while** you are training.  Be very succint or
            face the consequences of slow training!
 
-
         Parameters
         ---------
-
         trainer
-            The Lightning trainer object
-
+            The Lightning trainer object.
         pl_module
-            The lightning module that is being trained
+            The lightning module that is being trained.
         """
         # required because the validation epoch is started **within** the
         # training epoch START/END.
@@ -265,15 +245,12 @@ class LoggingCallback(lightning.pytorch.Callback):
         epochs happen as often as possible.  You want to make this code
         relatively fast to avoid significative runtime slow-downs.
 
-
         Parameters
         ----------
-
         trainer
-            The Lightning trainer object
-
+            The Lightning trainer object.
         pl_module
-            The lightning module that is being trained
+            The lightning module that is being trained.
         """
 
         # summarizes resource usage since the last checkpoint
@@ -322,25 +299,18 @@ class LoggingCallback(lightning.pytorch.Callback):
            This is executed **while** you are training.  Be very succint or
            face the consequences of slow training!
 
-
         Parameters
         ----------
-
         trainer
-            The Lightning trainer object
-
+            The Lightning trainer object.
         pl_module
-            The lightning module that is being trained
-
+            The lightning module that is being trained.
         outputs
-            The outputs of the module's ``training_step``
-
+            The outputs of the module's ``training_step``.
         batch
-            The data that the training step received
-
+            The data that the training step received.
         batch_idx
-            The relative number of the batch
-
+            The relative number of the batch.
         dataloader_idx
             Index of the dataloader used during validation.  Use this to figure
             out which dataset was used for this validation epoch.
diff --git a/src/mednet/engine/device.py b/src/mednet/engine/device.py
index d11aacff..8d5574e7 100644
--- a/src/mednet/engine/device.py
+++ b/src/mednet/engine/device.py
@@ -38,10 +38,8 @@ class DeviceManager:
     Instances of this class also manage the environment variable
     ``$CUDA_VISIBLE_DEVICES`` if necessary.
 
-
     Parameters
     ----------
-
     name
         The name of the device to use, in the form of a string defined by
         ``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``).  In
@@ -122,10 +120,8 @@ class DeviceManager:
            device.  This may impact Nvidia GPU logging in the case multiple
            GPU cards are used.
 
-
         Returns
         -------
-
         device
             The **first** torch device (if a list of ids is set).
         """
@@ -148,12 +144,10 @@ class DeviceManager:
 
         Returns
         -------
-
         accelerator
-            The lightning accelerator to use
-
+            The lightning accelerator to use.
         devices
-            The lightning devices to use
+            The lightning devices to use.
         """
 
         devices: int | list[int] | str = self.device_ids
diff --git a/src/mednet/engine/evaluator.py b/src/mednet/engine/evaluator.py
index 1add3edb..51f5ed55 100644
--- a/src/mednet/engine/evaluator.py
+++ b/src/mednet/engine/evaluator.py
@@ -32,9 +32,9 @@ def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float:
         An iterable of multiple
         :py:data:`.models.typing.BinaryPrediction`'s.
 
-
     Returns
     -------
+    float
         The EER threshold value.
     """
     from scipy.interpolate import interp1d
@@ -51,19 +51,20 @@ def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float:
 
 def _get_centered_maxf1(
     f1_scores: numpy.typing.NDArray, thresholds: numpy.typing.NDArray
-):
+) -> tuple[float, float]:
     """Return the centered max F1 score threshold when multiple thresholds give
     the same max F1 score.
 
     Parameters
     ----------
     f1_scores
-        1D array of f1 scores
+        1D array of f1 scores.
     thresholds
-        1D array of thresholds
+        1D array of thresholds.
 
     Returns
     -------
+    tuple(float, float)
         A tuple with the maximum F1-score and the "centered" threshold.
     """
     maxf1 = f1_scores.max()
@@ -88,9 +89,9 @@ def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float:
         An iterable of multiple
         :py:data:`.models.typing.BinaryPrediction`'s.
 
-
     Returns
     -------
+    float
         The threshold value leading to the maximum F1-score on the provided set
         of predictions.
     """
@@ -122,17 +123,17 @@ def _score_plot(
     Parameters
     ----------
     labels
-        True labels (negatives and positives) for each entry in ``scores``
+        True labels (negatives and positives) for each entry in ``scores``.
     scores
-        Likelihoods provided by the classification model, for each sample
+        Likelihoods provided by the classification model, for each sample.
     title
         Title of the plot.
     threshold
-        Shows where the threshold is in the figure
-
+        Shows where the threshold is in the figure.
 
     Returns
     -------
+    matplotlib.figure.Figure
         A single (matplotlib) plot containing the score distribution, ready to
         be saved to disk or displayed.
     """
@@ -199,20 +200,23 @@ def run_binary(
     name
         The name of subset to load.
     predictions
-        A list of predictions to consider for measurement
+        A list of predictions to consider for measurement.
     threshold_a_priori
         A threshold to use, evaluated *a priori*, if must report single values.
         If this value is not provided, an *a posteriori* threshold is calculated
         on the input scores.  This is a biased estimator.
 
-
     Returns
     -------
+    tuple[
+    dict[str, typing.Any],
+    dict[str, matplotlib.figure.Figure],
+    dict[str, typing.Any]]
         A tuple containing the following entries:
 
         * summary: A dictionary containing the performance summary on the
-          specified threshold
-        * figures: A dictionary of generated standalone figures
+          specified threshold.
+        * figures: A dictionary of generated standalone figures.
         * curves: A dictionary containing curves that can potentially be combined
           with other prediction lists to make aggregate plots.
     """
@@ -282,19 +286,18 @@ def aggregate_summaries(
     This function can properly tabulate the various summaries produced for all
     the splits in a prediction database.
 
-
     Parameters
     ----------
     data
-        An iterable over all summary data collected
+        An iterable over all summary data collected.
     fmt
         One of the formats supported by `python-tabulate
         <https://pypi.org/project/tabulate/>`_.
 
-
     Returns
     -------
-        A string containing the tabulated information
+    str
+        A string containing the tabulated information.
     """
 
     headers = list(data[0].keys())
@@ -311,18 +314,17 @@ def aggregate_roc(
     This function produces a single ROC plot for multiple curves generated per
     split.
 
-
     Parameters
     ----------
     data
         A dictionary mapping split names to ROC curve data produced by
         :py:func:sklearn.metrics.roc_curve`.
-
     title
         The title of the plot.
 
     Returns
     -------
+    matplotlib.figure.Figure
         A figure, containing the aggregated ROC plot.
     """
     fig, ax = plt.subplots(1, 1)
@@ -396,13 +398,12 @@ def _precision_recall_canvas() -> (
     contains F1-ISO lines and is preset to a 0-1 square region.  Once the
     context is finished, ``fig.tight_layout()`` is called.
 
-
     Yields
     ------
     figure
-        The figure that should be finally returned to the user
+        The figure that should be finally returned to the user.
     axes
-        An axis set where to precision-recall plots should be added to
+        An axis set where to precision-recall plots should be added to.
     """
 
     fig, axes1 = plt.subplots(1)
@@ -467,7 +468,6 @@ def aggregate_pr(
     generated per split. The plot will be annotated with F1-score iso-lines (in
     which the F1-score maintains the same value).
 
-
     Parameters
     ----------
     data
@@ -476,9 +476,9 @@ def aggregate_pr(
     title
         The title of the plot.
 
-
     Returns
     -------
+    matplotlib.figure.Figure
         A figure, containing the aggregated PR plot.
     """
 
diff --git a/src/mednet/engine/loggers.py b/src/mednet/engine/loggers.py
index 94c2ca3b..d597294b 100644
--- a/src/mednet/engine/loggers.py
+++ b/src/mednet/engine/loggers.py
@@ -15,7 +15,6 @@ class CustomTensorboardLogger(TensorBoardLogger):
     This implementation puts all logs inside the same directory, instead of a
     separate "version_n" directories, which is the default lightning behaviour.
 
-
     Parameters
     ----------
     save_dir
diff --git a/src/mednet/engine/predictor.py b/src/mednet/engine/predictor.py
index 3125ef87..e4d625fb 100644
--- a/src/mednet/engine/predictor.py
+++ b/src/mednet/engine/predictor.py
@@ -44,22 +44,29 @@ def run(
         validation.  This representation can be converted into a pytorch device
         or a lightning accelerator setup.
 
-
     Returns
     -------
+    (
+    list[BinaryPrediction]
+    | list[MultiClassPrediction]
+    | list[list[BinaryPrediction]]
+    | list[list[MultiClassPrediction]]
+    | BinaryPredictionSplit
+    | MultiClassPredictionSplit
+    | None
+    )
         Depending on the return type of the datamodule's
         ``predict_dataloader()`` method:
 
         * if :py:class:`torch.utils.data.DataLoader`, then returns a
-          :py:class:`list` of predictions
+          :py:class:`list` of predictions.
         * if :py:class:`list` of :py:class:`torch.utils.data.DataLoader`, then
           returns a list of lists of predictions, each list corresponding to
           the iteration over one of the dataloaders.
         * if :py:class:`dict` of :py:class:`str` to
           :py:class:`torch.utils.data.DataLoader`, then returns a dictionary
-          mapping names to lists of predictions
-        * if ``None``, then returns ``None``
-
+          mapping names to lists of predictions.
+        * if ``None``, then returns ``None``.
 
     Raises
     ------
diff --git a/src/mednet/engine/saliency/completeness.py b/src/mednet/engine/saliency/completeness.py
index 95a2c7fc..655cdedc 100644
--- a/src/mednet/engine/saliency/completeness.py
+++ b/src/mednet/engine/saliency/completeness.py
@@ -49,7 +49,6 @@ def _calculate_road_scores(
     different removal (hardcoded) percentiles, for a single input image, a
     given visualization method, and a target class.
 
-
     Parameters
     ----------
     model
@@ -59,15 +58,14 @@ def _calculate_road_scores(
         we only support batches with a single image.
     output_num
         Target output neuron to take into consideration when evaluating the
-        saliency maps and calculating ROAD scores
+        saliency maps and calculating ROAD scores.
     saliency_map_callable
-        A callable saliency-map generator from grad-cam
+        A callable saliency-map generator from grad-cam.
     percentiles
         A sequence of percentiles (percent x100) integer values indicating the
         proportion of pixels to perturb in the original image to calculate both
         MoRF and LeRF scores.
 
-
     Returns
     -------
         A 3-tuple containing floating point numbers representing the
@@ -132,7 +130,7 @@ def _process_sample(
     device
         The device to process samples on.
     saliency_map_callable
-        A callable saliency-map generator from grad-cam
+        A callable saliency-map generator from grad-cam.
     target_class
         Class to target for saliency estimation. Can be set to
         "all" or "highest". "highest" is default, which means
@@ -140,7 +138,7 @@ def _process_sample(
         activation will be generated.
     positive only
         If set, and the model chosen has a single output (binary), then
-        saliency maps will only be generated for samples of the positive class
+        saliency maps will only be generated for samples of the positive class.
     percentiles
         A sequence of percentiles (percent x100) integer values indicating the
         proportion of pixels to perturb in the original image to calculate both
@@ -222,7 +220,6 @@ def run(
     percentile of the least relevant pixels), and combined ROAD evaluations per
     sample for a particular saliency mapping algorithm.
 
-
     Parameters
     ---------
     model
@@ -254,10 +251,9 @@ def run(
         as processing cores available in the system.  Set to >= 1 to enable
         that many multiprocessing instances for data processing.
 
-
     Returns
     -------
-
+    dict[str, list[typing.Any]]
         A dictionary where keys are dataset names in the provide datamodule,
         and values are lists containing sample information alongside metrics
         calculated:
diff --git a/src/mednet/engine/saliency/evaluator.py b/src/mednet/engine/saliency/evaluator.py
index b2ae1d53..3f2afbba 100644
--- a/src/mednet/engine/saliency/evaluator.py
+++ b/src/mednet/engine/saliency/evaluator.py
@@ -23,7 +23,6 @@ def _reconcile_metrics(
     sample, for the selected dataset.  Only samples for which a completness and
     interpretability scores are availble are returned in the reconciled list.
 
-
     Parameters
     ----------
     completeness
@@ -33,9 +32,9 @@ def _reconcile_metrics(
         A list containing various tables with the sample name and
         interpretability (Pro. Energy) scores.
 
-
     Returns
     -------
+    list[tuple[str, int, float, float, float]]
         A list containing a table with the sample name, target label,
         completeness score (Average ROAD across different ablation thresholds),
         interpretability score (Proportional Energy), and the ROAD-Weighted
@@ -87,18 +86,18 @@ def _make_histogram(
     Parameters
     ----------
     name
-        Name of the variable to be histogrammed (will appear in the figure)
+        Name of the variable to be histogrammed (will appear in the figure).
     values
-        Values to be histogrammed
+        Values to be histogrammed.
     xlim
         A tuple representing the X-axis maximum and minimum to plot. If not
         set, then use the bin boundaries.
     title
-        A title to set on the histogram
-
+        A title to set on the histogram.
 
     Returns
     -------
+    matplotlib.figure.Figure
         A matplotlib figure containing the histogram.
     """
 
@@ -154,9 +153,9 @@ def summary_table(
         One of the formats supported by `python-tabulate
         <https://pypi.org/project/tabulate/>`_.
 
-
     Returns
     -------
+    str
         A string containing the tabulated information.
     """
 
@@ -197,17 +196,17 @@ def _extract_statistics(
         produced by completeness and interpretability analysis as returned by
         :py:func:`_reconcile_metrics`.
     name
-        The name of the variable being analysed
+        The name of the variable being analysed.
     index
         The index of the tuple contained in ``data`` that should be extracted.
     dataset
-        The name of the dataset being analysed
+        The name of the dataset being analysed.
     xlim
-        Limits for histogram plotting
-
+        Limits for histogram plotting.
 
     Returns
     -------
+    dict[str, typing.Any]
         A dictionary containing the following elements:
 
         * ``values``: A list of values corresponding to the index on the data
@@ -261,9 +260,9 @@ def run(
         A dictionary mapping dataset names to tables with the sample name and
         interpretability (among which Prop. Energy) scores.
 
-
     Returns
     -------
+    dict[str, typing.Any]
         A dictionary with most important statistical values for the main
         completeness (AOPC-Combined), interpretability (Prop. Energy), and a
         combination of both (ROAD-Weighted Prop. Energy) scores.
diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py
index bae23050..2096adc1 100644
--- a/src/mednet/engine/saliency/generator.py
+++ b/src/mednet/engine/saliency/generator.py
@@ -143,7 +143,7 @@ def run(
         a multi-class output model.
     output_folder
         Where to save all the saliency maps (this path should exist before
-        this function is called)
+        this function is called).
     """
     from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
 
diff --git a/src/mednet/engine/saliency/interpretability.py b/src/mednet/engine/saliency/interpretability.py
index 65bd3614..b18cfe95 100644
--- a/src/mednet/engine/saliency/interpretability.py
+++ b/src/mednet/engine/saliency/interpretability.py
@@ -44,7 +44,6 @@ def _ordered_connected_components(
     4. We histogram the labels and return one binary mask for each label,
        sorted by decreasing size.
 
-
     Parameters
     ----------
     saliency_map
@@ -55,9 +54,9 @@ def _ordered_connected_components(
         map.  A value of 0.2 will zero all values in the saliency map that are
         bellow 20% of the maximum value observed in the said map.
 
-
     Returns
     -------
+    list[BinaryMask]
         A list of boolean masks, one for each connected component, ordered by
         decreasing size.  This list may be empty if the input ``saliency_map``
         is all zeroes.
@@ -89,9 +88,9 @@ def _extract_bounding_box(
     mask
         The connected component mask from whom extract the bounding box.
 
-
     Returns
     -------
+    BoundingBox
         A bounding box.
     """
     x, y, x2, y2 = torchvision.ops.masks_to_boxes(torch.tensor(mask)[None, :])[
@@ -110,7 +109,6 @@ def _compute_max_iou_and_ioda(
     for each gt box separately and the gt box with the highest
     intersecting part will be used for the calculation.
 
-
     Parameters
     ----------
     detected_box
@@ -119,9 +117,9 @@ def _compute_max_iou_and_ioda(
         Ground-truth bounding boxes in the format ``(x, y, width,
         height)``.
 
-
     Returns
     -------
+    tuple[float, float]
         The max iou and ioda values.
     """
     detected_area = detected_box.area()
@@ -162,7 +160,6 @@ def _get_largest_bounding_boxes(
     well as on the saliency map itself.  The number of objects found is also
     affected by those parameters.
 
-
     Parameters
     ----------
     saliency_map
@@ -176,9 +173,9 @@ def _get_largest_bounding_boxes(
         map.  A value of 0.2 will zero all values in the saliency map that are
         bellow 20% of the maximum value observed in the said map.
 
-
     Returns
     -------
+    list[BoundingBox]
         The N largest connected components as bounding boxes in a saliency map.
     """
 
@@ -202,7 +199,6 @@ def _compute_simultaneous_iou_and_ioda(
     will be compared to them simultaneously (and not to each gt box
     separately).
 
-
     Parameters
     ----------
     detected_box
@@ -211,9 +207,9 @@ def _compute_simultaneous_iou_and_ioda(
             Collection of bounding boxes of the ground-truth drawn as
             ``True`` values.
 
-
     Returns
     -------
+    tuple[float, float]
         The iou and ioda for the provided boxes.
     """
 
@@ -244,9 +240,9 @@ def _compute_iou_ioda_from_largest_bbox(
         A real-valued saliency-map that conveys regions used for
         classification in the original sample.
 
-
     Returns
     -------
+    tuple[float, float]
         A tuple containing the iou and ioda for the largest bounding box.
     """
 
@@ -269,7 +265,6 @@ def _compute_avg_saliency_focus(
     ground-truth bounding boxes and normalize it by the total area covered by
     all ground-truth bounding boxes.
 
-
     Parameters
     ----------
     saliency_map
@@ -279,9 +274,9 @@ def _compute_avg_saliency_focus(
         Ground-truth mask containing the bounding boxes of the ground-truth
         drawn as ``True`` values.
 
-
     Returns
     -------
+    float
         A single floating-point number representing the Average saliency focus.
     """
 
@@ -308,9 +303,9 @@ def _compute_proportional_energy(
         Ground-truth mask containing the bounding boxes of the ground-truth
         drawn as ``True`` values.
 
-
     Returns
     -------
+    float
         A single floating-point number representing the proportional energy.
     """
 
@@ -330,20 +325,18 @@ def _compute_binary_mask(
 
     The binary_mask will be ON/True where the gt boxes are located.
 
-
     Parameters
     ----------
     gt_bboxes
         Ground-truth bounding boxes in the format ``(x, y, width,
         height)``.
-
     saliency_map
         A real-valued saliency-map that conveys regions used for
         classification in the original sample.
 
-
     Returns
     -------
+    BinaryMask
         A numpy array of the same size as saliency_map with
         the value False everywhere except at the positions inside
         the bounding boxes, which will be True.
@@ -372,9 +365,9 @@ def _process_sample(
         A real-valued saliency-map that conveys regions used for
         classification in the original sample.
 
-
     Returns
     -------
+    tuple[float, float]
         A tuple containing the following values:
 
         * Proportional energy
@@ -408,9 +401,9 @@ def run(
     datamodule
         The lightning datamodule to iterate on.
 
-
     Returns
     -------
+    dict[str, list[typing.Any]]
         A dictionary where keys are dataset names in the provided datamodule,
         and values are lists containing sample information alongside metrics
         calculated:
diff --git a/src/mednet/engine/saliency/viewer.py b/src/mednet/engine/saliency/viewer.py
index 866f5073..184a586d 100644
--- a/src/mednet/engine/saliency/viewer.py
+++ b/src/mednet/engine/saliency/viewer.py
@@ -60,22 +60,21 @@ def _overlay_saliency_map(
     https://github.com/jacobgil/pytorch-grad-cam, but uses matplotlib instead
     of opencv.
 
-
     Parameters
     ----------
     image
-        The input image that will be overlayed with the saliency map
+        The input image that will be overlayed with the saliency map.
     saliencies
-        The saliency map that will be overlaid on the (raw) image
+        The saliency map that will be overlaid on the (raw) image.
     colormap
-        The name of the (matplotlib) colormap to be used
+        The name of the (matplotlib) colormap to be used.
     image_weight
         The final result is ``image_weight * image + (1-image_weight) *
         saliency_map``.
 
-
     Returns
     -------
+    PIL.Image.Image
         A modified version of the input ``image`` with the overlaid saliency
         map.
     """
@@ -121,7 +120,7 @@ def _overlay_bounding_box(
     image
         The input image that will be overlayed with the saliency map
     bbox
-        The bounding box to draw on the input image
+        The bounding box to draw on the input image.
     color
         The color to use for drawing the bounding box. Any of the colours in
         :any:`PIL.ImageColor.colormap` are accepted.
@@ -129,9 +128,9 @@ def _overlay_bounding_box(
         The width of the bounding box, in pixels.  A larger value creates a
         bounding box that is thicker, towards the outside of the boxed area.
 
-
     Returns
     -------
+    PIL.Image.Image
         A modified version of the input ``image`` with the ground-truth drawn
         on the top.
     """
@@ -157,16 +156,16 @@ def _process_sample(
     ----------
     raw_data
         The raw data representing the input sample that will be overlayed with
-        saliency maps and annotations
+        saliency maps and annotations.
     saliencies
         The saliency map recovered from the model, that will be imprinted on
-        the raw_data
+        the raw_data.
     ground_truth
-        Ground-truth annotations that may be imprinted on the final image
-
+        Ground-truth annotations that may be imprinted on the final image.
 
     Returns
     -------
+    PIL.Image.Image
         An image with the original raw data overlayed with the different
         elements as selected by the user.
     """
diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py
index 24c5dc1c..3c3deeee 100644
--- a/src/mednet/engine/trainer.py
+++ b/src/mednet/engine/trainer.py
@@ -34,21 +34,15 @@ def save_model_summary(
 
     Parameters
     ----------
-
     output_folder
-        output path
-
+        Directory in which to save the summary.
     model
-        Network (e.g. driu, hed, unet)
-
+        Instance of the model for which to save the summary.
 
     Returns
     -------
-    summary
-        The model summary in a text format
-
-    total_parameters
-        The number of parameters of the model
+    tuple[lightning.pytorch.callbacks.ModelSummary, int]
+        A tuple with the model summary in a text format and number of parameters of the model.
     """
     summary_path = output_folder / "model-summary.txt"
     logger.info(f"Saving model summary at {summary_path}...")
@@ -74,16 +68,13 @@ def static_information_to_csv(
 
     Parameters
     ----------
-
     static_logfile_name
         The static file name which is a join between the output folder and
-        "constants.csv"
-
+        "constants.csv".
     device_type
-        The type of device we are using
-
+        The type of device we are using.
     model_size
-        The size of the model we will be training
+        The size of the model we will be training.
     """
     if static_logfile_name.exists():
         backup = static_logfile_name.parent / (static_logfile_name.name + "~")
@@ -129,16 +120,12 @@ def run(
     This method supports periodic checkpointing and the output of a
     CSV-formatted log with the evolution of some figures during training.
 
-
     Parameters
     ----------
-
     model
         Neural network model (e.g. pasa).
-
     datamodule
-        The lightning datamodule to use for training **and** validation
-
+        The lightning datamodule to use for training **and** validation.
     validation_period
         Number of epochs after which validation happens.  By default, we run
         validation after every training epoch (period=1).  You can change this
@@ -148,31 +135,25 @@ def run(
         triggers the overriding of latest checkpoint), and that this process is
         independent of validation runs, evaluation of the 'best' model obtained
         so far based on those will be influenced by this setting.
-
     device_manager
         An internal device representation, to be used for training and
         validation.  This representation can be converted into a pytorch device
         or a lightning accelerator setup.
-
     max_epochs
         The maximum number of epochs to train for.
-
     output_folder
         Folder in which the results will be saved.
-
     monitoring_interval
         Interval, in seconds (or fractions), through which we should monitor
         resources during training.
-
     batch_chunk_count
         If this number is different than 1, then each batch will be divided in
         this number of chunks.  Gradients will be accumulated to perform each
         mini-batch.   This is particularly interesting when one has limited RAM
         on the GPU, but would like to keep training with larger batches.  One
         exchanges for longer processing times in this case.
-
     checkpoint
-        Path to an optional checkpoint file to load
+        Path to an optional checkpoint file to load.
     """
 
     os.makedirs(output_folder, exist_ok=True)
diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py
index 53fbf49a..9c990d84 100644
--- a/src/mednet/models/alexnet.py
+++ b/src/mednet/models/alexnet.py
@@ -35,7 +35,6 @@ class Alexnet(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-
     validation_loss
         The loss to be used for validation (may be different from the training
         loss).  If extra-validation sets are provided, the same loss will be
@@ -45,21 +44,16 @@ class Alexnet(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-
     optimizer_type
-        The type of optimizer to use for training
-
+        The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
-
     augmentation_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
-
     pretrained
         If set to True, loads pretrained model weights during initialization,
         else trains a new model.
-
     num_classes
         Number of outputs (classes) for this model.
     """
@@ -125,7 +119,7 @@ class Alexnet(pl.LightningModule):
         Parameters
         ----------
         checkpoint
-            The checkpoint to save
+            The checkpoint to save.
         """
         checkpoint["normalizer"] = self.normalizer
 
@@ -138,7 +132,7 @@ class Alexnet(pl.LightningModule):
         Parameters
         ----------
         checkpoint
-            The loaded checkpoint
+            The loaded checkpoint.
         """
         logger.info("Restoring normalizer from checkpoint.")
         self.normalizer = checkpoint["normalizer"]
diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py
index e2bfc310..19bf75b3 100644
--- a/src/mednet/models/densenet.py
+++ b/src/mednet/models/densenet.py
@@ -33,7 +33,6 @@ class Densenet(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-
     validation_loss
         The loss to be used for validation (may be different from the training
         loss).  If extra-validation sets are provided, the same loss will be
@@ -43,21 +42,16 @@ class Densenet(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-
     optimizer_type
-        The type of optimizer to use for training
-
+        The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
-
     augmentation_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
-
     pretrained
         If set to True, loads pretrained model weights during initialization,
         else trains a new model.
-
     num_classes
         Number of outputs (classes) for this model.
     """
@@ -125,7 +119,7 @@ class Densenet(pl.LightningModule):
         Parameters
         ----------
         checkpoint
-            The checkpoint to save
+            The checkpoint to save.
         """
         checkpoint["normalizer"] = self.normalizer
 
@@ -138,7 +132,7 @@ class Densenet(pl.LightningModule):
         Parameters
         ----------
         checkpoint
-            The loaded checkpoint
+            The loaded checkpoint.
         """
         logger.info("Restoring normalizer from checkpoint.")
         self.normalizer = checkpoint["normalizer"]
diff --git a/src/mednet/models/logistic_regression.py b/src/mednet/models/logistic_regression.py
index 6a88d967..fd3281ec 100644
--- a/src/mednet/models/logistic_regression.py
+++ b/src/mednet/models/logistic_regression.py
@@ -23,7 +23,6 @@ class LogisticRegression(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-
     validation_loss
         The loss to be used for validation (may be different from the training
         loss).  If extra-validation sets are provided, the same loss will be
@@ -33,13 +32,10 @@ class LogisticRegression(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-
     optimizer_type
-        The type of optimizer to use for training
-
+        The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
-
     input_size
         The number of inputs this classifer shall process.
     """
diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py
index c6af3d59..c0da89ff 100644
--- a/src/mednet/models/loss_weights.py
+++ b/src/mednet/models/loss_weights.py
@@ -23,20 +23,16 @@ def _get_label_weights(
 
     It returns a vector with weights (inverse counts) for each label.
 
-
     Parameters
     ----------
-
     dataloader
         A DataLoader from which to compute the positive weights.  Entries must
         be a dictionary which must contain a ``label`` key.
 
-
     Returns
     -------
-
-    positive_weights
-        The positive weight of each class in the dataset given as input
+    torch.Tensor
+        The positive weight of each class in the dataset given as input.
     """
 
     targets = torch.tensor(
@@ -78,12 +74,10 @@ def make_balanced_bcewithlogitsloss(
     The loss is weighted using the ratio between positives and total examples
     available.
 
-
     Returns
     -------
-
-    loss
-        An instance of the weighted loss
+    torch.nn.BCEWithLogitsLoss
+        An instance of the weighted loss.
     """
 
     weights = _get_label_weights(dataloader)
diff --git a/src/mednet/models/mlp.py b/src/mednet/models/mlp.py
index ac59ad6f..831b7385 100644
--- a/src/mednet/models/mlp.py
+++ b/src/mednet/models/mlp.py
@@ -22,7 +22,6 @@ class MultiLayerPerceptron(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-
     validation_loss
         The loss to be used for validation (may be different from the training
         loss).  If extra-validation sets are provided, the same loss will be
@@ -32,16 +31,12 @@ class MultiLayerPerceptron(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-
     optimizer_type
-        The type of optimizer to use for training
-
+        The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
-
     input_size
         The number of inputs this classifer shall process.
-
     hidden_size
         The number of neurons on the single hidden layer.
     """
diff --git a/src/mednet/models/normalizer.py b/src/mednet/models/normalizer.py
index 576f21cc..6bdce4fc 100644
--- a/src/mednet/models/normalizer.py
+++ b/src/mednet/models/normalizer.py
@@ -19,17 +19,14 @@ def make_z_normalizer(
     deviation by image channel.  It will work for both monochromatic, and color
     inputs with 2, 3 or more color planes.
 
-
     Parameters
     ----------
-
     dataloader:
-        A torch Dataloader from which to compute the mean and std
-
+        A torch Dataloader from which to compute the mean and std.
 
     Returns
     -------
-        An initialized normalizer
+        An initialized normalizer.
     """
 
     # Peek the number of channels of batches in the data loader
@@ -63,10 +60,9 @@ def make_imagenet_normalizer() -> torchvision.transforms.Normalize:
     The weights are wrapped in a torch module.  This normalizer only works for
     **RGB (color) images**.
 
-
     Returns
     -------
-        An initialized normalizer
+        An initialized normalizer.
     """
 
     return torchvision.transforms.Normalize(
diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py
index 7741944a..0125b0e0 100644
--- a/src/mednet/models/pasa.py
+++ b/src/mednet/models/pasa.py
@@ -29,7 +29,6 @@ class Pasa(pl.LightningModule):
     This network has a linear output.  You should use losses with ``WithLogit``
     instead of cross-entropy versions when training.
 
-
     Parameters
     ----------
     train_loss
@@ -39,7 +38,6 @@ class Pasa(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-
     validation_loss
         The loss to be used for validation (may be different from the training
         loss).  If extra-validation sets are provided, the same loss will be
@@ -49,17 +47,13 @@ class Pasa(pl.LightningModule):
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-
     optimizer_type
-        The type of optimizer to use for training
-
+        The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
-
     augmentation_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
-
     num_classes
         Number of outputs (classes) for this model.
     """
@@ -211,7 +205,7 @@ class Pasa(pl.LightningModule):
         Parameters
         ----------
         checkpoint
-            The checkpoint to save
+            The checkpoint to save.
         """
         checkpoint["normalizer"] = self.normalizer
 
@@ -224,7 +218,7 @@ class Pasa(pl.LightningModule):
         Parameters
         ----------
         checkpoint
-            The loaded checkpoint
+            The loaded checkpoint.
         """
         logger.info("Restoring normalizer from checkpoint.")
         self.normalizer = checkpoint["normalizer"]
@@ -235,7 +229,7 @@ class Pasa(pl.LightningModule):
         Parameters
         ----------
         dataloader
-            A torch Dataloader from which to compute the mean and std
+            A torch Dataloader from which to compute the mean and std.
         """
         from .normalizer import make_z_normalizer
 
diff --git a/src/mednet/models/separate.py b/src/mednet/models/separate.py
index 8479c2a7..244d386b 100644
--- a/src/mednet/models/separate.py
+++ b/src/mednet/models/separate.py
@@ -22,9 +22,9 @@ def _as_predictions(
     samples
         A sequence of samples as returned by :py:func:`separate`.
 
-
     Returns
     -------
+    list[BinaryPrediction | MultiClassPrediction]
         A list of typed predictions that can be saved to disk.
     """
     return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples]
@@ -42,15 +42,14 @@ def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
       dimension, via :py:func:`torch.flatten`)
     * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]``
 
-
     Parameters
     ----------
     batch
-        A batch, as output by torch model forwarding
-
+        A batch, as output by torch model forwarding.
 
     Returns
     -------
+    list[BinaryPrediction | MultiClassPrediction]
         A list of predictions that contains the predictions and associated metadata
         for each processed sample.
     """
diff --git a/src/mednet/models/transforms.py b/src/mednet/models/transforms.py
index 8f5286fe..f67d1f66 100644
--- a/src/mednet/models/transforms.py
+++ b/src/mednet/models/transforms.py
@@ -16,18 +16,15 @@ def grayscale_to_rgb(img: torch.Tensor) -> torch.Tensor:
     (number of color channels = 1), then replicate it to obtain 3 color channels
     (a new copy is returned in this case).
 
-
     Parameters
     ----------
-
     img
         The tensor to be transformed.  Expected to be in the form: ``[...,
         [1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
 
     Returns
     -------
-
-    img
+    torch.Tensor
         Transformed tensor with 3 identical color channels.
     """
     if img.ndim < 3:
@@ -66,18 +63,15 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
 
     A new tensor is returned in this case.
 
-
     Parameters
     ----------
-
     img
         The tensor to be transformed.  Expected to be in the form: ``[...,
         [1,3], H, W]`` (i.e. arbitrary number of leading dimensions).
 
     Returns
     -------
-
-    img
+    torch
         Transformed tensor with a single (grayscale) color channel.
     """
     if img.ndim < 3:
diff --git a/src/mednet/scripts/click.py b/src/mednet/scripts/click.py
index 07dfe697..84606e60 100644
--- a/src/mednet/scripts/click.py
+++ b/src/mednet/scripts/click.py
@@ -17,11 +17,12 @@ class ConfigCommand(_BaseConfigCommand):
     ) -> None:
         """Formats the command epilog during --help.
 
-        Arguments:
-
-            _: The current parsing context
-
-            formatter: The formatter to use for printing text
+        Parameters
+        ----------
+            _
+                The current parsing context.
+            formatter
+                The formatter to use for printing text.
         """
 
         if self.epilog:
diff --git a/src/mednet/scripts/database.py b/src/mednet/scripts/database.py
index 683cad2e..8bd43ff2 100644
--- a/src/mednet/scripts/database.py
+++ b/src/mednet/scripts/database.py
@@ -15,12 +15,12 @@ def _get_raw_databases() -> dict[str, dict[str, str]]:
 
     Returns
     -------
-    d
+    dict[str, dict[str, str]]
         Dictionary where keys are database names, and values are dictionaries
         containing two string keys:
 
         * ``module``: the full Pythonic module name (e.g.
-        ``mednet.data.montgomery``)
+        ``mednet.data.montgomery``).
         * ``datadir``: points to the user-configured data directory for the
         current dataset, if set, or ``None`` otherwise.
     """
diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py
index b09bca13..de8801af 100644
--- a/src/mednet/scripts/train.py
+++ b/src/mednet/scripts/train.py
@@ -21,14 +21,12 @@ def reusable_options(f):
     This decorator equips the target function ``f`` with all (reusable)
     ``train`` script options.
 
-
     Parameters
     ----------
     f
         The target function to equip with options.  This function must have
         parameters that accept such options.
 
-
     Returns
     -------
         The decorated version of function ``f``
diff --git a/src/mednet/scripts/train_analysis.py b/src/mednet/scripts/train_analysis.py
index 98c77548..88a6bcb2 100644
--- a/src/mednet/scripts/train_analysis.py
+++ b/src/mednet/scripts/train_analysis.py
@@ -40,25 +40,21 @@ def create_figures(
     It is assumed that some metric names are of the form <metric>/<subset>.
     All subsets for a metric will be displayed on the same figure.
 
-
     Parameters
     ----------
-
-    data:
+    data
         A dictionary where keys represent all scalar names, and values
         correspond to a tuple that contains an array with epoch numbers (when
         values were taken), and the monitored values themselves.  These lists
         are pre-sorted by epoch number.
-    groups:
+    groups
         A list of scalar globs present in the existing tensorboard data that
         we are interested in for plotting.  Values with multiple matches are
         drawn on the same plot.  Values that do not exist are ignored.
 
-
     Returns
     -------
-
-    figures:
+    list
         List of matplotlib figures, one per metric.
     """
     import fnmatch
diff --git a/src/mednet/scripts/utils.py b/src/mednet/scripts/utils.py
index 511fc884..b3b0d310 100644
--- a/src/mednet/scripts/utils.py
+++ b/src/mednet/scripts/utils.py
@@ -21,10 +21,8 @@ def save_sh_command(path: pathlib.Path) -> None:
     records further information on the date and time the script was run and the
     version of the package.
 
-
     Parameters
     ----------
-
     path
         Path to a file where the commands to reproduce the current run will be
         recorded.  Parent directories will be created if they do not exist. An
diff --git a/src/mednet/utils/checkpointer.py b/src/mednet/utils/checkpointer.py
index 543f7c70..bedff23a 100644
--- a/src/mednet/utils/checkpointer.py
+++ b/src/mednet/utils/checkpointer.py
@@ -35,21 +35,18 @@ def _get_checkpoint_from_alias(
     If only one file is present matching the alias characteristics, then it is
     returned.
 
-
     Parameters
     ----------
     path
-        Folder in which may contain checkpoint
+        Folder in which may contain checkpoint.
     alias
         Can be one of "best" or "periodic".
 
-
     Returns
     -------
         Path to the requested checkpoint, or ``None``, if no checkpoint file
         matching specifications is found on the provided path.
 
-
     Raises
     ------
     FileNotFoundError
@@ -91,7 +88,7 @@ def _get_checkpoint_from_alias(
 
 def get_checkpoint_to_resume_training(
     path: pathlib.Path,
-):
+) -> pathlib.Path:
     """Returns the best checkpoint file path to resume training from.
 
     Parameters
@@ -100,11 +97,10 @@ def get_checkpoint_to_resume_training(
         The base directory containing either the "periodic" checkpoint to start
         the training session from.
 
-
     Returns
     -------
-        Path to a checkpoint file that exists on disk
-
+    pathlib.Path
+        Path to a checkpoint file that exists on disk.
 
     Raises
     ------
@@ -117,7 +113,7 @@ def get_checkpoint_to_resume_training(
 
 def get_checkpoint_to_run_inference(
     path: pathlib.Path,
-):
+) -> pathlib.Path:
     """Returns the best checkpoint file path to run inference with.
 
     Parameters
@@ -126,11 +122,10 @@ def get_checkpoint_to_run_inference(
         The base directory containing either the "best", "last" or "periodic"
         checkpoint to start the training session from.
 
-
     Returns
     -------
-        Path to a checkpoint file that exists on disk
-
+    pathlib.Path
+        Path to a checkpoint file that exists on disk.
 
     Raises
     ------
diff --git a/src/mednet/utils/resources.py b/src/mednet/utils/resources.py
index 01995cac..ae28a34d 100644
--- a/src/mednet/utils/resources.py
+++ b/src/mednet/utils/resources.py
@@ -37,18 +37,14 @@ def run_nvidia_smi(
     For a comprehensive list of options and help, execute ``nvidia-smi
     --help-query-gpu`` on a host with a GPU
 
-
     Parameters
     ----------
-
     query
-        A list of query strings as defined by ``nvidia-smi --help-query-gpu``
-
+        A list of query strings as defined by ``nvidia-smi --help-query-gpu``.
 
     Returns
     -------
-
-    data
+    dict[str, str | float] | None
         A dictionary containing the queried parameters (``rename`` versions).
         If ``nvidia-smi`` is not available, returns ``None``.  Percentage
         information is left alone, memory information is transformed to
@@ -151,7 +147,7 @@ def cuda_constants() -> dict[str, str | int | float] | None:
 
     Returns
     -------
-
+    dict[str, str | int | float] | None
         If ``nvidia-smi`` is not available, returns ``None``, otherwise, we
         return a dictionary containing the following ``nvidia-smi`` query
         information, in this order:
@@ -210,7 +206,7 @@ def cuda_log() -> dict[str, float] | None:
 
     Returns
     -------
-
+    dict[str, float] | None
         If ``nvidia-smi`` is not available, returns ``None``, otherwise, we
         return a dictionary containing the following ``nvidia-smi`` query
         information, in this order:
@@ -274,8 +270,8 @@ def cpu_constants() -> dict[str, int | float]:
 
     Returns
     -------
-
-        A dictionary containing these entries:
+    dict[str, int | float]
+        An ordered dictionary (organized as 2-tuples) containing these entries:
 
         0. ``cpu_memory_total`` (:py:class:`float`): total memory available,
            in gigabytes
@@ -288,13 +284,12 @@ def cpu_constants() -> dict[str, int | float]:
 
 
 class CPULogger:
-    """Logs CPU information using :py:mod:`psutil`
+    """Logs CPU information using :py:mod:`psutil`.
 
     Parameters
     ----------
-
     pid
-        Process identifier of the main process (parent process) to observe
+        Process identifier of the main process (parent process) to observe.
     """
 
     def __init__(self, pid: int | None = None):
@@ -308,8 +303,7 @@ class CPULogger:
 
         Returns
         -------
-
-        data
+        dict[str, int | float]
             An ordered dictionary containing these entries:
 
             0. ``cpu_memory_used`` (:py:class:`float`): total memory used from
@@ -376,16 +370,13 @@ class _InformationGatherer:
 
     Parameters
     ----------
-
     device_type
         String representation of one of the supported pytorch device types
         triggering the correct readout of resource usage.
-
     main_pid
         The main process identifier to monitor
-
     logger
-        A logger to be used for logging messages
+        A logger to be used for logging messages.
     """
 
     def __init__(
@@ -467,29 +458,22 @@ def _monitor_worker(
 
     Parameters
     ==========
-
     interval
         Number of seconds to wait between each measurement (maybe a floating
-        point number as accepted by :py:func:`time.sleep`)
-
+        point number as accepted by :py:func:`time.sleep`).
     device_type
         String representation of one of the supported pytorch device types
         triggering the correct readout of resource usage.
-
     main_pid
-        The main process identifier to monitor
-
+        The main process identifier to monitor.
     stop
-        Event that indicates if we should continue running or stop
-
+        Event that indicates if we should continue running or stop.
     summary_event
-        Event that indicates if we should produce a summary
-
+        Event that indicates if we should produce a summary.
     queue
-        A queue, to send monitoring information back to the spawner
-
+        A queue, to send monitoring information back to the spawner.
     logging_level
-        The logging level to use for logging from launched processes
+        The logging level to use for logging from launched processes.
     """
     logger = multiprocessing.log_to_stderr(level=logging_level)
     ra = _InformationGatherer(device_type, main_pid, logger)
@@ -517,20 +501,16 @@ class ResourceMonitor:
 
     Parameters
     ----------
-
     interval
         Number of seconds to wait between each measurement (maybe a floating
-        point number as accepted by :py:func:`time.sleep`)
-
+        point number as accepted by :py:func:`time.sleep`).
     device_type
         String representation of one of the supported pytorch device types
         triggering the correct readout of resource usage.
-
     main_pid
-        The main process identifier to monitor
-
+        The main process identifier to monitor.
     logging_level
-        The logging level to use for logging from launched processes
+        The logging level to use for logging from launched processes.
     """
 
     def __init__(
diff --git a/src/mednet/utils/summary.py b/src/mednet/utils/summary.py
index 2f7d468c..2d3824c0 100644
--- a/src/mednet/utils/summary.py
+++ b/src/mednet/utils/summary.py
@@ -6,11 +6,13 @@
 
 from functools import reduce
 
+import torch.nn.Module
+
 from torch.nn.modules.module import _addindent
 
 
 # ignore this space!
-def _repr(model):
+def _repr(model: torch.nn.Module) -> tuple[str, int]:
     # We treat the extra repr like the sub-module, one item per line
     extra_lines = []
     extra_repr = model.extra_repr()
@@ -43,22 +45,17 @@ def _repr(model):
     return main_str, total_params
 
 
-def summary(model):
+def summary(model: torch.nn.Module) -> tuple[str, int]:
     """Counts the number of parameters in each model layer.
 
     Parameters
     ----------
-
-    model : :py:class:`torch.nn.Module`
-        model to summarize
+    model
+        Model to summarize.
 
     Returns
     -------
-
-    repr : str
-        a multiline string representation of the network
-
-    nparam : int
-        number of parameters
+    tuple[int, str]
+        A tuple containing a multiline string representation of the network and the number of parameters.
     """
     return _repr(model)
diff --git a/src/mednet/utils/tensorboard.py b/src/mednet/utils/tensorboard.py
index e41b2c07..1ede9e70 100644
--- a/src/mednet/utils/tensorboard.py
+++ b/src/mednet/utils/tensorboard.py
@@ -18,15 +18,14 @@ def scalars_to_dict(
     run, and will return a dictionary with all collected scalars, ready for
     plotting.
 
-
     Parameters
     ----------
     logdir
         Directory containing the event files.
 
-
     Returns
     -------
+    dict[str, tuple[list[int], list[float]]]
         A dictionary where keys represent all scalar names, and values
         correspond to a tuple that contains an array with epoch numbers (when
         values were taken), when the monitored values themselves.  The lists
-- 
GitLab