diff --git a/src/mednet/libs/classification/config/models/alexnet.py b/src/mednet/libs/classification/config/models/alexnet.py index 702f9ed0de79fbdd18b54ca8ea507fbda6cadf63..02478ed228d324ae976436dd8ab62e4f81559929 100644 --- a/src/mednet/libs/classification/config/models/alexnet.py +++ b/src/mednet/libs/classification/config/models/alexnet.py @@ -18,6 +18,8 @@ model = Alexnet( loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), + scheduler_type=None, + scheduler_arguments=dict(), pretrained=False, model_transforms=[ SquareCenterPad(), diff --git a/src/mednet/libs/classification/config/models/alexnet_pretrained.py b/src/mednet/libs/classification/config/models/alexnet_pretrained.py index 8416db28deba66f6df16668297670e1232a6c540..21891149277fc688fde2b0c0fac60de4c5e3a096 100644 --- a/src/mednet/libs/classification/config/models/alexnet_pretrained.py +++ b/src/mednet/libs/classification/config/models/alexnet_pretrained.py @@ -20,6 +20,8 @@ model = Alexnet( loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), + scheduler_type=None, + scheduler_arguments=dict(), pretrained=True, model_transforms=[ SquareCenterPad(), diff --git a/src/mednet/libs/classification/config/models/densenet.py b/src/mednet/libs/classification/config/models/densenet.py index 65f7e90a0f06e41cf3efe4b6e0de035fed64fc1a..9ef4ff95437d53e074e72b7107ffebf737f50f68 100644 --- a/src/mednet/libs/classification/config/models/densenet.py +++ b/src/mednet/libs/classification/config/models/densenet.py @@ -16,6 +16,8 @@ model = Densenet( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), + scheduler_type=None, + scheduler_arguments=dict(), pretrained=False, dropout=0.1, model_transforms=[], diff --git a/src/mednet/libs/classification/config/models/densenet_pretrained.py b/src/mednet/libs/classification/config/models/densenet_pretrained.py index 8d7103aeeb92275e828a341faf3346ee68e69bf4..0f01a23f023d6b2e1ba04194f34bfc417fb91100 100644 --- a/src/mednet/libs/classification/config/models/densenet_pretrained.py +++ b/src/mednet/libs/classification/config/models/densenet_pretrained.py @@ -20,6 +20,8 @@ model = Densenet( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), + scheduler_type=None, + scheduler_arguments=dict(), pretrained=True, dropout=0.1, model_transforms=[ diff --git a/src/mednet/libs/classification/config/models/densenet_rs.py b/src/mednet/libs/classification/config/models/densenet_rs.py index 9d3008755da6f9ce02be29cdf91adfdbb28cdfc2..48847e0971c70e28a5478136cadcae3ed210cc0e 100644 --- a/src/mednet/libs/classification/config/models/densenet_rs.py +++ b/src/mednet/libs/classification/config/models/densenet_rs.py @@ -17,6 +17,8 @@ model = Densenet( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), + scheduler_type=None, + scheduler_arguments=dict(), pretrained=False, dropout=0.1, num_classes=14, # number of classes in NIH CXR-14 diff --git a/src/mednet/libs/classification/config/models/pasa.py b/src/mednet/libs/classification/config/models/pasa.py index 637df9a48a954d2c913d369cd495d4b349311fe2..e8245f5b42963ef63a959b053ebdb385d378906b 100644 --- a/src/mednet/libs/classification/config/models/pasa.py +++ b/src/mednet/libs/classification/config/models/pasa.py @@ -20,6 +20,8 @@ model = Pasa( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=8e-5), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ Grayscale(), SquareCenterPad(), diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py index f930b625e96361182b8b9240e2fd2c82194328f5..b400a728d2219f19c8148bff1f39d1de88117833 100644 --- a/src/mednet/libs/classification/models/alexnet.py +++ b/src/mednet/libs/classification/models/alexnet.py @@ -37,6 +37,10 @@ class Alexnet(ClassificationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -56,6 +60,8 @@ class Alexnet(ClassificationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], pretrained: bool = False, @@ -66,6 +72,8 @@ class Alexnet(ClassificationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py index e9654b480b67c7dc8b2ec249ed1921822383d33b..b125dec9d66f9881b7d43e4d0d6036197efc950a 100644 --- a/src/mednet/libs/classification/models/classification_model.py +++ b/src/mednet/libs/classification/models/classification_model.py @@ -33,6 +33,10 @@ class ClassificationModel(Model): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -49,6 +53,8 @@ class ClassificationModel(Model): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -58,6 +64,8 @@ class ClassificationModel(Model): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py index c40cc957c5ca8f284b6c3f4832bb1b7d54052bda..e58446d1e04794fa4ad329050176385f21a3e798 100644 --- a/src/mednet/libs/classification/models/densenet.py +++ b/src/mednet/libs/classification/models/densenet.py @@ -35,6 +35,10 @@ class Densenet(ClassificationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -56,6 +60,8 @@ class Densenet(ClassificationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], pretrained: bool = False, @@ -67,6 +73,8 @@ class Densenet(ClassificationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py index b17af1724e97c18e6b58b1c25f9db3b3e8c167cb..8eea2c38e3bff899ac9686b6aba695df04c687b9 100644 --- a/src/mednet/libs/classification/models/pasa.py +++ b/src/mednet/libs/classification/models/pasa.py @@ -40,6 +40,10 @@ class Pasa(ClassificationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -56,6 +60,8 @@ class Pasa(ClassificationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -65,6 +71,8 @@ class Pasa(ClassificationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/common/engine/callbacks.py b/src/mednet/libs/common/engine/callbacks.py index b030c8820f6a7ec748733ab9ba692f35f48f7d8f..4637b3af89ed53cf9810d019784b669bc0bc8d82 100644 --- a/src/mednet/libs/common/engine/callbacks.py +++ b/src/mednet/libs/common/engine/callbacks.py @@ -141,7 +141,9 @@ class LoggingCallback(lightning.pytorch.Callback): epoch_time = time.time() - self._start_training_epoch_time self._to_log["epoch-duration-seconds/train"] = epoch_time - self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] # type: ignore + self._to_log["learning-rate"] = pl_module.trainer.lr_scheduler_configs[ + 0 + ].scheduler.optimizer.param_groups[0]["lr"] # type: ignore overall_cycle_time = time.time() - self._start_training_epoch_time self._to_log["cycle-time-seconds/train"] = overall_cycle_time diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index cc9d5d84dee0a571d7809e4dfa1832f429c68de3..b59301e2b061785ae1c3d728588a767db213f06a 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -8,6 +8,7 @@ import typing import lightning.pytorch as pl import torch import torch.nn +import torch.optim.lr_scheduler import torch.optim.optimizer import torch.utils.data import torchvision.transforms @@ -37,6 +38,10 @@ class Model(pl.LightningModule): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -53,6 +58,8 @@ class Model(pl.LightningModule): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -75,6 +82,9 @@ class Model(pl.LightningModule): self._optimizer_type = optimizer_type self._optimizer_arguments = optimizer_arguments + self._scheduler_type = scheduler_type + self._scheduler_arguments = scheduler_arguments + self.augmentation_transforms = augmentation_transforms @property @@ -146,11 +156,20 @@ class Model(pl.LightningModule): self._validation_loss = self._loss_type(**self._validation_loss_arguments) def configure_optimizers(self): - return self._optimizer_type( + optimizer = self._optimizer_type( self.parameters(), **self._optimizer_arguments, ) + if self._scheduler_type is None: + return optimizer + + scheduler = self._scheduler_type( + optimizer, + **self._scheduler_arguments, + ) + return [optimizer], [scheduler] + def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self: """Move model, augmentations and losses to specified device. diff --git a/src/mednet/libs/segmentation/config/models/driu.py b/src/mednet/libs/segmentation/config/models/driu.py index 4f1be608a32e0463ef357f5f2c62dc6c1d8938e1..43bb6414750aece5bd7f5be6eb1a36602401af6a 100644 --- a/src/mednet/libs/segmentation/config/models/driu.py +++ b/src/mednet/libs/segmentation/config/models/driu.py @@ -41,6 +41,8 @@ model = DRIU( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/driu_bn.py b/src/mednet/libs/segmentation/config/models/driu_bn.py index fc435cb94b3a4c9fd15a3636722bbd70f4f23e17..7c83b1eb5cb138f75fe0fb0e23660181fd3bd8cd 100644 --- a/src/mednet/libs/segmentation/config/models/driu_bn.py +++ b/src/mednet/libs/segmentation/config/models/driu_bn.py @@ -41,6 +41,8 @@ model = DRIUBN( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/driu_od.py b/src/mednet/libs/segmentation/config/models/driu_od.py index 0931882671709adcd2d04a2c5cbd80a01656c369..6a80dfa8187b3690d6ae35168a39490f16212472 100644 --- a/src/mednet/libs/segmentation/config/models/driu_od.py +++ b/src/mednet/libs/segmentation/config/models/driu_od.py @@ -41,6 +41,8 @@ model = DRIUOD( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/driu_pix.py b/src/mednet/libs/segmentation/config/models/driu_pix.py index 72b2feeac53658bca9847eee149746051bf5d4c9..0562ceb17285e1a847c675ff2f6a7ac7a751eb00 100644 --- a/src/mednet/libs/segmentation/config/models/driu_pix.py +++ b/src/mednet/libs/segmentation/config/models/driu_pix.py @@ -41,6 +41,8 @@ model = DRIUPix( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/hed.py b/src/mednet/libs/segmentation/config/models/hed.py index b3ce43fe6dc299fe37a157e0c1d899d60d3c9882..528e817e97ccd6028676e63f0326b8328803db92 100644 --- a/src/mednet/libs/segmentation/config/models/hed.py +++ b/src/mednet/libs/segmentation/config/models/hed.py @@ -32,6 +32,8 @@ model = HED( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py index 339fa22a1f2bde9b6439730af0b6f1c7a2cd192c..ecfef8905da42c968a8364a5ca953dadc3ed4180 100644 --- a/src/mednet/libs/segmentation/config/models/lwnet.py +++ b/src/mednet/libs/segmentation/config/models/lwnet.py @@ -13,6 +13,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss from mednet.libs.segmentation.models.lwnet import LittleWNet from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR max_lr = 0.01 # start min_lr = 1e-08 # valley @@ -24,6 +25,8 @@ model = LittleWNet( loss_type=MultiWeightedBCELogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=max_lr), + scheduler_type=CosineAnnealingLR, + scheduler_arguments=dict(T_max=cycle, eta_min=min_lr), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/m2unet.py b/src/mednet/libs/segmentation/config/models/m2unet.py index 39389461d807da45af7eefd9cdbf3a081eb38db1..ccfae59b323b2e9c0042248c24c4029b6eb597d2 100644 --- a/src/mednet/libs/segmentation/config/models/m2unet.py +++ b/src/mednet/libs/segmentation/config/models/m2unet.py @@ -45,6 +45,8 @@ model = M2UNET( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/config/models/unet.py b/src/mednet/libs/segmentation/config/models/unet.py index 0dc6c417e4af6e6906d74a41abda640d00b4fa74..31d79ec28f221ab48c5a55361d5f8eac29f66a64 100644 --- a/src/mednet/libs/segmentation/config/models/unet.py +++ b/src/mednet/libs/segmentation/config/models/unet.py @@ -43,6 +43,8 @@ model = Unet( weight_decay=weight_decay, amsbound=amsbound, ), + scheduler_type=None, + scheduler_arguments=dict(), model_transforms=[ resize_transform, SquareCenterPad(), diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py index e12aac71517ff42a6caf6b2ce052da2be8f32db0..295713cb5ec4c22ae4ab391a88d21af41cf6d592 100644 --- a/src/mednet/libs/segmentation/models/driu.py +++ b/src/mednet/libs/segmentation/models/driu.py @@ -88,6 +88,10 @@ class DRIU(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -106,6 +110,8 @@ class DRIU(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -116,6 +122,8 @@ class DRIU(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py index 555d4ef1ffc03e6d88ba81c60326df0cfc745a74..bacbf2913f02b54f5b00b31e35600416e6a59e9b 100644 --- a/src/mednet/libs/segmentation/models/driu_bn.py +++ b/src/mednet/libs/segmentation/models/driu_bn.py @@ -91,6 +91,10 @@ class DRIUBN(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -109,6 +113,8 @@ class DRIUBN(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -119,6 +125,8 @@ class DRIUBN(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py index 3b0cff036c843835a54e566206c045569e4bcd9c..af4727c556b112020739a37337f724afe3ed1b24 100644 --- a/src/mednet/libs/segmentation/models/driu_od.py +++ b/src/mednet/libs/segmentation/models/driu_od.py @@ -73,6 +73,10 @@ class DRIUOD(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -91,6 +95,8 @@ class DRIUOD(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -101,6 +107,8 @@ class DRIUOD(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py index 75adce0c91da0fd53ff9386b3ffa08bc187fa72f..e9800984a884106e36a2aa96238bad48b784dc49 100644 --- a/src/mednet/libs/segmentation/models/driu_pix.py +++ b/src/mednet/libs/segmentation/models/driu_pix.py @@ -77,6 +77,10 @@ class DRIUPix(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -95,6 +99,8 @@ class DRIUPix(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -105,6 +111,8 @@ class DRIUPix(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py index 97a663c51cc3fa649f70f1c0902fad712babb916..12a7c591aec5d75fb971b9265a27ade8cf51806f 100644 --- a/src/mednet/libs/segmentation/models/hed.py +++ b/src/mednet/libs/segmentation/models/hed.py @@ -91,6 +91,10 @@ class HED(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -109,6 +113,8 @@ class HED(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -119,6 +125,8 @@ class HED(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index adb17512bcd3953efd65c6dc42b382d0b99f12ea..3a552c4901cef48eed003eeb4cf127a9455666eb 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -294,6 +294,10 @@ class LittleWNet(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -310,6 +314,8 @@ class LittleWNet(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -319,6 +325,8 @@ class LittleWNet(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, @@ -346,5 +354,4 @@ class LittleWNet(SegmentationModel): xn = self.normalizer(x) x1 = self.unet1(xn) x2 = self.unet2(torch.cat([xn, x1], dim=1)) - return x1, x2 diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py index cf7ce5c436b8af5796f119c06940d8098c3d238b..fcca7d6072a88ad6bc6bccd4932eef2a1b0a153f 100644 --- a/src/mednet/libs/segmentation/models/m2unet.py +++ b/src/mednet/libs/segmentation/models/m2unet.py @@ -139,6 +139,10 @@ class M2UNET(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -157,6 +161,8 @@ class M2UNET(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -167,6 +173,8 @@ class M2UNET(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py index 8d8aa3ca42757bc415fc9784d309119f6f4599c4..834af45f087e4cc27d9089fc67efe51a6041e5a5 100644 --- a/src/mednet/libs/segmentation/models/segmentation_model.py +++ b/src/mednet/libs/segmentation/models/segmentation_model.py @@ -34,6 +34,10 @@ class SegmentationModel(Model): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -50,6 +54,8 @@ class SegmentationModel(Model): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -59,6 +65,8 @@ class SegmentationModel(Model): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes, diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py index 6d9a76da8aea735d7cee850955698d71392356e3..cb1e34cfa7c07fcbafa43a59269412703f4be3ef 100644 --- a/src/mednet/libs/segmentation/models/unet.py +++ b/src/mednet/libs/segmentation/models/unet.py @@ -80,6 +80,10 @@ class Unet(SegmentationModel): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. model_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -98,6 +102,8 @@ class Unet(SegmentationModel): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, @@ -108,6 +114,8 @@ class Unet(SegmentationModel): loss_arguments, optimizer_type, optimizer_arguments, + scheduler_type, + scheduler_arguments, model_transforms, augmentation_transforms, num_classes,