diff --git a/src/mednet/libs/classification/config/models/alexnet.py b/src/mednet/libs/classification/config/models/alexnet.py
index 702f9ed0de79fbdd18b54ca8ea507fbda6cadf63..02478ed228d324ae976436dd8ab62e4f81559929 100644
--- a/src/mednet/libs/classification/config/models/alexnet.py
+++ b/src/mednet/libs/classification/config/models/alexnet.py
@@ -18,6 +18,8 @@ model = Alexnet(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     pretrained=False,
     model_transforms=[
         SquareCenterPad(),
diff --git a/src/mednet/libs/classification/config/models/alexnet_pretrained.py b/src/mednet/libs/classification/config/models/alexnet_pretrained.py
index 8416db28deba66f6df16668297670e1232a6c540..21891149277fc688fde2b0c0fac60de4c5e3a096 100644
--- a/src/mednet/libs/classification/config/models/alexnet_pretrained.py
+++ b/src/mednet/libs/classification/config/models/alexnet_pretrained.py
@@ -20,6 +20,8 @@ model = Alexnet(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     pretrained=True,
     model_transforms=[
         SquareCenterPad(),
diff --git a/src/mednet/libs/classification/config/models/densenet.py b/src/mednet/libs/classification/config/models/densenet.py
index 65f7e90a0f06e41cf3efe4b6e0de035fed64fc1a..9ef4ff95437d53e074e72b7107ffebf737f50f68 100644
--- a/src/mednet/libs/classification/config/models/densenet.py
+++ b/src/mednet/libs/classification/config/models/densenet.py
@@ -16,6 +16,8 @@ model = Densenet(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     pretrained=False,
     dropout=0.1,
     model_transforms=[],
diff --git a/src/mednet/libs/classification/config/models/densenet_pretrained.py b/src/mednet/libs/classification/config/models/densenet_pretrained.py
index 8d7103aeeb92275e828a341faf3346ee68e69bf4..0f01a23f023d6b2e1ba04194f34bfc417fb91100 100644
--- a/src/mednet/libs/classification/config/models/densenet_pretrained.py
+++ b/src/mednet/libs/classification/config/models/densenet_pretrained.py
@@ -20,6 +20,8 @@ model = Densenet(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     pretrained=True,
     dropout=0.1,
     model_transforms=[
diff --git a/src/mednet/libs/classification/config/models/densenet_rs.py b/src/mednet/libs/classification/config/models/densenet_rs.py
index 9d3008755da6f9ce02be29cdf91adfdbb28cdfc2..48847e0971c70e28a5478136cadcae3ed210cc0e 100644
--- a/src/mednet/libs/classification/config/models/densenet_rs.py
+++ b/src/mednet/libs/classification/config/models/densenet_rs.py
@@ -17,6 +17,8 @@ model = Densenet(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     pretrained=False,
     dropout=0.1,
     num_classes=14,  # number of classes in NIH CXR-14
diff --git a/src/mednet/libs/classification/config/models/pasa.py b/src/mednet/libs/classification/config/models/pasa.py
index 637df9a48a954d2c913d369cd495d4b349311fe2..e8245f5b42963ef63a959b053ebdb385d378906b 100644
--- a/src/mednet/libs/classification/config/models/pasa.py
+++ b/src/mednet/libs/classification/config/models/pasa.py
@@ -20,6 +20,8 @@ model = Pasa(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=8e-5),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         Grayscale(),
         SquareCenterPad(),
diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py
index f930b625e96361182b8b9240e2fd2c82194328f5..b400a728d2219f19c8148bff1f39d1de88117833 100644
--- a/src/mednet/libs/classification/models/alexnet.py
+++ b/src/mednet/libs/classification/models/alexnet.py
@@ -37,6 +37,10 @@ class Alexnet(ClassificationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -56,6 +60,8 @@ class Alexnet(ClassificationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         pretrained: bool = False,
@@ -66,6 +72,8 @@ class Alexnet(ClassificationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py
index e9654b480b67c7dc8b2ec249ed1921822383d33b..b125dec9d66f9881b7d43e4d0d6036197efc950a 100644
--- a/src/mednet/libs/classification/models/classification_model.py
+++ b/src/mednet/libs/classification/models/classification_model.py
@@ -33,6 +33,10 @@ class ClassificationModel(Model):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -49,6 +53,8 @@ class ClassificationModel(Model):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -58,6 +64,8 @@ class ClassificationModel(Model):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py
index c40cc957c5ca8f284b6c3f4832bb1b7d54052bda..e58446d1e04794fa4ad329050176385f21a3e798 100644
--- a/src/mednet/libs/classification/models/densenet.py
+++ b/src/mednet/libs/classification/models/densenet.py
@@ -35,6 +35,10 @@ class Densenet(ClassificationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -56,6 +60,8 @@ class Densenet(ClassificationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         pretrained: bool = False,
@@ -67,6 +73,8 @@ class Densenet(ClassificationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py
index b17af1724e97c18e6b58b1c25f9db3b3e8c167cb..8eea2c38e3bff899ac9686b6aba695df04c687b9 100644
--- a/src/mednet/libs/classification/models/pasa.py
+++ b/src/mednet/libs/classification/models/pasa.py
@@ -40,6 +40,10 @@ class Pasa(ClassificationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -56,6 +60,8 @@ class Pasa(ClassificationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -65,6 +71,8 @@ class Pasa(ClassificationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/common/engine/callbacks.py b/src/mednet/libs/common/engine/callbacks.py
index b030c8820f6a7ec748733ab9ba692f35f48f7d8f..4637b3af89ed53cf9810d019784b669bc0bc8d82 100644
--- a/src/mednet/libs/common/engine/callbacks.py
+++ b/src/mednet/libs/common/engine/callbacks.py
@@ -141,7 +141,9 @@ class LoggingCallback(lightning.pytorch.Callback):
         epoch_time = time.time() - self._start_training_epoch_time
 
         self._to_log["epoch-duration-seconds/train"] = epoch_time
-        self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"]  # type: ignore
+        self._to_log["learning-rate"] = pl_module.trainer.lr_scheduler_configs[
+            0
+        ].scheduler.optimizer.param_groups[0]["lr"]  # type: ignore
 
         overall_cycle_time = time.time() - self._start_training_epoch_time
         self._to_log["cycle-time-seconds/train"] = overall_cycle_time
diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py
index cc9d5d84dee0a571d7809e4dfa1832f429c68de3..b59301e2b061785ae1c3d728588a767db213f06a 100644
--- a/src/mednet/libs/common/models/model.py
+++ b/src/mednet/libs/common/models/model.py
@@ -8,6 +8,7 @@ import typing
 import lightning.pytorch as pl
 import torch
 import torch.nn
+import torch.optim.lr_scheduler
 import torch.optim.optimizer
 import torch.utils.data
 import torchvision.transforms
@@ -37,6 +38,10 @@ class Model(pl.LightningModule):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -53,6 +58,8 @@ class Model(pl.LightningModule):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -75,6 +82,9 @@ class Model(pl.LightningModule):
         self._optimizer_type = optimizer_type
         self._optimizer_arguments = optimizer_arguments
 
+        self._scheduler_type = scheduler_type
+        self._scheduler_arguments = scheduler_arguments
+
         self.augmentation_transforms = augmentation_transforms
 
     @property
@@ -146,11 +156,20 @@ class Model(pl.LightningModule):
         self._validation_loss = self._loss_type(**self._validation_loss_arguments)
 
     def configure_optimizers(self):
-        return self._optimizer_type(
+        optimizer = self._optimizer_type(
             self.parameters(),
             **self._optimizer_arguments,
         )
 
+        if self._scheduler_type is None:
+            return optimizer
+
+        scheduler = self._scheduler_type(
+            optimizer,
+            **self._scheduler_arguments,
+        )
+        return [optimizer], [scheduler]
+
     def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
         """Move model, augmentations and losses to specified device.
 
diff --git a/src/mednet/libs/segmentation/config/models/driu.py b/src/mednet/libs/segmentation/config/models/driu.py
index 4f1be608a32e0463ef357f5f2c62dc6c1d8938e1..43bb6414750aece5bd7f5be6eb1a36602401af6a 100644
--- a/src/mednet/libs/segmentation/config/models/driu.py
+++ b/src/mednet/libs/segmentation/config/models/driu.py
@@ -41,6 +41,8 @@ model = DRIU(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/driu_bn.py b/src/mednet/libs/segmentation/config/models/driu_bn.py
index fc435cb94b3a4c9fd15a3636722bbd70f4f23e17..7c83b1eb5cb138f75fe0fb0e23660181fd3bd8cd 100644
--- a/src/mednet/libs/segmentation/config/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/config/models/driu_bn.py
@@ -41,6 +41,8 @@ model = DRIUBN(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/driu_od.py b/src/mednet/libs/segmentation/config/models/driu_od.py
index 0931882671709adcd2d04a2c5cbd80a01656c369..6a80dfa8187b3690d6ae35168a39490f16212472 100644
--- a/src/mednet/libs/segmentation/config/models/driu_od.py
+++ b/src/mednet/libs/segmentation/config/models/driu_od.py
@@ -41,6 +41,8 @@ model = DRIUOD(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/driu_pix.py b/src/mednet/libs/segmentation/config/models/driu_pix.py
index 72b2feeac53658bca9847eee149746051bf5d4c9..0562ceb17285e1a847c675ff2f6a7ac7a751eb00 100644
--- a/src/mednet/libs/segmentation/config/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/config/models/driu_pix.py
@@ -41,6 +41,8 @@ model = DRIUPix(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/hed.py b/src/mednet/libs/segmentation/config/models/hed.py
index b3ce43fe6dc299fe37a157e0c1d899d60d3c9882..528e817e97ccd6028676e63f0326b8328803db92 100644
--- a/src/mednet/libs/segmentation/config/models/hed.py
+++ b/src/mednet/libs/segmentation/config/models/hed.py
@@ -32,6 +32,8 @@ model = HED(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py
index 339fa22a1f2bde9b6439730af0b6f1c7a2cd192c..ecfef8905da42c968a8364a5ca953dadc3ed4180 100644
--- a/src/mednet/libs/segmentation/config/models/lwnet.py
+++ b/src/mednet/libs/segmentation/config/models/lwnet.py
@@ -13,6 +13,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
 from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
 from mednet.libs.segmentation.models.lwnet import LittleWNet
 from torch.optim import Adam
+from torch.optim.lr_scheduler import CosineAnnealingLR
 
 max_lr = 0.01  # start
 min_lr = 1e-08  # valley
@@ -24,6 +25,8 @@ model = LittleWNet(
     loss_type=MultiWeightedBCELogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=max_lr),
+    scheduler_type=CosineAnnealingLR,
+    scheduler_arguments=dict(T_max=cycle, eta_min=min_lr),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/m2unet.py b/src/mednet/libs/segmentation/config/models/m2unet.py
index 39389461d807da45af7eefd9cdbf3a081eb38db1..ccfae59b323b2e9c0042248c24c4029b6eb597d2 100644
--- a/src/mednet/libs/segmentation/config/models/m2unet.py
+++ b/src/mednet/libs/segmentation/config/models/m2unet.py
@@ -45,6 +45,8 @@ model = M2UNET(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/unet.py b/src/mednet/libs/segmentation/config/models/unet.py
index 0dc6c417e4af6e6906d74a41abda640d00b4fa74..31d79ec28f221ab48c5a55361d5f8eac29f66a64 100644
--- a/src/mednet/libs/segmentation/config/models/unet.py
+++ b/src/mednet/libs/segmentation/config/models/unet.py
@@ -43,6 +43,8 @@ model = Unet(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index e12aac71517ff42a6caf6b2ce052da2be8f32db0..295713cb5ec4c22ae4ab391a88d21af41cf6d592 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -88,6 +88,10 @@ class DRIU(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -106,6 +110,8 @@ class DRIU(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -116,6 +122,8 @@ class DRIU(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index 555d4ef1ffc03e6d88ba81c60326df0cfc745a74..bacbf2913f02b54f5b00b31e35600416e6a59e9b 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -91,6 +91,10 @@ class DRIUBN(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -109,6 +113,8 @@ class DRIUBN(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -119,6 +125,8 @@ class DRIUBN(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index 3b0cff036c843835a54e566206c045569e4bcd9c..af4727c556b112020739a37337f724afe3ed1b24 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -73,6 +73,10 @@ class DRIUOD(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -91,6 +95,8 @@ class DRIUOD(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -101,6 +107,8 @@ class DRIUOD(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index 75adce0c91da0fd53ff9386b3ffa08bc187fa72f..e9800984a884106e36a2aa96238bad48b784dc49 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -77,6 +77,10 @@ class DRIUPix(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -95,6 +99,8 @@ class DRIUPix(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -105,6 +111,8 @@ class DRIUPix(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 97a663c51cc3fa649f70f1c0902fad712babb916..12a7c591aec5d75fb971b9265a27ade8cf51806f 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -91,6 +91,10 @@ class HED(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -109,6 +113,8 @@ class HED(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -119,6 +125,8 @@ class HED(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index adb17512bcd3953efd65c6dc42b382d0b99f12ea..3a552c4901cef48eed003eeb4cf127a9455666eb 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -294,6 +294,10 @@ class LittleWNet(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -310,6 +314,8 @@ class LittleWNet(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -319,6 +325,8 @@ class LittleWNet(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
@@ -346,5 +354,4 @@ class LittleWNet(SegmentationModel):
         xn = self.normalizer(x)
         x1 = self.unet1(xn)
         x2 = self.unet2(torch.cat([xn, x1], dim=1))
-
         return x1, x2
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index cf7ce5c436b8af5796f119c06940d8098c3d238b..fcca7d6072a88ad6bc6bccd4932eef2a1b0a153f 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -139,6 +139,10 @@ class M2UNET(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -157,6 +161,8 @@ class M2UNET(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -167,6 +173,8 @@ class M2UNET(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py
index 8d8aa3ca42757bc415fc9784d309119f6f4599c4..834af45f087e4cc27d9089fc67efe51a6041e5a5 100644
--- a/src/mednet/libs/segmentation/models/segmentation_model.py
+++ b/src/mednet/libs/segmentation/models/segmentation_model.py
@@ -34,6 +34,10 @@ class SegmentationModel(Model):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -50,6 +54,8 @@ class SegmentationModel(Model):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -59,6 +65,8 @@ class SegmentationModel(Model):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
index 6d9a76da8aea735d7cee850955698d71392356e3..cb1e34cfa7c07fcbafa43a59269412703f4be3ef 100644
--- a/src/mednet/libs/segmentation/models/unet.py
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -80,6 +80,10 @@ class Unet(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -98,6 +102,8 @@ class Unet(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -108,6 +114,8 @@ class Unet(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,