Skip to content
Snippets Groups Projects
Commit 993945b7 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[models] Add support for learning rate schedulers

parent a3cc8884
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 84 additions and 2 deletions
......@@ -18,6 +18,8 @@ model = Alexnet(
loss_type=BCEWithLogitsLoss,
optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1),
scheduler_type=None,
scheduler_arguments=dict(),
pretrained=False,
model_transforms=[
SquareCenterPad(),
......
......@@ -20,6 +20,8 @@ model = Alexnet(
loss_type=BCEWithLogitsLoss,
optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1),
scheduler_type=None,
scheduler_arguments=dict(),
pretrained=True,
model_transforms=[
SquareCenterPad(),
......
......@@ -16,6 +16,8 @@ model = Densenet(
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
scheduler_type=None,
scheduler_arguments=dict(),
pretrained=False,
dropout=0.1,
model_transforms=[],
......
......@@ -20,6 +20,8 @@ model = Densenet(
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
scheduler_type=None,
scheduler_arguments=dict(),
pretrained=True,
dropout=0.1,
model_transforms=[
......
......@@ -17,6 +17,8 @@ model = Densenet(
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
scheduler_type=None,
scheduler_arguments=dict(),
pretrained=False,
dropout=0.1,
num_classes=14, # number of classes in NIH CXR-14
......
......@@ -20,6 +20,8 @@ model = Pasa(
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=8e-5),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[
Grayscale(),
SquareCenterPad(),
......
......@@ -37,6 +37,10 @@ class Alexnet(ClassificationModel):
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
scheduler_type
The type of scheduler to use for training.
scheduler_arguments
Arguments to the scheduler after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
......@@ -56,6 +60,8 @@ class Alexnet(ClassificationModel):
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None,
scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [],
pretrained: bool = False,
......@@ -66,6 +72,8 @@ class Alexnet(ClassificationModel):
loss_arguments,
optimizer_type,
optimizer_arguments,
scheduler_type,
scheduler_arguments,
model_transforms,
augmentation_transforms,
num_classes,
......
......@@ -33,6 +33,10 @@ class ClassificationModel(Model):
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
scheduler_type
The type of scheduler to use for training.
scheduler_arguments
Arguments to the scheduler after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
......@@ -49,6 +53,8 @@ class ClassificationModel(Model):
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None,
scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
......@@ -58,6 +64,8 @@ class ClassificationModel(Model):
loss_arguments,
optimizer_type,
optimizer_arguments,
scheduler_type,
scheduler_arguments,
model_transforms,
augmentation_transforms,
num_classes,
......
......@@ -35,6 +35,10 @@ class Densenet(ClassificationModel):
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
scheduler_type
The type of scheduler to use for training.
scheduler_arguments
Arguments to the scheduler after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
......@@ -56,6 +60,8 @@ class Densenet(ClassificationModel):
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None,
scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [],
pretrained: bool = False,
......@@ -67,6 +73,8 @@ class Densenet(ClassificationModel):
loss_arguments,
optimizer_type,
optimizer_arguments,
scheduler_type,
scheduler_arguments,
model_transforms,
augmentation_transforms,
num_classes,
......
......@@ -40,6 +40,10 @@ class Pasa(ClassificationModel):
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
scheduler_type
The type of scheduler to use for training.
scheduler_arguments
Arguments to the scheduler after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
......@@ -56,6 +60,8 @@ class Pasa(ClassificationModel):
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None,
scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
......@@ -65,6 +71,8 @@ class Pasa(ClassificationModel):
loss_arguments,
optimizer_type,
optimizer_arguments,
scheduler_type,
scheduler_arguments,
model_transforms,
augmentation_transforms,
num_classes,
......
......@@ -141,7 +141,9 @@ class LoggingCallback(lightning.pytorch.Callback):
epoch_time = time.time() - self._start_training_epoch_time
self._to_log["epoch-duration-seconds/train"] = epoch_time
self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] # type: ignore
self._to_log["learning-rate"] = pl_module.trainer.lr_scheduler_configs[
0
].scheduler.optimizer.param_groups[0]["lr"] # type: ignore
overall_cycle_time = time.time() - self._start_training_epoch_time
self._to_log["cycle-time-seconds/train"] = overall_cycle_time
......
......@@ -8,6 +8,7 @@ import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.optim.lr_scheduler
import torch.optim.optimizer
import torch.utils.data
import torchvision.transforms
......@@ -37,6 +38,10 @@ class Model(pl.LightningModule):
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
scheduler_type
The type of scheduler to use for training.
scheduler_arguments
Arguments to the scheduler after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
......@@ -53,6 +58,8 @@ class Model(pl.LightningModule):
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None,
scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
......@@ -75,6 +82,9 @@ class Model(pl.LightningModule):
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._scheduler_type = scheduler_type
self._scheduler_arguments = scheduler_arguments
self.augmentation_transforms = augmentation_transforms
@property
......@@ -146,11 +156,20 @@ class Model(pl.LightningModule):
self._validation_loss = self._loss_type(**self._validation_loss_arguments)
def configure_optimizers(self):
return self._optimizer_type(
optimizer = self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
if self._scheduler_type is None:
return optimizer
scheduler = self._scheduler_type(
optimizer,
**self._scheduler_arguments,
)
return [optimizer], [scheduler]
def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
"""Move model, augmentations and losses to specified device.
......
......@@ -41,6 +41,8 @@ model = DRIU(
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[
resize_transform,
SquareCenterPad(),
......
......@@ -41,6 +41,8 @@ model = DRIUBN(
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[
resize_transform,
SquareCenterPad(),
......
......@@ -41,6 +41,8 @@ model = DRIUOD(
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[
resize_transform,
SquareCenterPad(),
......
......@@ -41,6 +41,8 @@ model = DRIUPix(
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[
resize_transform,
SquareCenterPad(),
......
......@@ -32,6 +32,8 @@ model = HED(
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[
resize_transform,
SquareCenterPad(),
......
......@@ -13,6 +13,7 @@ from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
from mednet.libs.segmentation.models.lwnet import LittleWNet
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
max_lr = 0.01 # start
min_lr = 1e-08 # valley
......@@ -24,6 +25,8 @@ model = LittleWNet(
loss_type=MultiWeightedBCELogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=max_lr),
scheduler_type=CosineAnnealingLR,
scheduler_arguments=dict(T_max=cycle, eta_min=min_lr),
model_transforms=[
resize_transform,
SquareCenterPad(),
......
......@@ -45,6 +45,8 @@ model = M2UNET(
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[
resize_transform,
SquareCenterPad(),
......
......@@ -43,6 +43,8 @@ model = Unet(
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[
resize_transform,
SquareCenterPad(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment