diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py
index 09af9f842c731bcaa23e284eb292d90134664883..b0428f4498f6da65ac6549e39c1fec21a7bbf265 100644
--- a/src/mednet/libs/segmentation/config/models/lwnet.py
+++ b/src/mednet/libs/segmentation/config/models/lwnet.py
@@ -18,8 +18,7 @@ min_lr = 1e-08  # valley
 cycle = 50  # epochs for a complete scheduling cycle
 
 model = LittleWNet(
-    train_loss=MultiWeightedBCELogitsLoss(),
-    validation_loss=MultiWeightedBCELogitsLoss(),
+    loss_type=MultiWeightedBCELogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=max_lr),
     augmentation_transforms=[],
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index d8e5b2cbd23119eb3d02cd5d2ac9e9342ac3ee88..ac23d617138e7f8f9d04849691b82832930ffd7b 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -17,11 +17,10 @@ Reference: [GALDRAN-2020]_
 
 import typing
 
-import lightning.pytorch as pl
 import torch
 import torch.nn
-import torchvision.transforms
 from mednet.libs.common.data.typing import TransformSequence
+from mednet.libs.common.models.model import Model
 from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
 from torchvision.transforms.v2 import CenterCrop
 
@@ -230,27 +229,20 @@ class LittleUNet(torch.nn.Module):
         return self.final(x)
 
 
-class LittleWNet(pl.LightningModule):
+class LittleWNet(Model):
     """Little W-Net model, concatenating two Little U-Net models.
 
     Parameters
     ----------
-    train_loss
-        The loss to be used during the training.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -266,32 +258,28 @@ class LittleWNet(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss=MultiWeightedBCELogitsLoss(),
-        validation_loss=MultiWeightedBCELogitsLoss(),
+        loss_type: torch.nn.Module = MultiWeightedBCELogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
         crop_size: int = 544,
     ):
-        super().__init__()
+        super().__init__(
+            loss_type,
+            loss_arguments,
+            optimizer_type,
+            optimizer_arguments,
+            augmentation_transforms,
+            num_classes,
+        )
 
         self.name = "lwnet"
         self.num_classes = num_classes
 
         self.model_transforms = [CenterCrop(size=(crop_size, crop_size))]
 
-        self._train_loss = train_loss
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        )
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self._augmentation_transforms = torchvision.transforms.Compose(
-            augmentation_transforms
-        )
-
         self.unet1 = LittleUNet(
             in_c=3,
             n_classes=self.num_classes,