From 993945b7eda52e96867bdaee4d8fac69f32fa5d1 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Fri, 28 Jun 2024 15:26:17 +0200
Subject: [PATCH] [models] Add support for learning rate schedulers

---
 .../classification/config/models/alexnet.py   |  2 ++
 .../config/models/alexnet_pretrained.py       |  2 ++
 .../classification/config/models/densenet.py  |  2 ++
 .../config/models/densenet_pretrained.py      |  2 ++
 .../config/models/densenet_rs.py              |  2 ++
 .../libs/classification/config/models/pasa.py |  2 ++
 .../libs/classification/models/alexnet.py     |  8 +++++++
 .../models/classification_model.py            |  8 +++++++
 .../libs/classification/models/densenet.py    |  8 +++++++
 src/mednet/libs/classification/models/pasa.py |  8 +++++++
 src/mednet/libs/common/engine/callbacks.py    |  4 +++-
 src/mednet/libs/common/models/model.py        | 21 ++++++++++++++++++-
 .../libs/segmentation/config/models/driu.py   |  2 ++
 .../segmentation/config/models/driu_bn.py     |  2 ++
 .../segmentation/config/models/driu_od.py     |  2 ++
 .../segmentation/config/models/driu_pix.py    |  2 ++
 .../libs/segmentation/config/models/hed.py    |  2 ++
 .../libs/segmentation/config/models/lwnet.py  |  3 +++
 .../libs/segmentation/config/models/m2unet.py |  2 ++
 .../libs/segmentation/config/models/unet.py   |  2 ++
 src/mednet/libs/segmentation/models/driu.py   |  8 +++++++
 .../libs/segmentation/models/driu_bn.py       |  8 +++++++
 .../libs/segmentation/models/driu_od.py       |  8 +++++++
 .../libs/segmentation/models/driu_pix.py      |  8 +++++++
 src/mednet/libs/segmentation/models/hed.py    |  8 +++++++
 src/mednet/libs/segmentation/models/lwnet.py  |  9 +++++++-
 src/mednet/libs/segmentation/models/m2unet.py |  8 +++++++
 .../segmentation/models/segmentation_model.py |  8 +++++++
 src/mednet/libs/segmentation/models/unet.py   |  8 +++++++
 29 files changed, 156 insertions(+), 3 deletions(-)

