diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index 49ee76dbd002a9fbd13e99d052eef223c17a19ac..d1e1b0a3ae8d9e3e32a7ec19a49e21f01bb694d9 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -11,32 +11,16 @@ Screening and Visualization".
 Reference: [PASA-2019]_
 """
 
-from torch import empty
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import Adam
 
-from ...models.pasa import PASA
-
-# optimizer
-optimizer = Adam
-optimizer_configs = {"lr": 8e-5}
-
-# criterion
-criterion = BCEWithLogitsLoss(pos_weight=empty(1))
-criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
-
 from ...data.transforms import ElasticDeformation
-
-augmentation_transforms = [ElasticDeformation(p=0.8)]
-
-# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode
-# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)]
-
-# model
-model = PASA(
-    criterion,
-    criterion_valid,
-    optimizer,
-    optimizer_configs,
-    augmentation_transforms=augmentation_transforms,
+from ...models.pasa import Pasa
+
+model = Pasa(
+    train_loss=BCEWithLogitsLoss(),
+    validation_loss=BCEWithLogitsLoss(),
+    optimizer_type=Adam,
+    optimizer_arguments=dict(lr=8e-5),
+    augmentation_transforms=[ElasticDeformation(p=0.8)],
 )
diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 55898b6759e4e471607bbe87cff0de3fb074724c..e8643b46ceef967e8c94f7425d9cecd9bd21a0b3 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -5,11 +5,13 @@
 import logging
 
 import lightning.pytorch as pl
-import torch
 import torch.nn as nn
+import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
 
+from ..data.typing import DataLoader
+
 logger = logging.getLogger(__name__)
 
 
@@ -30,7 +32,7 @@ class Alexnet(pl.LightningModule):
     ):
         super().__init__()
 
-        self.name = "AlexNet"
+        self.name = "alexnet"
 
         self.augmentation_transforms = torchvision.transforms.Compose(
             augmentation_transforms
@@ -49,7 +51,7 @@ class Alexnet(pl.LightningModule):
         if not pretrained:
             weights = None
         else:
-            logger.info("Loading pretrained model weights")
+            logger.info(f"Loading pretrained {self.name} model weights")
             weights = models.AlexNet_Weights.DEFAULT
 
         self.model_ft = models.alexnet(weights=weights)
@@ -81,47 +83,60 @@ class Alexnet(pl.LightningModule):
             from .normalizer import make_imagenet_normalizer
 
             logger.warning(
-                "ImageNet pre-trained densenet model - NOT "
-                "computing z-norm factors from training data. "
-                "Using preset factors from torchvision."
+                f"ImageNet pre-trained {self.name} model - NOT "
+                f"computing z-norm factors from train dataloader. "
+                f"Using preset factors from torchvision."
             )
             self.normalizer = make_imagenet_normalizer()
         else:
             from .normalizer import make_z_normalizer
 
             logger.info(
-                "Uninitialised densenet model - "
-                "computing z-norm factors from training data."
+                f"Uninitialised {self.name} model - "
+                f"computing z-norm factors from train dataloader."
             )
             self.normalizer = make_z_normalizer(dataloader)
 
-    def set_bce_loss_weights(self, datamodule):
-        """Reweights loss weights if BCEWithLogitsLoss is used.
+    def balance_losses_by_class(
+        self, train_dataloader: DataLoader, valid_dataloader: DataLoader
+    ):
+        """Reweights loss weights if possible.
 
         Parameters
         ----------
 
-        datamodule:
-            A datamodule implementing train_dataloader() and val_dataloader()
+        train_dataloader
+            The data loader to use for training
+
+        valid_dataloader
+            The data loader to use for validation
+
+
+        Raises
+        ------
+
+        RuntimeError
+            If train or validation losses are not of type
+            :py:class:`torch.nn.BCEWithLogitsLoss`.
         """
-        from ..data.dataset import _get_positive_weights
+        from .loss_weights import get_label_weights
 
         if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
             logger.info("Reweighting BCEWithLogitsLoss training criterion.")
