From 993945b7eda52e96867bdaee4d8fac69f32fa5d1 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Fri, 28 Jun 2024 15:26:17 +0200 Subject: [PATCH] [models] Add support for learning rate schedulers --- .../classification/config/models/alexnet.py | 2 ++ .../config/models/alexnet_pretrained.py | 2 ++ .../classification/config/models/densenet.py | 2 ++ .../config/models/densenet_pretrained.py | 2 ++ .../config/models/densenet_rs.py | 2 ++ .../libs/classification/config/models/pasa.py | 2 ++ .../libs/classification/models/alexnet.py | 8 +++++++ .../models/classification_model.py | 8 +++++++ .../libs/classification/models/densenet.py | 8 +++++++ src/mednet/libs/classification/models/pasa.py | 8 +++++++ src/mednet/libs/common/engine/callbacks.py | 4 +++- src/mednet/libs/common/models/model.py | 21 ++++++++++++++++++- .../libs/segmentation/config/models/driu.py | 2 ++ .../segmentation/config/models/driu_bn.py | 2 ++ .../segmentation/config/models/driu_od.py | 2 ++ .../segmentation/config/models/driu_pix.py | 2 ++ .../libs/segmentation/config/models/hed.py | 2 ++ .../libs/segmentation/config/models/lwnet.py | 3 +++ .../libs/segmentation/config/models/m2unet.py | 2 ++ .../libs/segmentation/config/models/unet.py | 2 ++ src/mednet/libs/segmentation/models/driu.py | 8 +++++++ .../libs/segmentation/models/driu_bn.py | 8 +++++++ .../libs/segmentation/models/driu_od.py | 8 +++++++ .../libs/segmentation/models/driu_pix.py | 8 +++++++ src/mednet/libs/segmentation/models/hed.py | 8 +++++++ src/mednet/libs/segmentation/models/lwnet.py | 9 +++++++- src/mednet/libs/segmentation/models/m2unet.py | 8 +++++++ .../segmentation/models/segmentation_model.py | 8 +++++++ src/mednet/libs/segmentation/models/unet.py | 8 +++++++ 29 files changed, 156 insertions(+), 3 deletions(-) diff --git a/src/mednet/libs/classification/config/models/alexnet.py b/src/mednet/libs/classification/config/models/alexnet.py index 702f9ed0..02478ed2 100644 --- a/src/mednet/libs/classification/config/models/alexnet.py +++ b/src/mednet/libs/classification/config/models/alexnet.py @@ -18,6 +18,8 @@ model = Alexnet( loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), + scheduler_type=None, + scheduler_arguments=dict(), pretrained=False, model_transforms=[ SquareCenterPad(), diff --git a/src/mednet/libs/classification/config/models/alexnet_pretrained.py b/src/mednet/libs/classification/config/models/alexnet_pretrained.py index 8416db28..21891149 100644 --- a/src/mednet/libs/classification/config/models/alexnet_pretrained.py +++ b/src/mednet/libs/classification/config/models/alexnet_pretrained.py @@ -20,6 +20,8 @@ model = Alexnet( loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), + scheduler_type=None, + scheduler_arguments=dict(), pretrained=True, model_transforms=[ SquareCenterPad(), diff --git a/src/mednet/libs/classification/config/models/densenet.py b/src/mednet/libs/classification/config/models/densenet.py index 65f7e90a..9ef4ff95 100644 --- a/src/mednet/libs/classification/config/models/densenet.py +++ b/src/mednet/libs/classification/config/models/densenet.py @@ -16,6 +16,8 @@ model = Densenet( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), + scheduler_type=None, + scheduler_arguments=dict(), pretrained=False, dropout=0.1, model_transforms=[], diff --git a/src/mednet/libs/classification/config/models/densenet_pretrained.py b/src/mednet/libs/classification/config/models/densenet_pretrained.py index 8d7103ae..0f01a23f 100644 --- a/src/mednet/libs/classification/config/models/densenet_pretrained.py +++ b/src/mednet/libs/classification/config/models/densenet_pretrained.py @@ -20,6 +20,8 @@ model = Densenet( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), + scheduler_type=None, + scheduler_arguments=dict(), pretrained=True, dropout=0.1, model_transforms=[ diff --git a/src/mednet/libs/classification/config/models/densenet_rs.py b/src/mednet/libs/classification/config/models/densenet_rs.py index 9d300875..48847e09 100644 --- a/src/mednet/libs/classification/config/models/densenet_rs.py +++ b/src/mednet/libs/classification/config/models/densenet_rs.py @@ -17,6 +17,8 @@ model = Densenet( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), + scheduler_type=None, + scheduler_arguments=dict(), pretrained=False, dropout=0.1, num_classes=14, # number of classes in NIH CXR-14 diff --git a/src/mednet/libs/classification/config/models/pasa.py b/src/mednet/libs/classification/config/models/pasa.py index 637df9a4..e8245f5b 100644 --- a/src/mednet/libs/classification/config/models/pasa.py +++ b/src/mednet/libs/classification/config/models/pasa.py @@ -20,6 +20,8 @@ model = Pasa( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=8e-5), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ Grayscale(), SquareCenterPad(), diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py index f930b625..b400a728 100644 --- a/src/mednet/libs/classification/models/alexnet.py +++ b/src/mednet/libs/classification/models/alexnet.py @@ -37,6 +37,10 @@ class Alexnet(ClassificationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -56,6 +60,8 @@ class Alexnet(ClassificationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], pretrained: bool = False, @@ -66,6 +72,8 @@ class Alexnet(ClassificationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py index e9654b48..b125dec9 100644 --- a/src/mednet/libs/classification/models/classification_model.py +++ b/src/mednet/libs/classification/models/classification_model.py @@ -33,6 +33,10 @@ class ClassificationModel(Model): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -49,6 +53,8 @@ class ClassificationModel(Model): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -58,6 +64,8 @@ class ClassificationModel(Model): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py index c40cc957..e58446d1 100644 --- a/src/mednet/libs/classification/models/densenet.py +++ b/src/mednet/libs/classification/models/densenet.py @@ -35,6 +35,10 @@ class Densenet(ClassificationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -56,6 +60,8 @@ class Densenet(ClassificationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], pretrained: bool = False, @@ -67,6 +73,8 @@ class Densenet(ClassificationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py index b17af172..8eea2c38 100644 --- a/src/mednet/libs/classification/models/pasa.py +++ b/src/mednet/libs/classification/models/pasa.py @@ -40,6 +40,10 @@ class Pasa(ClassificationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -56,6 +60,8 @@ class Pasa(ClassificationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -65,6 +71,8 @@ class Pasa(ClassificationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/common/engine/callbacks.py b/src/mednet/libs/common/engine/callbacks.py index b030c882..4637b3af 100644 --- a/src/mednet/libs/common/engine/callbacks.py +++ b/src/mednet/libs/common/engine/callbacks.py @@ -141,7 +141,9 @@ class LoggingCallback(lightning.pytorch.Callback): epoch_time = time.time() - self._start_training_epoch_time self._to_log["epoch-duration-seconds/train"] = epoch_time - self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] # type: ignore + self._to_log["learning-rate"] = pl_module.trainer.lr_scheduler_configs[ + 0 + ].scheduler.optimizer.param_groups[0]["lr"] # type: ignore overall_cycle_time = time.time() - self._start_training_epoch_time self._to_log["cycle-time-seconds/train"] = overall_cycle_time diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index cc9d5d84..b59301e2 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -8,6 +8,7 @@ import typing import lightning.pytorch as pl import torch import torch.nn +import torch.optim.lr_scheduler import torch.optim.optimizer import torch.utils.data import torchvision.transforms @@ -37,6 +38,10 @@ class Model(pl.LightningModule): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -53,6 +58,8 @@ class Model(pl.LightningModule): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -75,6 +82,9 @@ class Model(pl.LightningModule): self._optimizer_type = optimizer_type self._optimizer_arguments = optimizer_arguments + self._scheduler_type = scheduler_type + self._scheduler_arguments = scheduler_arguments + self.augmentation_transforms = augmentation_transforms @property @@ -146,11 +156,20 @@ class Model(pl.LightningModule): self._validation_loss = self._loss_type(**self._validation_loss_arguments) def configure_optimizers(self): - return self._optimizer_type( + optimizer = self._optimizer_type( self.parameters(), **self._optimizer_arguments, ) + if self._scheduler_type is None: + return optimizer + + scheduler = self._scheduler_type( + optimizer, + **self._scheduler_arguments, + ) + return [optimizer], [scheduler] + def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self: """Move model, augmentations and losses to specified device. diff --git a/src/mednet/libs/segmentation/config/models/driu.py b/src/mednet/libs/segmentation/config/models/driu.py index 4f1be608..43bb6414 100644 --- a/src/mednet/libs/segmentation/config/models/driu.py +++ b/src/mednet/libs/segmentation/config/models/driu.py @@ -41,6 +41,8 @@ model = DRIU( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/driu_bn.py b/src/mednet/libs/segmentation/config/models/driu_bn.py index fc435cb9..7c83b1eb 100644 --- a/src/mednet/libs/segmentation/config/models/driu_bn.py +++ b/src/mednet/libs/segmentation/config/models/driu_bn.py @@ -41,6 +41,8 @@ model = DRIUBN( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/driu_od.py b/src/mednet/libs/segmentation/config/models/driu_od.py index 09318826..6a80dfa8 100644 --- a/src/mednet/libs/segmentation/config/models/driu_od.py +++ b/src/mednet/libs/segmentation/config/models/driu_od.py @@ -41,6 +41,8 @@ model = DRIUOD( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/driu_pix.py b/src/mednet/libs/segmentation/config/models/driu_pix.py index 72b2feea..0562ceb1 100644 --- a/src/mednet/libs/segmentation/config/models/driu_pix.py +++ b/src/mednet/libs/segmentation/config/models/driu_pix.py @@ -41,6 +41,8 @@ model = DRIUPix( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/hed.py b/src/mednet/libs/segmentation/config/models/hed.py index b3ce43fe..528e817e 100644 --- a/src/mednet/libs/segmentation/config/models/hed.py +++ b/src/mednet/libs/segmentation/config/models/hed.py @@ -32,6 +32,8 @@ model = HED( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py index 339fa22a..ecfef890 100644 --- a/src/mednet/libs/segmentation/config/models/lwnet.py +++ b/src/mednet/libs/segmentation/config/models/lwnet.py @@ -13,6 +13,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss from mednet.libs.segmentation.models.lwnet import LittleWNet from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR max_lr = 0.01 # start min_lr = 1e-08 # valley @@ -24,6 +25,8 @@ model = LittleWNet( loss_type=MultiWeightedBCELogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=max_lr), + scheduler_type=CosineAnnealingLR, + scheduler_arguments=dict(T_max=cycle, eta_min=min_lr), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/m2unet.py b/src/mednet/libs/segmentation/config/models/m2unet.py index 39389461..ccfae59b 100644 --- a/src/mednet/libs/segmentation/config/models/m2unet.py +++ b/src/mednet/libs/segmentation/config/models/m2unet.py @@ -45,6 +45,8 @@ model = M2UNET( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/unet.py b/src/mednet/libs/segmentation/config/models/unet.py index 0dc6c417..31d79ec2 100644 --- a/src/mednet/libs/segmentation/config/models/unet.py +++ b/src/mednet/libs/segmentation/config/models/unet.py @@ -43,6 +43,8 @@ model = Unet( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py index e12aac71..295713cb 100644 --- a/src/mednet/libs/segmentation/models/driu.py +++ b/src/mednet/libs/segmentation/models/driu.py @@ -88,6 +88,10 @@ class DRIU(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -106,6 +110,8 @@ class DRIU(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -116,6 +122,8 @@ class DRIU(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py index 555d4ef1..bacbf291 100644 --- a/src/mednet/libs/segmentation/models/driu_bn.py +++ b/src/mednet/libs/segmentation/models/driu_bn.py @@ -91,6 +91,10 @@ class DRIUBN(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -109,6 +113,8 @@ class DRIUBN(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -119,6 +125,8 @@ class DRIUBN(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py index 3b0cff03..af4727c5 100644 --- a/src/mednet/libs/segmentation/models/driu_od.py +++ b/src/mednet/libs/segmentation/models/driu_od.py @@ -73,6 +73,10 @@ class DRIUOD(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -91,6 +95,8 @@ class DRIUOD(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -101,6 +107,8 @@ class DRIUOD(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py index 75adce0c..e9800984 100644 --- a/src/mednet/libs/segmentation/models/driu_pix.py +++ b/src/mednet/libs/segmentation/models/driu_pix.py @@ -77,6 +77,10 @@ class DRIUPix(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -95,6 +99,8 @@ class DRIUPix(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -105,6 +111,8 @@ class DRIUPix(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py index 97a663c5..12a7c591 100644 --- a/src/mednet/libs/segmentation/models/hed.py +++ b/src/mednet/libs/segmentation/models/hed.py @@ -91,6 +91,10 @@ class HED(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -109,6 +113,8 @@ class HED(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -119,6 +125,8 @@ class HED(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index adb17512..3a552c49 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -294,6 +294,10 @@ class LittleWNet(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -310,6 +314,8 @@ class LittleWNet(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -319,6 +325,8 @@ class LittleWNet(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, @@ -346,5 +354,4 @@ class LittleWNet(SegmentationModel): xn = self.normalizer(x) x1 = self.unet1(xn) x2 = self.unet2(torch.cat([xn, x1], dim=1)) - return x1, x2 diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py index cf7ce5c4..fcca7d60 100644 --- a/src/mednet/libs/segmentation/models/m2unet.py +++ b/src/mednet/libs/segmentation/models/m2unet.py @@ -139,6 +139,10 @@ class M2UNET(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -157,6 +161,8 @@ class M2UNET(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -167,6 +173,8 @@ class M2UNET(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py index 8d8aa3ca..834af45f 100644 --- a/src/mednet/libs/segmentation/models/segmentation_model.py +++ b/src/mednet/libs/segmentation/models/segmentation_model.py @@ -34,6 +34,10 @@ class SegmentationModel(Model): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -50,6 +54,8 @@ class SegmentationModel(Model): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -59,6 +65,8 @@ class SegmentationModel(Model): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py index 6d9a76da..cb1e34cf 100644 --- a/src/mednet/libs/segmentation/models/unet.py +++ b/src/mednet/libs/segmentation/models/unet.py @@ -80,6 +80,10 @@ class Unet(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -98,6 +102,8 @@ class Unet(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -108,6 +114,8 @@ class Unet(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, -- GitLab