diff --git a/src/mednet/libs/classification/config/models/alexnet.py b/src/mednet/libs/classification/config/models/alexnet.py
index 702f9ed0..02478ed2 100644
--- a/src/mednet/libs/classification/config/models/alexnet.py
+++ b/src/mednet/libs/classification/config/models/alexnet.py
@@ -18,6 +18,8 @@ model = Alexnet(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     pretrained=False,
     model_transforms=[
         SquareCenterPad(),
diff --git a/src/mednet/libs/classification/config/models/alexnet_pretrained.py b/src/mednet/libs/classification/config/models/alexnet_pretrained.py
index 8416db28..21891149 100644
--- a/src/mednet/libs/classification/config/models/alexnet_pretrained.py
+++ b/src/mednet/libs/classification/config/models/alexnet_pretrained.py
@@ -20,6 +20,8 @@ model = Alexnet(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     pretrained=True,
     model_transforms=[
         SquareCenterPad(),
diff --git a/src/mednet/libs/classification/config/models/densenet.py b/src/mednet/libs/classification/config/models/densenet.py
index 65f7e90a..9ef4ff95 100644
--- a/src/mednet/libs/classification/config/models/densenet.py
+++ b/src/mednet/libs/classification/config/models/densenet.py
@@ -16,6 +16,8 @@ model = Densenet(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     pretrained=False,
     dropout=0.1,
     model_transforms=[],
diff --git a/src/mednet/libs/classification/config/models/densenet_pretrained.py b/src/mednet/libs/classification/config/models/densenet_pretrained.py
index 8d7103ae..0f01a23f 100644
--- a/src/mednet/libs/classification/config/models/densenet_pretrained.py
+++ b/src/mednet/libs/classification/config/models/densenet_pretrained.py
@@ -20,6 +20,8 @@ model = Densenet(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     pretrained=True,
     dropout=0.1,
     model_transforms=[
diff --git a/src/mednet/libs/classification/config/models/densenet_rs.py b/src/mednet/libs/classification/config/models/densenet_rs.py
index 9d300875..48847e09 100644
--- a/src/mednet/libs/classification/config/models/densenet_rs.py
+++ b/src/mednet/libs/classification/config/models/densenet_rs.py
@@ -17,6 +17,8 @@ model = Densenet(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     pretrained=False,
     dropout=0.1,
     num_classes=14,  # number of classes in NIH CXR-14
diff --git a/src/mednet/libs/classification/config/models/pasa.py b/src/mednet/libs/classification/config/models/pasa.py
index 637df9a4..e8245f5b 100644
--- a/src/mednet/libs/classification/config/models/pasa.py
+++ b/src/mednet/libs/classification/config/models/pasa.py
@@ -20,6 +20,8 @@ model = Pasa(
     loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=8e-5),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         Grayscale(),
         SquareCenterPad(),
diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py
index f930b625..b400a728 100644
--- a/src/mednet/libs/classification/models/alexnet.py
+++ b/src/mednet/libs/classification/models/alexnet.py
@@ -37,6 +37,10 @@ class Alexnet(ClassificationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -56,6 +60,8 @@ class Alexnet(ClassificationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         pretrained: bool = False,
@@ -66,6 +72,8 @@ class Alexnet(ClassificationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py
index e9654b48..b125dec9 100644
--- a/src/mednet/libs/classification/models/classification_model.py
+++ b/src/mednet/libs/classification/models/classification_model.py
@@ -33,6 +33,10 @@ class ClassificationModel(Model):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -49,6 +53,8 @@ class ClassificationModel(Model):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -58,6 +64,8 @@ class ClassificationModel(Model):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py
index c40cc957..e58446d1 100644
--- a/src/mednet/libs/classification/models/densenet.py
+++ b/src/mednet/libs/classification/models/densenet.py
@@ -35,6 +35,10 @@ class Densenet(ClassificationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -56,6 +60,8 @@ class Densenet(ClassificationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         pretrained: bool = False,
@@ -67,6 +73,8 @@ class Densenet(ClassificationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py
index b17af172..8eea2c38 100644
--- a/src/mednet/libs/classification/models/pasa.py
+++ b/src/mednet/libs/classification/models/pasa.py
@@ -40,6 +40,10 @@ class Pasa(ClassificationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -56,6 +60,8 @@ class Pasa(ClassificationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -65,6 +71,8 @@ class Pasa(ClassificationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/common/engine/callbacks.py b/src/mednet/libs/common/engine/callbacks.py
index b030c882..4637b3af 100644
--- a/src/mednet/libs/common/engine/callbacks.py
+++ b/src/mednet/libs/common/engine/callbacks.py
@@ -141,7 +141,9 @@ class LoggingCallback(lightning.pytorch.Callback):
         epoch_time = time.time() - self._start_training_epoch_time
 
         self._to_log["epoch-duration-seconds/train"] = epoch_time
-        self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"]  # type: ignore
+        self._to_log["learning-rate"] = pl_module.trainer.lr_scheduler_configs[
+            0
+        ].scheduler.optimizer.param_groups[0]["lr"]  # type: ignore
 
         overall_cycle_time = time.time() - self._start_training_epoch_time
         self._to_log["cycle-time-seconds/train"] = overall_cycle_time
diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py
index cc9d5d84..b59301e2 100644
--- a/src/mednet/libs/common/models/model.py
+++ b/src/mednet/libs/common/models/model.py
@@ -8,6 +8,7 @@ import typing
 import lightning.pytorch as pl
 import torch
 import torch.nn
+import torch.optim.lr_scheduler
 import torch.optim.optimizer
 import torch.utils.data
 import torchvision.transforms
@@ -37,6 +38,10 @@ class Model(pl.LightningModule):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -53,6 +58,8 @@ class Model(pl.LightningModule):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -75,6 +82,9 @@ class Model(pl.LightningModule):
         self._optimizer_type = optimizer_type
         self._optimizer_arguments = optimizer_arguments
 
+        self._scheduler_type = scheduler_type
+        self._scheduler_arguments = scheduler_arguments
+
         self.augmentation_transforms = augmentation_transforms
 
     @property
@@ -146,11 +156,20 @@ class Model(pl.LightningModule):
         self._validation_loss = self._loss_type(**self._validation_loss_arguments)
 
     def configure_optimizers(self):
-        return self._optimizer_type(
+        optimizer = self._optimizer_type(
             self.parameters(),
             **self._optimizer_arguments,
         )
 
+        if self._scheduler_type is None:
+            return optimizer
+
+        scheduler = self._scheduler_type(
+            optimizer,
+            **self._scheduler_arguments,
+        )
+        return [optimizer], [scheduler]
+
     def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
         """Move model, augmentations and losses to specified device.
 
diff --git a/src/mednet/libs/segmentation/config/models/driu.py b/src/mednet/libs/segmentation/config/models/driu.py
index 4f1be608..43bb6414 100644
--- a/src/mednet/libs/segmentation/config/models/driu.py
+++ b/src/mednet/libs/segmentation/config/models/driu.py
@@ -41,6 +41,8 @@ model = DRIU(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/driu_bn.py b/src/mednet/libs/segmentation/config/models/driu_bn.py
index fc435cb9..7c83b1eb 100644
--- a/src/mednet/libs/segmentation/config/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/config/models/driu_bn.py
@@ -41,6 +41,8 @@ model = DRIUBN(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/driu_od.py b/src/mednet/libs/segmentation/config/models/driu_od.py
index 09318826..6a80dfa8 100644
--- a/src/mednet/libs/segmentation/config/models/driu_od.py
+++ b/src/mednet/libs/segmentation/config/models/driu_od.py
@@ -41,6 +41,8 @@ model = DRIUOD(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/driu_pix.py b/src/mednet/libs/segmentation/config/models/driu_pix.py
index 72b2feea..0562ceb1 100644
--- a/src/mednet/libs/segmentation/config/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/config/models/driu_pix.py
@@ -41,6 +41,8 @@ model = DRIUPix(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/hed.py b/src/mednet/libs/segmentation/config/models/hed.py
index b3ce43fe..528e817e 100644
--- a/src/mednet/libs/segmentation/config/models/hed.py
+++ b/src/mednet/libs/segmentation/config/models/hed.py
@@ -32,6 +32,8 @@ model = HED(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py
index 339fa22a..ecfef890 100644
--- a/src/mednet/libs/segmentation/config/models/lwnet.py
+++ b/src/mednet/libs/segmentation/config/models/lwnet.py
@@ -13,6 +13,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
 from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
 from mednet.libs.segmentation.models.lwnet import LittleWNet
 from torch.optim import Adam
+from torch.optim.lr_scheduler import CosineAnnealingLR
 
 max_lr = 0.01  # start
 min_lr = 1e-08  # valley
@@ -24,6 +25,8 @@ model = LittleWNet(
     loss_type=MultiWeightedBCELogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=max_lr),
+    scheduler_type=CosineAnnealingLR,
+    scheduler_arguments=dict(T_max=cycle, eta_min=min_lr),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/m2unet.py b/src/mednet/libs/segmentation/config/models/m2unet.py
index 39389461..ccfae59b 100644
--- a/src/mednet/libs/segmentation/config/models/m2unet.py
+++ b/src/mednet/libs/segmentation/config/models/m2unet.py
@@ -45,6 +45,8 @@ model = M2UNET(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/config/models/unet.py b/src/mednet/libs/segmentation/config/models/unet.py
index 0dc6c417..31d79ec2 100644
--- a/src/mednet/libs/segmentation/config/models/unet.py
+++ b/src/mednet/libs/segmentation/config/models/unet.py
@@ -43,6 +43,8 @@ model = Unet(
         weight_decay=weight_decay,
         amsbound=amsbound,
     ),
+    scheduler_type=None,
+    scheduler_arguments=dict(),
     model_transforms=[
         resize_transform,
         SquareCenterPad(),
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index e12aac71..295713cb 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -88,6 +88,10 @@ class DRIU(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -106,6 +110,8 @@ class DRIU(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -116,6 +122,8 @@ class DRIU(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index 555d4ef1..bacbf291 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -91,6 +91,10 @@ class DRIUBN(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -109,6 +113,8 @@ class DRIUBN(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -119,6 +125,8 @@ class DRIUBN(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index 3b0cff03..af4727c5 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -73,6 +73,10 @@ class DRIUOD(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -91,6 +95,8 @@ class DRIUOD(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -101,6 +107,8 @@ class DRIUOD(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index 75adce0c..e9800984 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -77,6 +77,10 @@ class DRIUPix(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -95,6 +99,8 @@ class DRIUPix(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -105,6 +111,8 @@ class DRIUPix(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 97a663c5..12a7c591 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -91,6 +91,10 @@ class HED(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -109,6 +113,8 @@ class HED(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -119,6 +125,8 @@ class HED(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index adb17512..3a552c49 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -294,6 +294,10 @@ class LittleWNet(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -310,6 +314,8 @@ class LittleWNet(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -319,6 +325,8 @@ class LittleWNet(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
@@ -346,5 +354,4 @@ class LittleWNet(SegmentationModel):
         xn = self.normalizer(x)
         x1 = self.unet1(xn)
         x2 = self.unet2(torch.cat([xn, x1], dim=1))
-
         return x1, x2
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index cf7ce5c4..fcca7d60 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -139,6 +139,10 @@ class M2UNET(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -157,6 +161,8 @@ class M2UNET(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -167,6 +173,8 @@ class M2UNET(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py
index 8d8aa3ca..834af45f 100644
--- a/src/mednet/libs/segmentation/models/segmentation_model.py
+++ b/src/mednet/libs/segmentation/models/segmentation_model.py
@@ -34,6 +34,10 @@ class SegmentationModel(Model):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -50,6 +54,8 @@ class SegmentationModel(Model):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -59,6 +65,8 @@ class SegmentationModel(Model):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
index 6d9a76da..cb1e34cf 100644
--- a/src/mednet/libs/segmentation/models/unet.py
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -80,6 +80,10 @@ class Unet(SegmentationModel):
         The type of optimizer to use for training.
     optimizer_arguments
         Arguments to the optimizer after ``params``.
+    scheduler_type
+        The type of scheduler to use for training.
+    scheduler_arguments
+        Arguments to the scheduler after ``params``.
     model_transforms
         An optional sequence of torch modules containing transforms to be
         applied on the input **before** it is fed into the network.
@@ -98,6 +102,8 @@ class Unet(SegmentationModel):
         loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
+        scheduler_type: type[torch.optim.lr_scheduler] = None,
+        scheduler_arguments: dict[str, typing.Any] = {},
         model_transforms: TransformSequence = [],
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
@@ -108,6 +114,8 @@ class Unet(SegmentationModel):
             loss_arguments,
             optimizer_type,
             optimizer_arguments,
+            scheduler_type,
+            scheduler_arguments,
             model_transforms,
             augmentation_transforms,
             num_classes,
-- 
GitLab