-            train_positive_weights = _get_positive_weights(
-                datamodule.train_dataloader()
-            )
-            self.criterion = torch.nn.BCEWithLogitsLoss(
-                pos_weight=train_positive_weights
+            weights = get_label_weights(train_dataloader)
+            self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
+        else:
+            raise RuntimeError(
+                "Training loss is not BCEWithLogitsLoss - dunno how to balance"
             )
 
         if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
             logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
-            validation_positive_weights = _get_positive_weights(
-                datamodule.val_dataloader()["validation"]
-            )
-            self.criterion_valid = torch.nn.BCEWithLogitsLoss(
-                pos_weight=validation_positive_weights
+            weights = get_label_weights(valid_dataloader)
+            self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights)
+        else:
+            raise RuntimeError(
+                "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
             )
 
     def training_step(self, batch, batch_idx):
@@ -172,11 +187,6 @@ class Alexnet(pl.LightningModule):
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
 
-        # necessary check for HED architecture that uses several outputs
-        # for loss calculation instead of just the last concatfuse block
-        if isinstance(outputs, list):
-            outputs = outputs[-1]
-
         return names[0], torch.flatten(probabilities), torch.flatten(labels)
 
     def configure_optimizers(self):
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index ae866e720ba4df0be292bd2ee3f23878a714818a..25d1d8ff344e15bc1c6729d7ebc7a324ce461e2a 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -6,17 +6,24 @@ import logging
 
 import lightning.pytorch as pl
 import torch
-import torch.nn as nn
+import torch.nn
+import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
 
+from ..data.typing import DataLoader
+
 logger = logging.getLogger(__name__)
 
 
 class Densenet(pl.LightningModule):
-    """Densenet module.
+    """Densenet-121 module.
+
+    Parameters
+    ----------
 
-    Note: only usable with a normalized dataset
+    criterion
+        A dictionary containing the criteria for the
     """
 
     def __init__(
@@ -30,7 +37,7 @@ class Densenet(pl.LightningModule):
     ):
         super().__init__()
 
-        self.name = "Densenet"
+        self.name = "densenet-121"
 
         self.augmentation_transforms = torchvision.transforms.Compose(
             augmentation_transforms
@@ -42,21 +49,20 @@ class Densenet(pl.LightningModule):
         self.optimizer = optimizer
         self.optimizer_configs = optimizer_configs
 
-        self.normalizer = None
         self.pretrained = pretrained
 
         # Load pretrained model
         if not pretrained:
             weights = None
         else:
-            logger.info("Loading pretrained model weights")
+            logger.info(f"Loading pretrained {self.name} model weights")
             weights = models.DenseNet121_Weights.DEFAULT
 
         self.model_ft = models.densenet121(weights=weights)
 
         # Adapt output features
-        self.model_ft.classifier = nn.Sequential(
-            nn.Linear(1024, 256), nn.Linear(256, 1)
+        self.model_ft.classifier = torch.nn.Sequential(
+            torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1)
         )
 
     def forward(self, x):
@@ -82,47 +88,62 @@ class Densenet(pl.LightningModule):
             from .normalizer import make_imagenet_normalizer
 
             logger.warning(
-                "ImageNet pre-trained densenet model - NOT "
-                "computing z-norm factors from training data. "
-                "Using preset factors from torchvision."
+                f"ImageNet pre-trained {self.name} model - NOT "
+                f"computing z-norm factors from train dataloader. "
+                f"Using preset factors from torchvision."
             )
             self.normalizer = make_imagenet_normalizer()
         else:
             from .normalizer import make_z_normalizer
 
             logger.info(
-                "Uninitialised densenet model - "
-                "computing z-norm factors from training data."
+                f"Uninitialised {self.name} model - "
+                f"computing z-norm factors from train dataloader."
             )
             self.normalizer = make_z_normalizer(dataloader)
 
-    def set_bce_loss_weights(self, datamodule):
-        """Reweights loss weights if BCEWithLogitsLoss is used.
+    def balance_losses_by_class(
+        self,
+        train_dataloader: DataLoader,
+        valid_dataloader: dict[str, DataLoader],
+    ):
+        """Reweights loss weights if possible.
 
         Parameters
         ----------
 
-        datamodule:
-            A datamodule implementing train_dataloader() and val_dataloader()
+        train_dataloader
+            The data loader to use for training
+
+        valid_dataloader
+            The data loaders to use for each of the validation sets
+
+
+        Raises
+        ------
+
+        RuntimeError
+            If train or validation losses are not of type
+            :py:class:`torch.nn.BCEWithLogitsLoss`.
         """
-        from ..data.dataset import _get_positive_weights
+        from .loss_weights import get_label_weights
 
         if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
             logger.info("Reweighting BCEWithLogitsLoss training criterion.")
-            train_positive_weights = _get_positive_weights(
-                datamodule.train_dataloader()
-            )
-            self.criterion = torch.nn.BCEWithLogitsLoss(
-                pos_weight=train_positive_weights
+            weights = get_label_weights(train_dataloader)
+            self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
+        else:
+            raise RuntimeError(
+                "Training loss is not BCEWithLogitsLoss - dunno how to balance"
             )
 
         if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
             logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
-            validation_positive_weights = _get_positive_weights(
-                datamodule.val_dataloader()["validation"]
-            )
-            self.criterion_valid = torch.nn.BCEWithLogitsLoss(
-                pos_weight=validation_positive_weights
+            weights = get_label_weights(valid_dataloader)
+            self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights)
+        else:
+            raise RuntimeError(
+                "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
             )
 
     def training_step(self, batch, batch_idx):
@@ -173,11 +194,6 @@ class Densenet(pl.LightningModule):
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
 
-        # necessary check for HED architecture that uses several outputs
-        # for loss calculation instead of just the last concatfuse block
-        if isinstance(outputs, list):
-            outputs = outputs[-1]
-
         return names[0], torch.flatten(probabilities), torch.flatten(labels)
 
     def configure_optimizers(self):
diff --git a/src/ptbench/models/loss_weights.py b/src/ptbench/models/loss_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..6889b2539fb1071e228c4487540ad9272aad808c
--- /dev/null
+++ b/src/ptbench/models/loss_weights.py
@@ -0,0 +1,70 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import logging
+
+import torch
+import torch.utils.data
+
+logger = logging.getLogger(__name__)
+
+
+def get_label_weights(
+    dataloader: torch.utils.data.DataLoader,
+) -> torch.Tensor:
+    """Computes the weights of each class of a DataLoader.
+
+    This function inputs a pytorch DataLoader and computes the ratio between
+    number of negative and positive samples (scalar).  The weight can be used
+    to adjust minimisation criteria to in cases there is a huge data imbalance.
+
+    If
+
+    It returns a vector with weights (inverse counts) for each label.
+
+
+    Parameters
+    ----------
+
+    dataloader
+        A DataLoader from which to compute the positive weights.  Entries must
+        be a dictionary which must contain a ``label`` key.
+
+
+    Returns
+    -------
+
+    positive_weights
+        the positive weight of each class in the dataset given as input
+    """
+
+    targets = torch.tensor(
+        [sample for batch in dataloader for sample in batch[1]["label"]]
+    )
+
+    # Binary labels
+    if len(list(targets.shape)) == 1:
+        class_sample_count = [
+            float((targets == t).sum().item())
+            for t in torch.unique(targets, sorted=True)
+        ]
+
+        # Divide negatives by positives
+        positive_weights = torch.tensor(
+            [class_sample_count[0] / class_sample_count[1]]
+        ).reshape(-1)
+
+    # Multiclass labels
+    else:
+        class_sample_count = torch.sum(targets, dim=0)
+        negative_class_sample_count = (
+            torch.full((targets.size()[1],), float(targets.size()[0]))
+            - class_sample_count
+        )
+
+        positive_weights = negative_class_sample_count / (
+            class_sample_count + negative_class_sample_count
+        )
+
+    return positive_weights
diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py
index 2cc4b956f17ee1e690031f42e891ef726b548e2a..ce68f4b558b2812d17a8f54954e197a086233a52 100644
--- a/src/ptbench/models/normalizer.py
+++ b/src/ptbench/models/normalizer.py
@@ -8,6 +8,7 @@ import torch
 import torch.nn
 import torch.utils.data
 import torchvision.transforms
+import tqdm
 
 
 def make_z_normalizer(
@@ -42,7 +43,7 @@ def make_z_normalizer(
     num_images = 0
 
     # Evaluates mean and standard deviation
-    for batch in dataloader:
+    for batch in tqdm.tqdm(dataloader, unit="batch"):
         data = batch[0]
         data = data.view(data.size(0), data.size(1), -1)
 
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index e257a4fc54190c778fdaa6c8d36522fa713fbd76..34e5c67fbb01684a51fe1e14ddc963ac0dadb6fb 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -3,94 +3,139 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import logging
+import typing
 
 import lightning.pytorch as pl
 import torch
-import torch.nn as nn
+import torch.nn
 import torch.nn.functional as F
+import torch.optim.optimizer
 import torch.utils.data
 import torchvision.transforms
 
+from ..data.typing import TransformSequence
+
 logger = logging.getLogger(__name__)
 
 
-class PASA(pl.LightningModule):
-    """PASA module.
+class Pasa(pl.LightningModule):
+    """Implementation of CNN by Pasa.
+
+    Simple CNN for classification based on paper by [PASA-2019]_.
+
+
+    Parameters
+    ----------
+
+    train_loss
+        The loss to be used during the training.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    validation_loss
+        The loss to be used for validation (may be different from the training
+        loss).  If extra-validation sets are provided, the same loss will be
+        used throughout.
 
-    Based on paper by [PASA-2019]_.
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    optimizer_type
+        The type of optimizer to use for training
+
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
     """
 
     def __init__(
         self,
-        criterion,
-        criterion_valid,
-        optimizer,
-        optimizer_configs,
-        augmentation_transforms,
+        train_loss: torch.nn.Module,
+        validation_loss: torch.nn.Module | None,
+        optimizer_type: type[torch.optim.Optimizer],
+        optimizer_arguments: dict[str, typing.Any],
+        augmentation_transforms: TransformSequence = [],
     ):
         super().__init__()
 
         self.name = "pasa"
 
-        self.augmentation_transforms = torchvision.transforms.Compose(
-            augmentation_transforms
+        self._train_loss = train_loss
+        self._validation_loss = (
+            validation_loss if validation_loss is not None else train_loss
         )
+        self._optimizer_type = optimizer_type
+        self._optimizer_arguments = optimizer_arguments
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
-        self.optimizer = optimizer
-        self.optimizer_configs = optimizer_configs
-
-        self.normalizer = None
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms
+        )
 
         # First convolution block
-        self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
-        self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
-        self.fc3 = nn.Conv2d(1, 16, (1, 1), (4, 4))
+        self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
+        self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
+        self.fc3 = torch.nn.Conv2d(1, 16, (1, 1), (4, 4))
 
-        self.batchNorm2d_4 = nn.BatchNorm2d(4)
-        self.batchNorm2d_16 = nn.BatchNorm2d(16)
-        self.batchNorm2d_16_2 = nn.BatchNorm2d(16)
+        self.batchNorm2d_4 = torch.nn.BatchNorm2d(4)
+        self.batchNorm2d_16 = torch.nn.BatchNorm2d(16)
+        self.batchNorm2d_16_2 = torch.nn.BatchNorm2d(16)
 
         # Second convolution block
-        self.fc4 = nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1))
-        self.fc5 = nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1))
-        self.fc6 = nn.Conv2d(16, 32, (1, 1), (1, 1))  # Original stride (2, 2)
+        self.fc4 = torch.nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1))
+        self.fc5 = torch.nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1))
+        self.fc6 = torch.nn.Conv2d(
+            16, 32, (1, 1), (1, 1)
+        )  # Original stride (2, 2)
 
-        self.batchNorm2d_24 = nn.BatchNorm2d(24)
-        self.batchNorm2d_32 = nn.BatchNorm2d(32)
-        self.batchNorm2d_32_2 = nn.BatchNorm2d(32)
+        self.batchNorm2d_24 = torch.nn.BatchNorm2d(24)
+        self.batchNorm2d_32 = torch.nn.BatchNorm2d(32)
+        self.batchNorm2d_32_2 = torch.nn.BatchNorm2d(32)
 
         # Third convolution block
-        self.fc7 = nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1))
-        self.fc8 = nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1))
-        self.fc9 = nn.Conv2d(32, 48, (1, 1), (1, 1))  # Original stride (2, 2)
+        self.fc7 = torch.nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1))
+        self.fc8 = torch.nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1))
+        self.fc9 = torch.nn.Conv2d(
+            32, 48, (1, 1), (1, 1)
+        )  # Original stride (2, 2)
 
-        self.batchNorm2d_40 = nn.BatchNorm2d(40)
-        self.batchNorm2d_48 = nn.BatchNorm2d(48)
-        self.batchNorm2d_48_2 = nn.BatchNorm2d(48)
+        self.batchNorm2d_40 = torch.nn.BatchNorm2d(40)
+        self.batchNorm2d_48 = torch.nn.BatchNorm2d(48)
+        self.batchNorm2d_48_2 = torch.nn.BatchNorm2d(48)
 
         # Fourth convolution block
-        self.fc10 = nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1))
-        self.fc11 = nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1))
-        self.fc12 = nn.Conv2d(48, 64, (1, 1), (1, 1))  # Original stride (2, 2)
+        self.fc10 = torch.nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1))
+        self.fc11 = torch.nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1))
+        self.fc12 = torch.nn.Conv2d(
+            48, 64, (1, 1), (1, 1)
+        )  # Original stride (2, 2)
 
-        self.batchNorm2d_56 = nn.BatchNorm2d(56)
-        self.batchNorm2d_64 = nn.BatchNorm2d(64)
-        self.batchNorm2d_64_2 = nn.BatchNorm2d(64)
+        self.batchNorm2d_56 = torch.nn.BatchNorm2d(56)
+        self.batchNorm2d_64 = torch.nn.BatchNorm2d(64)
+        self.batchNorm2d_64_2 = torch.nn.BatchNorm2d(64)
 
         # Fifth convolution block
-        self.fc13 = nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1))
-        self.fc14 = nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1))
-        self.fc15 = nn.Conv2d(64, 80, (1, 1), (1, 1))  # Original stride (2, 2)
+        self.fc13 = torch.nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1))
+        self.fc14 = torch.nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1))
+        self.fc15 = torch.nn.Conv2d(
+            64, 80, (1, 1), (1, 1)
+        )  # Original stride (2, 2)
 
-        self.batchNorm2d_72 = nn.BatchNorm2d(72)
-        self.batchNorm2d_80 = nn.BatchNorm2d(80)
-        self.batchNorm2d_80_2 = nn.BatchNorm2d(80)
+        self.batchNorm2d_72 = torch.nn.BatchNorm2d(72)
+        self.batchNorm2d_80 = torch.nn.BatchNorm2d(80)
+        self.batchNorm2d_80_2 = torch.nn.BatchNorm2d(80)
 
-        self.pool2d = nn.MaxPool2d((3, 3), (2, 2))  # Pool after conv. block
-        self.dense = nn.Linear(80, 1)  # Fully connected layer
+        self.pool2d = torch.nn.MaxPool2d(
+            (3, 3), (2, 2)
+        )  # Pool after conv. block
+        self.dense = torch.nn.Linear(80, 1)  # Fully connected layer
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore
@@ -141,51 +186,22 @@ class PASA(pl.LightningModule):
         return x
 
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
-        """Initializes the normalizer for the current model.
+        """Initializes the input normalizer for the current model.
 
         Parameters
         ----------
 
-        dataloader: :py:class:`torch.utils.data.DataLoader`
+        dataloader
             A torch Dataloader from which to compute the mean and std
         """
         from .normalizer import make_z_normalizer
 
         logger.info(
-            "Uninitialised densenet model - "
-            "computing z-norm factors from training data."
+            f"Uninitialised {self.name} model - "
+            f"computing z-norm factors from train dataloader."
         )
         self.normalizer = make_z_normalizer(dataloader)
 
-    def set_bce_loss_weights(self, datamodule):
-        """Reweights loss weights if BCEWithLogitsLoss is used.
-
-        Parameters
-        ----------
-
-        datamodule:
-            A datamodule implementing train_dataloader() and val_dataloader()
-        """
-        from ..data.dataset import _get_positive_weights
-
-        if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss training criterion.")
-            train_positive_weights = _get_positive_weights(
-                datamodule.train_dataloader()
-            )
-            self.criterion = torch.nn.BCEWithLogitsLoss(
-                pos_weight=train_positive_weights
-            )
-
-        if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
-            validation_positive_weights = _get_positive_weights(
-                datamodule.val_dataloader()["validation"]
-            )
-            self.criterion_valid = torch.nn.BCEWithLogitsLoss(
-                pos_weight=validation_positive_weights
-            )
-
     def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]
@@ -197,15 +213,13 @@ class PASA(pl.LightningModule):
 
         # Forward pass on the network
         augmented_images = [
-            self.augmentation_transforms(img).to(self.device) for img in images
+            self._augmentation_transforms(img).to(self.device) for img in images
         ]
         # Combine list of augmented images back into a tensor
         augmented_images = torch.cat(augmented_images, 0).view(images.shape)
         outputs = self(augmented_images)
 
-        training_loss = self.criterion(outputs, labels.double())
-
-        return {"loss": training_loss}
+        return self._train_loss(outputs, labels.float())
 
     def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[0]
@@ -219,12 +233,7 @@ class PASA(pl.LightningModule):
         # data forwarding on the existing network
         outputs = self(images)
 
-        validation_loss = self.criterion_valid(outputs, labels.double())
-
-        if dataloader_idx == 0:
-            return {"validation_loss": validation_loss}
-        else:
-            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
+        return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         images = batch[0]
@@ -234,40 +243,13 @@ class PASA(pl.LightningModule):
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
 
-        # necessary check for HED architecture that uses several outputs
-        # for loss calculation instead of just the last concatfuse block
-        if isinstance(outputs, list):
-            outputs = outputs[-1]
-
-        results = (
+        return (
             names[0],
             torch.flatten(probabilities),
             torch.flatten(labels),
         )
 
-        return results
-        # {
-        # f"dataloader_{dataloader_idx}_predictions": (
-        #    names[0],
-        #    torch.flatten(probabilities),
-        #    torch.flatten(labels),
-        # )
-        # }
-
-    # def on_predict_epoch_end(self):
-
-    #    retval = defaultdict(list)
-
-    #    for dataloader_name, predictions in self.predictions_cache.items():
-    #        for prediction in predictions:
-    #            retval[dataloader_name]["name"].append(prediction[0])
-    #            retval[dataloader_name]["prediction"].append(prediction[1])
-    #            retval[dataloader_name]["label"].append(prediction[2])
-
-    # Need to cache predictions in the predict step, then reorder by key
-    # Clear prediction dict
-    # raise NotImplementedError
-
     def configure_optimizers(self):
-        optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
-        return optimizer
+        return self._optimizer_type(
+            self.parameters(), **self._optimizer_arguments
+        )
diff --git a/src/ptbench/utils/save_sh_command.py b/src/ptbench/utils/save_sh_command.py
deleted file mode 100644
index e0a7d379c00caddb7ade2513668a1392e98b21f2..0000000000000000000000000000000000000000
--- a/src/ptbench/utils/save_sh_command.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import glob
-import logging
-import os
-import sys
-import time
-
-import pkg_resources
-
-logger = logging.getLogger(__name__)
-
-
-def save_sh_command(output_dir):
-    """Records command-line to reproduce this experiment.
-
-    This function can record the current command-line used to call the script
-    being run.  It creates an executable ``bash`` script setting up the current
-    working directory and activating a conda environment, if needed.  It
-    records further information on the date and time the script was run and the
-    version of the package.
-
-
-    Parameters
-    ----------
-
-    output_folder : str
-        Path leading to the directory where the commands to reproduce the current
-        run will be recorded. A subdirectory will be created each time this function
-        is called to match lightning's versioning convention for loggers.
-    """
-
-    cmd_config_dir = os.path.join(output_dir, "cmd_line_configs")
-    cmd_config_versions = glob.glob(os.path.join(cmd_config_dir, "version_*"))
-    if len(cmd_config_versions) > 0:
-        latest_cmd_config_version = max(
-            [
-                int(config.split("version_")[-1])
-                for config in cmd_config_versions
-            ]
-        )
-        current_cmd_config_version = str(latest_cmd_config_version + 1)
-    else:
-        current_cmd_config_version = "0"
-
-    destfile = os.path.join(
-        cmd_config_dir,
-        f"version_{current_cmd_config_version}",
-        "cmd_line_config.txt",
-    )
-
-    if os.path.exists(destfile):
-        logger.info(f"Not overwriting existing file '{destfile}'")
-        return
-
-    logger.info(f"Writing command-line for reproduction at '{destfile}'...")
-    os.makedirs(os.path.dirname(destfile), exist_ok=True)
-
-    with open(destfile, "w") as f:
-        f.write("#!/usr/bin/env sh\n")
-        f.write(f"# date: {time.asctime()}\n")
-        version = pkg_resources.require("ptbench")[0].version
-        f.write(f"# version: {version} (deepdraw)\n")
-        f.write(f"# platform: {sys.platform}\n")
-        f.write("\n")
-        args = []
-        for k in sys.argv:
-            if " " in k:
-                args.append(f'"{k}"')
-            else:
-                args.append(k)
-        if os.environ.get("CONDA_DEFAULT_ENV") is not None:
-            f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
-        f.write(f"#cd {os.path.realpath(os.curdir)}\n")
-        f.write(" ".join(args) + "\n")
-    os.chmod(destfile, 0o755)