From 5249d72415efa00aab70d1e318b950233b4e331d Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 2 Jul 2024 09:32:23 +0200
Subject: [PATCH] [*/models] Unify modelling by applying DRY; Fix a nasty
 loss-balancing bug; Make loss-configuration private to base model class;
 Automate loss-configuration; Fix typing across model submodules; Move `name`
 into parent type

---
 .../libs/classification/models/alexnet.py     |  26 +++--
 .../models/classification_model.py            |   4 +
 .../libs/classification/models/cnn3d.py       |   1 +
 .../libs/classification/models/densenet.py    |  26 +++--
 .../models/logistic_regression.py             | 102 +++++-------------
 src/mednet/libs/classification/models/mlp.py  | 101 +++++------------
 src/mednet/libs/classification/models/pasa.py |  26 +++--
 src/mednet/libs/common/engine/trainer.py      |   2 -
 src/mednet/libs/common/models/model.py        |  53 +++++----
 src/mednet/libs/segmentation/models/driu.py   |  27 ++---
 .../libs/segmentation/models/driu_bn.py       |  27 ++---
 .../libs/segmentation/models/driu_od.py       |  26 ++---
 .../libs/segmentation/models/driu_pix.py      |  27 ++---
 src/mednet/libs/segmentation/models/hed.py    |  28 ++---
 src/mednet/libs/segmentation/models/lwnet.py  |  26 +++--
 src/mednet/libs/segmentation/models/m2unet.py |  26 ++---
 .../segmentation/models/segmentation_model.py |   4 +
 src/mednet/libs/segmentation/models/unet.py   |  51 +++------
 18 files changed, 240 insertions(+), 343 deletions(-)

diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py
index b400a728..320e850c 100644
--- a/src/mednet/libs/classification/models/alexnet.py
+++ b/src/mednet/libs/classification/models/alexnet.py
@@ -56,11 +56,11 @@ class Alexnet(ClassificationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -68,20 +68,18 @@ class Alexnet(ClassificationModel):
         num_classes: int = 1,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="alexnet",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "alexnet"
-        self.num_classes = num_classes
-
         self.pretrained = pretrained
 
         # Load pretrained model
diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py
index e2d4b3be..51f6fa18 100644
--- a/src/mednet/libs/classification/models/classification_model.py
+++ b/src/mednet/libs/classification/models/classification_model.py
@@ -20,6 +20,8 @@ class ClassificationModel(Model):
 
     Parameters
     ----------
+    name
+        Common name to give to models of this type.
     loss_type
         The loss to be used for training and evaluation.
 
@@ -49,6 +51,7 @@ class ClassificationModel(Model):
 
     def __init__(
         self,
+        name: str,
         loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
@@ -60,6 +63,7 @@ class ClassificationModel(Model):
         num_classes: int = 1,
     ):
         super().__init__(
+            name,
             loss_type,
             loss_arguments,
             optimizer_type,
diff --git a/src/mednet/libs/classification/models/cnn3d.py b/src/mednet/libs/classification/models/cnn3d.py
index 9a77092d..64bf4893 100644
--- a/src/mednet/libs/classification/models/cnn3d.py
+++ b/src/mednet/libs/classification/models/cnn3d.py
@@ -65,6 +65,7 @@ class Conv3DNet(ClassificationModel):
         num_classes: int = 1,
     ):
         super().__init__(
+            name="cnn3d",
             loss_type=loss_type,
             loss_arguments=loss_arguments,
             optimizer_type=optimizer_type,
diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py
index e58446d1..e2e92793 100644
--- a/src/mednet/libs/classification/models/densenet.py
+++ b/src/mednet/libs/classification/models/densenet.py
@@ -56,11 +56,11 @@ class Densenet(ClassificationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -69,20 +69,18 @@ class Densenet(ClassificationModel):
         num_classes: int = 1,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="densenet-121",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "densenet-121"
-        self.num_classes = num_classes
-
         self.pretrained = pretrained
 
         # Load pretrained model
diff --git a/src/mednet/libs/classification/models/logistic_regression.py b/src/mednet/libs/classification/models/logistic_regression.py
index f9fe1847..d95bce06 100644
--- a/src/mednet/libs/classification/models/logistic_regression.py
+++ b/src/mednet/libs/classification/models/logistic_regression.py
@@ -4,109 +4,59 @@
 
 import typing
 
-import lightning.pytorch as pl
 import torch
-import torch.nn as nn
+import torch.nn
 
+from .classification_model import ClassificationModel
 
-class LogisticRegression(pl.LightningModule):
+
+class LogisticRegression(ClassificationModel):
     """Logistic regression classifier with a single output.
 
     Parameters
     ----------
-    train_loss
-        The loss to be used during the training.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    num_classes
+        Number of outputs (classes) for this model.
     input_size
         The number of inputs this classifer shall process.
     """
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {"lr": 1e-2},
+        num_classes: int = 1,
         input_size: int = 14,
     ):
-        super().__init__()
-
-        self._train_loss = train_loss.to(self.device)
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        ).to(self.device)
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self.name = "logistic-regression"
+        super().__init__(
+            name="logistic-regression",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=None,
+            scheduler_arguments={},
+            model_transforms=[],
+            augmentation_transforms=[],
+            num_classes=num_classes,
+        )
 
-        self.linear = nn.Linear(input_size, 1)
+        self.linear = torch.nn.Linear(input_size, self.num_classes)
 
     def forward(self, x):
         return self.linear(self.normalizer(x))
-
-    def training_step(self, batch, batch_idx):
-        _input = batch[1]
-        labels = batch[2]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        output = self(_input)
-
-        # Manually move criterion to selected device, since not part of the model.
-        self._train_loss = self._train_loss.to(self.device)
-        training_loss = self._train_loss(output, labels.float())
-
-        return {"loss": training_loss}
-
-    def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        _input = batch[1]
-        labels = batch[2]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # data forwarding on the existing network
-        output = self(_input)
-
-        # Manually move criterion to selected device, since not part of the model.
-        self._validation_loss = self._validation_loss.to(self.device)
-        validation_loss = self._validation_loss(output, labels.float())
-
-        if dataloader_idx == 0:
-            return {"validation_loss": validation_loss}
-
-        return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        return torch.sigmoid(outputs)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(
-            self.parameters(),
-            **self._optimizer_arguments,
-        )
diff --git a/src/mednet/libs/classification/models/mlp.py b/src/mednet/libs/classification/models/mlp.py
index bd928410..b35ca103 100644
--- a/src/mednet/libs/classification/models/mlp.py
+++ b/src/mednet/libs/classification/models/mlp.py
@@ -4,35 +4,32 @@
 
 import typing
 
-import lightning.pytorch as pl
 import torch
+import torch.nn
 
+from .classification_model import ClassificationModel
 
-class MultiLayerPerceptron(pl.LightningModule):
+
+class MultiLayerPerceptron(ClassificationModel):
     """MLP with a variable number of inputs and hidden neurons (single layer).
 
     Parameters
     ----------
-    train_loss
-        The loss to be used during the training.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    num_classes
+        Number of outputs (classes) for this model.
     input_size
         The number of inputs this classifer shall process.
     hidden_size
@@ -41,76 +38,30 @@ class MultiLayerPerceptron(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
-        validation_loss: torch.nn.Module | None = None,
+        loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {"lr": 1e-2},
+        num_classes: int = 1,
         input_size: int = 14,
         hidden_size: int = 10,
     ):
-        super().__init__()
-
-        self._train_loss = train_loss.to(self.device)
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        ).to(self.device)
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self.name = "mlp"
+        super().__init__(
+            name="mlp",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=None,
+            scheduler_arguments={},
+            model_transforms=[],
+            augmentation_transforms=[],
+            num_classes=num_classes,
+        )
 
         self.fc1 = torch.nn.Linear(input_size, hidden_size)
         self.relu = torch.nn.ReLU()
-        self.fc2 = torch.nn.Linear(hidden_size, 1)
+        self.fc2 = torch.nn.Linear(hidden_size, self.num_classes)
 
     def forward(self, x):
         return self.fc2(self.relu(self.fc1(self.normalizer(x))))
-
-    def training_step(self, batch, batch_idx):
-        _input = batch[1]
-        labels = batch[2]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # Forward pass on the network
-        output = self(_input)
-
-        # Manually move criterion to selected device, since not part of the model.
-        self._train_loss = self._train_loss.to(self.device)
-        training_loss = self._train_loss(output, labels.float())
-
-        return {"loss": training_loss}
-
-    def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        _input = batch[1]
-        labels = batch[2]
-
-        # Increase label dimension if too low
-        # Allows single and multiclass usage
-        if labels.ndim == 1:
-            labels = torch.reshape(labels, (labels.shape[0], 1))
-
-        # data forwarding on the existing network
-        output = self(_input)
-
-        # Manually move criterion to selected device, since not part of the model.
-        self._validation_loss = self._validation_loss.to(self.device)
-        validation_loss = self._validation_loss(output, labels.float())
-
-        if dataloader_idx == 0:
-            return {"validation_loss": validation_loss}
-
-        return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        outputs = self(batch[0])
-        return torch.sigmoid(outputs)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(
-            self.parameters(),
-            **self._optimizer_arguments,
-        )
diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py
index 8eea2c38..6f0c7c75 100644
--- a/src/mednet/libs/classification/models/pasa.py
+++ b/src/mednet/libs/classification/models/pasa.py
@@ -56,31 +56,29 @@ class Pasa(ClassificationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="pasa",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "pasa"
-        self.num_classes = num_classes
-
         # First convolution block
         self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
         self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
diff --git a/src/mednet/libs/common/engine/trainer.py b/src/mednet/libs/common/engine/trainer.py
index 5d3f2a0d..63d916bc 100644
--- a/src/mednet/libs/common/engine/trainer.py
+++ b/src/mednet/libs/common/engine/trainer.py
@@ -73,8 +73,6 @@ def run(
 
     output_folder.mkdir(parents=True, exist_ok=True)
 
-    model.configure_losses()
-
     from .loggers import CustomTensorboardLogger
 
     log_dir = "logs"
diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py
index 669ba6c1..cad52097 100644
--- a/src/mednet/libs/common/models/model.py
+++ b/src/mednet/libs/common/models/model.py
@@ -2,6 +2,7 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import copy
 import logging
 import typing
 
@@ -25,6 +26,8 @@ class Model(pl.LightningModule):
 
     Parameters
     ----------
+    name
+        Common name to give to models of this type.
     loss_type
         The loss to be used for training and evaluation.
 
@@ -54,6 +57,7 @@ class Model(pl.LightningModule):
 
     def __init__(
         self,
+        name: str,
         loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
@@ -66,25 +70,21 @@ class Model(pl.LightningModule):
     ):
         super().__init__()
 
-        self.name = "model"
+        self.name = name
         self.num_classes = num_classes
-
         self.model_transforms = model_transforms
-
         self._loss_type = loss_type
-
-        self._train_loss_arguments = loss_arguments
-
-        self._validation_loss_arguments = loss_arguments
-
+        self._train_loss_arguments = copy.deepcopy(loss_arguments)
+        self._validation_loss_arguments = copy.deepcopy(loss_arguments)
         self._optimizer_type = optimizer_type
         self._optimizer_arguments = optimizer_arguments
-
         self._scheduler_type = scheduler_type
         self._scheduler_arguments = scheduler_arguments
-
         self.augmentation_transforms = augmentation_transforms
 
+        # initializes losses from input arguments
+        self._configure_losses()
+
     @property
     def augmentation_transforms(self) -> torchvision.transforms.Compose:
         return self._augmentation_transforms
@@ -149,8 +149,14 @@ class Model(pl.LightningModule):
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         raise NotImplementedError
 
-    def configure_losses(self):
+    def _configure_losses(self):
+        """Create loss objects for train and validation."""
+
+        logger.info(f"Configuring train loss ({self._train_loss_arguments})...")
         self._train_loss = self._loss_type(**self._train_loss_arguments)
+        logger.info(
+            f"Configuring validation loss ({self._validation_loss_arguments})..."
+        )
         self._validation_loss = self._loss_type(**self._validation_loss_arguments)
 
     def configure_optimizers(self):
@@ -195,10 +201,8 @@ class Model(pl.LightningModule):
             ]
         )
 
-        if self._train_loss is not None:
-            self._train_loss.to(*args, **kwargs)
-        if self._validation_loss is not None:
-            self._validation_loss.to(*args, **kwargs)
+        self._train_loss.to(*args, **kwargs)
+        self._validation_loss.to(*args, **kwargs)
 
         return self
 
@@ -213,23 +217,36 @@ class Model(pl.LightningModule):
 
         try:
             getattr(self._loss_type(), "pos_weight")
+
         except AttributeError:
             logger.warning(
                 f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced."
             )
+
         else:
-            logger.info(f"Balancing training loss {self._loss_type}.")
             train_weights = get_positive_weights(datamodule.train_dataloader())
             self._train_loss_arguments["pos_weight"] = train_weights
+            logger.info(
+                f"Balanced training loss {self._loss_type}: "
+                f"`pos_weight={train_weights.item():.3f}`."
+            )
 
-            logger.info(f"Balancing validation loss {self._loss_type}.")
             if "validation" in datamodule.val_dataloader().keys():
                 validation_weights = get_positive_weights(
                     datamodule.val_dataloader()["validation"]
                 )
             else:
                 logger.warning(
-                    "Datamodule does not contain a validation dataloader. The training dataloader will be used instead."
+                    "Datamodule does not contain a validation dataloader. "
+                    "The training dataloader will be used instead."
                 )
                 validation_weights = get_positive_weights(datamodule.train_dataloader())
+
             self._validation_loss_arguments["pos_weight"] = validation_weights
+            logger.info(
+                f"Balanced validation loss {self._loss_type}: "
+                f"`pos_weight={validation_weights.item():.3f}`."
+            )
+
+        # re-instantiates losses for train and validation
+        self._configure_losses()
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index 295713cb..40d09709 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -8,6 +8,7 @@ import typing
 
 import torch
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 
 from .backbones.vgg import vgg16_for_segmentation
@@ -43,7 +44,7 @@ class DRIUHead(torch.nn.Module):
         Number of channels for each feature map that is returned from backbone.
     """
 
-    def __init__(self, in_channels_list=None):
+    def __init__(self, in_channels_list):
         super().__init__()
         (
             in_conv_1_2_16,
@@ -106,11 +107,11 @@ class DRIU(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -118,17 +119,17 @@ class DRIU(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="driu",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
-        self.name = "driu"
 
         self.pretrained = pretrained
 
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index bacbf291..9e32925e 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -8,6 +8,7 @@ import typing
 
 import torch
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 
 from .backbones.vgg import vgg16_for_segmentation
@@ -46,7 +47,7 @@ class DRIUBNHead(torch.nn.Module):
         Number of channels for each feature map that is returned from backbone.
     """
 
-    def __init__(self, in_channels_list=None):
+    def __init__(self, in_channels_list):
         super().__init__()
         (
             in_conv_1_2_16,
@@ -109,11 +110,11 @@ class DRIUBN(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -121,17 +122,17 @@ class DRIUBN(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="driu-bn",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
-        self.name = "driu-bn"
 
         self.pretrained = pretrained
 
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index af4727c5..2336869e 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -28,7 +28,7 @@ class DRIUODHead(torch.nn.Module):
         Number of channels for each feature map that is returned from backbone.
     """
 
-    def __init__(self, in_channels_list=None):
+    def __init__(self, in_channels_list):
         super().__init__()
         (
             in_upsample2,
@@ -91,11 +91,11 @@ class DRIUOD(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -103,17 +103,17 @@ class DRIUOD(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="driu-od",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
-        self.name = "driu-od"
 
         self.pretrained = pretrained
 
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index e9800984..42436fcc 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -8,6 +8,7 @@ import typing
 
 import torch
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 
 from .backbones.vgg import vgg16_for_segmentation
@@ -28,7 +29,7 @@ class DRIUPIXHead(torch.nn.Module):
         Number of channels for each feature map that is returned from backbone.
     """
 
-    def __init__(self, in_channels_list=None):
+    def __init__(self, in_channels_list):
         super().__init__()
         (
             in_conv_1_2_16,
@@ -95,11 +96,11 @@ class DRIUPix(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -107,17 +108,17 @@ class DRIUPix(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="driu-pix",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
-        self.name = "driu-pix"
 
         self.pretrained = pretrained
 
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 12a7c591..cd99707d 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -7,6 +7,7 @@ import typing
 
 import torch
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 
 from .backbones.vgg import vgg16_for_segmentation
@@ -40,7 +41,7 @@ class HEDHead(torch.nn.Module):
         Number of channels for each feature map that is returned from backbone.
     """
 
-    def __init__(self, in_channels_list=None):
+    def __init__(self, in_channels_list):
         super().__init__()
         (
             in_conv_1_2_16,
@@ -109,11 +110,11 @@ class HED(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = MultiSoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = MultiSoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -121,19 +122,18 @@ class HED(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="hed",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "hed"
-
         self.pretrained = pretrained
 
         self.backbone = vgg16_for_segmentation(
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index 9442c2c9..5cd491b0 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -310,31 +310,29 @@ class LittleWNet(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = MultiWeightedBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = MultiWeightedBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="lwnet",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "lwnet"
-        self.num_classes = num_classes
-
         self.unet1 = LittleUNet(
             in_c=3,
             n_classes=self.num_classes,
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index fcca7d60..87359565 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -7,6 +7,7 @@ import typing
 
 import torch
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
 from torchvision.models.mobilenetv2 import InvertedResidual
@@ -157,11 +158,11 @@ class M2UNET(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -169,19 +170,18 @@ class M2UNET(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="m2unet",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "m2unet"
-
         self.pretrained = pretrained
 
         self.backbone = mobilenet_v2_for_segmentation(
diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py
index 8ffb75ac..dae5cde6 100644
--- a/src/mednet/libs/segmentation/models/segmentation_model.py
+++ b/src/mednet/libs/segmentation/models/segmentation_model.py
@@ -21,6 +21,8 @@ class SegmentationModel(Model):
 
     Parameters
     ----------
+    name
+        Common name to give to models of this type.
     loss_type
         The loss to be used for training and evaluation.
 
@@ -50,6 +52,7 @@ class SegmentationModel(Model):
 
     def __init__(
         self,
+        name: str,
         loss_type: type[torch.nn.Module] = MultiWeightedBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
@@ -61,6 +64,7 @@ class SegmentationModel(Model):
         num_classes: int = 1,
     ):
         super().__init__(
+            name,
             loss_type,
             loss_arguments,
             optimizer_type,
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
index cb1e34cf..f4d64f01 100644
--- a/src/mednet/libs/segmentation/models/unet.py
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -7,6 +7,7 @@ import logging
 import typing
 
 import torch.nn
+import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 
 from .backbones.vgg import vgg16_for_segmentation
@@ -28,7 +29,7 @@ class UNetHead(torch.nn.Module):
         If True, upsample using PixelShuffleICNR.
     """
 
-    def __init__(self, in_channels_list: list[int] = None, pixel_shuffle=False):
+    def __init__(self, in_channels_list: list[int], pixel_shuffle=False):
         super().__init__()
         # number of channels
         c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
@@ -98,11 +99,11 @@ class Unet(SegmentationModel):
 
     def __init__(
         self,
-        loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
+        loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
-        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
         scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
@@ -110,19 +111,18 @@ class Unet(SegmentationModel):
         pretrained: bool = False,
     ):
         super().__init__(
-            loss_type,
-            loss_arguments,
-            optimizer_type,
-            optimizer_arguments,
-            scheduler_type,
-            scheduler_arguments,
-            model_transforms,
-            augmentation_transforms,
-            num_classes,
+            name="unet",
+            loss_type=loss_type,
+            loss_arguments=loss_arguments,
+            optimizer_type=optimizer_type,
+            optimizer_arguments=optimizer_arguments,
+            scheduler_type=scheduler_type,
+            scheduler_arguments=scheduler_arguments,
+            model_transforms=model_transforms,
+            augmentation_transforms=augmentation_transforms,
+            num_classes=num_classes,
         )
 
-        self.name = "unet"
-
         self.pretrained = pretrained
 
         self.backbone = vgg16_for_segmentation(
@@ -160,26 +160,3 @@ class Unet(SegmentationModel):
             self.normalizer = make_imagenet_normalizer()
         else:
             super().set_normalizer(dataloader)
-
-    def training_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(self._augmentation_transforms(images))
-        return self._train_loss(outputs, ground_truths, masks)
-
-    def validation_step(self, batch, batch_idx):
-        images = batch[0]["image"]
-        ground_truths = batch[0]["target"]
-        masks = batch[0]["mask"]
-
-        outputs = self(images)
-        return self._validation_loss(outputs, ground_truths, masks)
-
-    def predict_step(self, batch, batch_idx, dataloader_idx=0):
-        output = self(batch[0]["image"])[1]
-        return torch.sigmoid(output)
-
-    def configure_optimizers(self):
-        return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
-- 
GitLab