Skip to content
Snippets Groups Projects
Commit 821f5c6f authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[config.models] Restore model transforms; Unify transform strategies

parent b4d25c13
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 186 additions and 701 deletions
...@@ -3,28 +3,25 @@ ...@@ -3,28 +3,25 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""AlexNet_, to be trained from scratch. """AlexNet_, to be trained from scratch.
This configuration contains a version of AlexNet_ (c.f. `TorchVision's This configuration contains a version of AlexNet_ (c.f. `TorchVision's page
page <alexnet-pytorch_>`_), modified for a variable number of outputs <alexnet-pytorch_>`_), modified for a variable number of outputs (defaults to
(defaults to 1). 1).
""" """
import mednet.libs.classification.models.alexnet
import mednet.libs.common.models.transforms
import torch.nn
import torch.optim
import torchvision.transforms import torchvision.transforms
from mednet.libs.classification.models.alexnet import Alexnet
from mednet.libs.common.models.transforms import RGB, SquareCenterPad
from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
model = Alexnet( model = mednet.libs.classification.models.alexnet.Alexnet(
loss_type=BCEWithLogitsLoss, loss_type=torch.nn.BCEWithLogitsLoss,
optimizer_type=SGD, optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.01, momentum=0.1), optimizer_arguments=dict(lr=0.001),
scheduler_type=None,
scheduler_arguments=dict(),
pretrained=False, pretrained=False,
model_transforms=[ model_transforms=[
SquareCenterPad(), mednet.libs.common.models.transforms.SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True), torchvision.transforms.Resize(512, antialias=True),
RGB(), mednet.libs.common.models.transforms.RGB(),
], ],
augmentation_transforms=[],
) )
...@@ -4,28 +4,26 @@ ...@@ -4,28 +4,26 @@
"""AlexNet_, to be fine-tuned. Pre-trained on ImageNet_. """AlexNet_, to be fine-tuned. Pre-trained on ImageNet_.
This configuration contains a version of AlexNet_ (c.f. `TorchVision's This configuration contains a version of AlexNet_ (c.f. `TorchVision's
page <alexnet_pytorch_>`), modified for a variable number of outputs page <alexnet-pytorch_>`_), modified for a variable number of outputs
(defaults to 1). (defaults to 1).
N.B.: The output layer is **always** initialized from scratch. N.B.: The output layer is **always** initialized from scratch.
""" """
import mednet.libs.classification.models.alexnet
import mednet.libs.common.models.transforms
import torch.nn
import torch.optim
import torchvision.transforms import torchvision.transforms
from mednet.libs.classification.models.alexnet import Alexnet
from mednet.libs.common.models.transforms import RGB, SquareCenterPad
from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
model = Alexnet( model = mednet.libs.classification.models.alexnet.Alexnet(
loss_type=BCEWithLogitsLoss, loss_type=torch.nn.BCEWithLogitsLoss,
optimizer_type=SGD, optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.01, momentum=0.1), optimizer_arguments=dict(lr=0.0001),
scheduler_type=None,
scheduler_arguments=dict(),
pretrained=True, pretrained=True,
model_transforms=[ model_transforms=[
SquareCenterPad(), mednet.libs.common.models.transforms.SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True), torchvision.transforms.Resize(512, antialias=True),
RGB(), mednet.libs.common.models.transforms.RGB(),
], ],
) )
...@@ -8,18 +8,21 @@ page <densenet_pytorch_>`), modified for a variable number of outputs ...@@ -8,18 +8,21 @@ page <densenet_pytorch_>`), modified for a variable number of outputs
(defaults to 1). (defaults to 1).
""" """
from mednet.libs.classification.models.densenet import Densenet import mednet.libs.classification.models.densenet
from torch.nn import BCEWithLogitsLoss import mednet.libs.common.models.transforms
from torch.optim import Adam import torch.nn
import torch.optim
import torchvision.transforms
model = Densenet( model = mednet.libs.classification.models.densenet.Densenet(
loss_type=BCEWithLogitsLoss, loss_type=torch.nn.BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
scheduler_type=None,
scheduler_arguments=dict(),
pretrained=False, pretrained=False,
dropout=0.1, dropout=0.1,
model_transforms=[], model_transforms=[
augmentation_transforms=[], mednet.libs.common.models.transforms.SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True),
mednet.libs.common.models.transforms.RGB(),
],
) )
...@@ -10,28 +10,21 @@ page <alexnet_pytorch_>`), modified for a variable number of outputs ...@@ -10,28 +10,21 @@ page <alexnet_pytorch_>`), modified for a variable number of outputs
N.B.: The output layer is **always** initialized from scratch. N.B.: The output layer is **always** initialized from scratch.
""" """
import mednet.libs.classification.models.densenet
import mednet.libs.common.models.transforms
import torch.nn
import torch.optim
import torchvision.transforms import torchvision.transforms
from mednet.libs.classification.models.densenet import Densenet
from mednet.libs.common.models.transforms import RGB, SquareCenterPad
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
model = Densenet( model = mednet.libs.classification.models.densenet.Densenet(
loss_type=BCEWithLogitsLoss, loss_type=torch.nn.BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
scheduler_type=None,
scheduler_arguments=dict(),
pretrained=True, pretrained=True,
dropout=0.1, dropout=0.1,
model_transforms=[ model_transforms=[
SquareCenterPad(), mednet.libs.common.models.transforms.SquareCenterPad(),
torchvision.transforms.Resize( torchvision.transforms.Resize(512, antialias=True),
512, mednet.libs.common.models.transforms.RGB(),
antialias=True,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
),
RGB(),
], ],
augmentation_transforms=[],
) )
...@@ -9,19 +9,22 @@ page <densenet_pytorch_>`), modified to have exactly 14 outputs ...@@ -9,19 +9,22 @@ page <densenet_pytorch_>`), modified to have exactly 14 outputs
weights from scratch for radiological sign detection. weights from scratch for radiological sign detection.
""" """
from mednet.libs.classification.models.densenet import Densenet import mednet.libs.classification.models.densenet
from torch.nn import BCEWithLogitsLoss import mednet.libs.common.models.transforms
from torch.optim import Adam import torch.nn
import torch.optim
import torchvision.transforms
model = Densenet( model = mednet.libs.classification.models.densenet.Densenet(
loss_type=BCEWithLogitsLoss, loss_type=torch.nn.BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
scheduler_type=None,
scheduler_arguments=dict(),
pretrained=False, pretrained=False,
dropout=0.1, dropout=0.1,
num_classes=14, # number of classes in NIH CXR-14 num_classes=14, # number of classes in NIH CXR-14
model_transforms=[], model_transforms=[
augmentation_transforms=[], mednet.libs.common.models.transforms.SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True),
mednet.libs.common.models.transforms.RGB(),
],
) )
...@@ -10,26 +10,20 @@ Screening and Visualization". ...@@ -10,26 +10,20 @@ Screening and Visualization".
Reference: [PASA-2019]_ Reference: [PASA-2019]_
""" """
import mednet.libs.classification.models.pasa
import mednet.libs.common.models.transforms
import torch.nn
import torch.optim
import torchvision.transforms import torchvision.transforms
from mednet.libs.classification.models.pasa import Pasa
from mednet.libs.common.models.transforms import Grayscale, SquareCenterPad
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
model = Pasa( model = mednet.libs.classification.models.pasa.Pasa(
loss_type=BCEWithLogitsLoss, loss_type=torch.nn.BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=torch.optim.Adam,
optimizer_arguments=dict(lr=8e-5), optimizer_arguments=dict(lr=8e-5),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[ model_transforms=[
Grayscale(), mednet.libs.common.models.transforms.Grayscale(),
SquareCenterPad(), mednet.libs.common.models.transforms.SquareCenterPad(),
torchvision.transforms.Resize( torchvision.transforms.Resize(512, antialias=True),
512,
antialias=True,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
),
], ],
augmentation_transforms=[], augmentation_transforms=[],
) )
...@@ -60,39 +60,6 @@ def crop_multiple_images_to_mask( ...@@ -60,39 +60,6 @@ def crop_multiple_images_to_mask(
return [crop_image_to_mask(img, mask) for img in images] return [crop_image_to_mask(img, mask) for img in images]
def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor:
"""Resize image based on the longest side while keeping the aspect ratio.
Parameters
----------
tensor
The tensor to resize.
max_side
The new length of the largest side.
Returns
-------
The resized image.
"""
from torchvision.transforms import InterpolationMode
if max_side <= 0:
raise ValueError(f"The new max side ({max_side}) must be positive.")
height, width = tensor.shape[-2:]
aspect_ratio = float(height) / float(width)
if height >= width:
new_size = (max_side, int(max_side / aspect_ratio))
else:
new_size = (int(max_side * aspect_ratio), max_side)
return torchvision.transforms.Resize(
new_size, interpolation=InterpolationMode.NEAREST, antialias=True
)(tensor)
def square_center_pad(img: torch.Tensor) -> torch.Tensor: def square_center_pad(img: torch.Tensor) -> torch.Tensor:
"""Return a squared version of the image, centered on a canvas padded with """Return a squared version of the image, centered on a canvas padded with
zeros. zeros.
...@@ -216,23 +183,6 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor: ...@@ -216,23 +183,6 @@ def rgb_to_grayscale(img: torch.Tensor) -> torch.Tensor:
return torchvision.transforms.functional.rgb_to_grayscale(img) return torchvision.transforms.functional.rgb_to_grayscale(img)
class ResizeMaxSide(torch.nn.Module):
"""Resize image on the longest side while keeping the aspect ratio.
Parameters
----------
max_side
The new length of the largest side.
"""
def __init__(self, max_side: int):
super().__init__()
self.max_side = max_side
def forward(self, img: torch.Tensor) -> torch.Tensor:
return resize_max_side(img, self.max_side)
class SquareCenterPad(torch.nn.Module): class SquareCenterPad(torch.nn.Module):
"""Transform to a squared version of the image, centered on a canvas padded """Transform to a squared version of the image, centered on a canvas padded
with zeros. with zeros.
......
...@@ -11,41 +11,21 @@ deep Convolutional Neural Networks (CNNs). ...@@ -11,41 +11,21 @@ deep Convolutional Neural Networks (CNNs).
Reference: [MANINIS-2016]_ Reference: [MANINIS-2016]_
""" """
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad import mednet.libs.common.models.transforms
from mednet.libs.segmentation.engine.adabound import AdaBound import mednet.libs.segmentation.models.driu
from mednet.libs.segmentation.models.driu import DRIU import mednet.libs.segmentation.models.losses
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss import torch.optim
import torchvision.transforms
lr = 0.001 model = mednet.libs.segmentation.models.driu.DRIU(
alpha = 0.7 loss_type=mednet.libs.segmentation.models.losses.SoftJaccardBCELogitsLoss,
betas = (0.9, 0.999) loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
eps = 1e-08 optimizer_type=torch.optim.Adam,
weight_decay = 0 optimizer_arguments=dict(lr=0.01),
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
resize_transform = ResizeMaxSide(512)
model = DRIU(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
optimizer_arguments=dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[ model_transforms=[
resize_transform, mednet.libs.common.models.transforms.SquareCenterPad(),
SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True),
mednet.libs.common.models.transforms.RGB(),
], ],
augmentation_transforms=[], pretrained=False,
) )
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""DRIU Network for Vessel Segmentation. """DRIU with Batch Normalization Network for Vessel Segmentation.
Deep Retinal Image Understanding (DRIU), a unified framework of retinal image Deep Retinal Image Understanding (DRIU), a unified framework of retinal image
analysis that provides both retinal vessel and optic disc segmentation using analysis that provides both retinal vessel and optic disc segmentation using
...@@ -11,41 +11,21 @@ deep Convolutional Neural Networks (CNNs). ...@@ -11,41 +11,21 @@ deep Convolutional Neural Networks (CNNs).
Reference: [MANINIS-2016]_ Reference: [MANINIS-2016]_
""" """
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad import mednet.libs.common.models.transforms
from mednet.libs.segmentation.engine.adabound import AdaBound import mednet.libs.segmentation.models.driu_bn
from mednet.libs.segmentation.models.driu_bn import DRIUBN import mednet.libs.segmentation.models.losses
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss import torch.optim
import torchvision.transforms
lr = 0.001 model = mednet.libs.segmentation.models.driu_bn.DRIUBN(
alpha = 0.7 loss_type=mednet.libs.segmentation.models.losses.SoftJaccardBCELogitsLoss,
betas = (0.9, 0.999) loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
eps = 1e-08 optimizer_type=torch.optim.Adam,
weight_decay = 0 optimizer_arguments=dict(lr=0.01),
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
resize_transform = ResizeMaxSide(512)
model = DRIUBN(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
optimizer_arguments=dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[ model_transforms=[
resize_transform, mednet.libs.common.models.transforms.SquareCenterPad(),
SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True),
mednet.libs.common.models.transforms.RGB(),
], ],
augmentation_transforms=[], pretrained=False,
) )
...@@ -11,41 +11,21 @@ deep Convolutional Neural Networks (CNNs). ...@@ -11,41 +11,21 @@ deep Convolutional Neural Networks (CNNs).
Reference: [MANINIS-2016]_ Reference: [MANINIS-2016]_
""" """
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad import mednet.libs.common.models.transforms
from mednet.libs.segmentation.engine.adabound import AdaBound import mednet.libs.segmentation.models.driu_od
from mednet.libs.segmentation.models.driu_od import DRIUOD import mednet.libs.segmentation.models.losses
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss import torch.optim
import torchvision.transforms
lr = 0.001 model = mednet.libs.segmentation.models.driu_od.DRIUOD(
alpha = 0.7 loss_type=mednet.libs.segmentation.models.losses.SoftJaccardBCELogitsLoss,
betas = (0.9, 0.999) loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
eps = 1e-08 optimizer_type=torch.optim.Adam,
weight_decay = 0 optimizer_arguments=dict(lr=0.01),
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
resize_transform = ResizeMaxSide(512)
model = DRIUOD(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
optimizer_arguments=dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[ model_transforms=[
resize_transform, mednet.libs.common.models.transforms.SquareCenterPad(),
SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True),
mednet.libs.common.models.transforms.RGB(),
], ],
augmentation_transforms=[], pretrained=False,
) )
...@@ -11,41 +11,21 @@ deep Convolutional Neural Networks (CNNs). ...@@ -11,41 +11,21 @@ deep Convolutional Neural Networks (CNNs).
Reference: [MANINIS-2016]_ Reference: [MANINIS-2016]_
""" """
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad import mednet.libs.common.models.transforms
from mednet.libs.segmentation.engine.adabound import AdaBound import mednet.libs.segmentation.models.driu_pix
from mednet.libs.segmentation.models.driu_pix import DRIUPix import mednet.libs.segmentation.models.losses
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss import torch.optim
import torchvision.transforms
lr = 0.001 model = mednet.libs.segmentation.models.driu_pix.DRIUPix(
alpha = 0.7 loss_type=mednet.libs.segmentation.models.losses.SoftJaccardBCELogitsLoss,
betas = (0.9, 0.999) loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
eps = 1e-08 optimizer_type=torch.optim.Adam,
weight_decay = 0 optimizer_arguments=dict(lr=0.01),
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
resize_transform = ResizeMaxSide(512)
model = DRIUPix(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
optimizer_arguments=dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[ model_transforms=[
resize_transform, mednet.libs.common.models.transforms.SquareCenterPad(),
SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True),
mednet.libs.common.models.transforms.RGB(),
], ],
augmentation_transforms=[], pretrained=False,
) )
...@@ -2,41 +2,26 @@ ...@@ -2,41 +2,26 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad """HED Network for Segmentation.
from mednet.libs.segmentation.engine.adabound import AdaBound
from mednet.libs.segmentation.models.hed import HED
from mednet.libs.segmentation.models.losses import MultiSoftJaccardBCELogitsLoss
lr = 0.001 Reference: [XIE-2015]_
alpha = 0.7 """
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
resize_transform = ResizeMaxSide(512) import mednet.libs.common.models.transforms
import mednet.libs.segmentation.models.hed
import mednet.libs.segmentation.models.losses
import torch.optim
import torchvision.transforms
model = HED( model = mednet.libs.segmentation.models.hed.HED(
loss_type=MultiSoftJaccardBCELogitsLoss, loss_type=mednet.libs.segmentation.models.losses.MultiSoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha), loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
optimizer_type=AdaBound, optimizer_type=torch.optim.Adam,
optimizer_arguments=dict( optimizer_arguments=dict(lr=0.01),
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[ model_transforms=[
resize_transform, mednet.libs.common.models.transforms.SquareCenterPad(),
SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True),
mednet.libs.common.models.transforms.RGB(),
], ],
augmentation_transforms=[], pretrained=False,
) )
...@@ -9,11 +9,12 @@ closely matches (or outperforms) other more complex techniques. ...@@ -9,11 +9,12 @@ closely matches (or outperforms) other more complex techniques.
Reference: [GALDRAN-2020]_ Reference: [GALDRAN-2020]_
""" """
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad import mednet.libs.common.models.transforms
from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss import mednet.libs.segmentation.models.losses
from mednet.libs.segmentation.models.lwnet import LittleWNet import mednet.libs.segmentation.models.lwnet
from torch.optim import Adam import torch.optim
from torch.optim.lr_scheduler import CosineAnnealingLR import torch.optim.lr_scheduler
import torchvision.transforms
max_lr = 0.01 # start max_lr = 0.01 # start
min_lr = 1e-08 # valley min_lr = 1e-08 # valley
...@@ -21,17 +22,16 @@ min_lr = 1e-08 # valley ...@@ -21,17 +22,16 @@ min_lr = 1e-08 # valley
# About 20 * len(train-data-loader) # About 20 * len(train-data-loader)
cycle = 100 # 1/2 epochs for a complete scheduling cycle cycle = 100 # 1/2 epochs for a complete scheduling cycle
resize_transform = ResizeMaxSide(512) model = mednet.libs.segmentation.models.lwnet.LittleWNet(
loss_type=mednet.libs.segmentation.models.losses.MultiWeightedBCELogitsLoss,
model = LittleWNet( loss_arguments=dict(),
loss_type=MultiWeightedBCELogitsLoss, optimizer_type=torch.optim.Adam,
optimizer_type=Adam,
optimizer_arguments=dict(lr=max_lr), optimizer_arguments=dict(lr=max_lr),
scheduler_type=CosineAnnealingLR, scheduler_type=torch.optim.lr_scheduler.CosineAnnealingLR,
scheduler_arguments=dict(T_max=cycle, eta_min=min_lr), scheduler_arguments=dict(T_max=cycle, eta_min=min_lr),
model_transforms=[ model_transforms=[
resize_transform, mednet.libs.common.models.transforms.SquareCenterPad(),
SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True),
mednet.libs.common.models.transforms.RGB(),
], ],
augmentation_transforms=[],
) )
...@@ -15,41 +15,21 @@ segmentation applications and the speed of MobileNetV2 networks. ...@@ -15,41 +15,21 @@ segmentation applications and the speed of MobileNetV2 networks.
References: [SANDLER-2018]_, [RONNEBERGER-2015]_ References: [SANDLER-2018]_, [RONNEBERGER-2015]_
""" """
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad import mednet.libs.common.models.transforms
from mednet.libs.segmentation.engine.adabound import AdaBound import mednet.libs.segmentation.models.losses
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss import mednet.libs.segmentation.models.m2unet
from mednet.libs.segmentation.models.m2unet import M2Unet import torch.optim
import torchvision.transforms
lr = 0.001 model = mednet.libs.segmentation.models.m2unet.M2Unet(
alpha = 0.7 loss_type=mednet.libs.segmentation.models.losses.SoftJaccardBCELogitsLoss,
betas = (0.9, 0.999) loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
eps = 1e-08 optimizer_type=torch.optim.Adam,
weight_decay = 0 optimizer_arguments=dict(lr=0.01),
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
resize_transform = ResizeMaxSide(512)
model = M2Unet(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
optimizer_arguments=dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[ model_transforms=[
resize_transform, mednet.libs.common.models.transforms.SquareCenterPad(),
SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True),
mednet.libs.common.models.transforms.RGB(),
], ],
augmentation_transforms=[], pretrained=False,
) )
...@@ -13,41 +13,21 @@ to yield more precise segmentations. ...@@ -13,41 +13,21 @@ to yield more precise segmentations.
Reference: [RONNEBERGER-2015]_ Reference: [RONNEBERGER-2015]_
""" """
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad import mednet.libs.common.models.transforms
from mednet.libs.segmentation.engine.adabound import AdaBound import mednet.libs.segmentation.models.losses
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss import mednet.libs.segmentation.models.unet
from mednet.libs.segmentation.models.unet import Unet import torch.optim
import torchvision.transforms
lr = 0.001 model = mednet.libs.segmentation.models.unet.Unet(
alpha = 0.7 loss_type=mednet.libs.segmentation.models.losses.SoftJaccardBCELogitsLoss,
betas = (0.9, 0.999) loss_arguments=dict(alpha=0.7), # 0.7 BCE + 0.3 Jaccard
eps = 1e-08 optimizer_type=torch.optim.Adam,
weight_decay = 0 optimizer_arguments=dict(lr=0.01),
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
resize_transform = ResizeMaxSide(512)
model = Unet(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
optimizer_arguments=dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
),
scheduler_type=None,
scheduler_arguments=dict(),
model_transforms=[ model_transforms=[
resize_transform, mednet.libs.common.models.transforms.SquareCenterPad(),
SquareCenterPad(), torchvision.transforms.Resize(512, antialias=True),
mednet.libs.common.models.transforms.RGB(),
], ],
augmentation_transforms=[], pretrained=False,
) )
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Implementation of the AdaBound optimizer.
<https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py>::
@inproceedings{Luo2019AdaBound,
author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu},
title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate},
booktitle = {Proceedings of the 7th International Conference on Learning
Representations},
month = {May},
year = {2019},
address = {New Orleans, Louisiana}
}
"""
import math
import typing
import torch
import torch.optim
class AdaBound(torch.optim.Optimizer):
"""Implement the AdaBound algorithm.
Parameters
----------
params
Iterable of parameters to optimize or dicts defining parameter groups.
lr
Adam learning rate.
betas
Coefficients (as a 2-tuple of floats) used for computing running
averages of gradient and its square.
final_lr
Final (SGD) learning rate.
gamma
Convergence speed of the bound functions.
eps
Term added to the denominator to improve numerical stability.
weight_decay
Weight decay (L2 penalty).
amsbound
Whether to use the AMSBound variant of this algorithm.
"""
def __init__(
self,
params: list,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
final_lr: float = 0.1,
gamma: float = 1e-3,
eps: float = 1e-8,
weight_decay: float = 0,
amsbound: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= final_lr:
raise ValueError(f"Invalid final learning rate: {final_lr}")
if not 0.0 <= gamma < 1.0:
raise ValueError(f"Invalid gamma parameter: {gamma}")
defaults = dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
super().__init__(params, defaults)
self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("amsbound", False)
def step(self, closure: typing.Callable | None = None):
"""Perform a single optimization step.
Parameters
----------
closure
A closure that reevaluates the model and returns the loss.
Returns
-------
The loss.
"""
loss = None
if closure is not None:
loss = closure()
for group, base_lr in zip(self.param_groups, self.base_lrs):
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider "
"SparseAdam instead"
)
amsbound = group["amsbound"]
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
if amsbound:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsbound:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
if group["weight_decay"] != 0:
grad = grad.add(group["weight_decay"], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsbound:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else:
denom = exp_avg_sq.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
# Applies bounds on actual learning rate. lr_scheduler cannot
# affect final_lr, this is a workaround to apply lr decay
final_lr = group["final_lr"] * group["lr"] / base_lr
lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1))
upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"]))
step_size = torch.full_like(denom, step_size)
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
p.data.add_(-step_size)
return loss
class AdaBoundW(torch.optim.Optimizer):
"""Implement AdaBound algorithm with Decoupled Weight Decay (See
https://arxiv.org/abs/1711.05101).
Parameters
----------
params
Iterable of parameters to optimize or dicts defining parameter groups.
lr
Adam learning rate.
betas
Coefficients (as a 2-tuple of floats) used for computing running
averages of gradient and its square.
final_lr
Final (SGD) learning rate.
gamma
Convergence speed of the bound functions.
eps
Term added to the denominator to improve numerical stability.
weight_decay
Weight decay (L2 penalty).
amsbound
Whether to use the AMSBound variant of this algorithm.
"""
def __init__(
self,
params: list,
lr: float | None = 1e-3,
betas: tuple[float, float] | None = (0.9, 0.999),
final_lr: float | None = 0.1,
gamma: float | None = 1e-3,
eps: float | None = 1e-8,
weight_decay: float | None = 0,
amsbound: bool | None = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= final_lr:
raise ValueError(f"Invalid final learning rate: {final_lr}")
if not 0.0 <= gamma < 1.0:
raise ValueError(f"Invalid gamma parameter: {gamma}")
defaults = dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
super().__init__(params, defaults)
self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("amsbound", False)
def step(self, closure: typing.Callable | None = None):
"""Perform a single optimization step.
Parameters
----------
closure
A closure that reevaluates the model and returns the loss.
Returns
-------
The loss.
"""
loss = None
if closure is not None:
loss = closure()
for group, base_lr in zip(self.param_groups, self.base_lrs):
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider "
"SparseAdam instead"
)
amsbound = group["amsbound"]
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
if amsbound:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsbound:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsbound:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else:
denom = exp_avg_sq.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
# Applies bounds on actual learning rate
# lr_scheduler cannot affect final_lr, this is a workaround to
# apply lr decay
final_lr = group["final_lr"] * group["lr"] / base_lr
lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1))
upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"]))
step_size = torch.full_like(denom, step_size)
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
if group["weight_decay"] != 0:
decayed_weights = torch.mul(p.data, group["weight_decay"])
p.data.add_(-step_size)
p.data.sub_(decayed_weights)
else:
p.data.add_(-step_size)
return loss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment