Skip to content
Snippets Groups Projects
Commit 0049aab9 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation.models] Add unet model

parent 07080b2e
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -422,6 +422,7 @@ visceral = "mednet.config.data.visceral.default"
[project.entry-points."mednet.libs.segmentation.config"]
lwnet = "mednet.libs.segmentation.config.models.lwnet"
unet = "mednet.libs.segmentation.config.models.unet"
# chase-db1 - retinography
chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator"
......@@ -546,6 +547,7 @@ exclude = [ # don't report on objects that match any of these regex
'\.__len__$',
'\.__getitem__$',
'\.__iter__$',
'\.__setstate__$',
'\.__exit__$',
]
......
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Little W-Net for image segmentation.
The Little W-Net architecture contains roughly around 70k parameters and
closely matches (or outperforms) other more complex techniques.
Reference: [GALDRAN-2020]_
"""
from mednet.libs.segmentation.engine.adabound import AdaBound
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
from mednet.libs.segmentation.models.unet import Unet
lr = 0.01 # start
alpha = 0.7
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
model = Unet(
loss_type=SoftJaccardBCELogitsLoss,
loss_arguments=dict(alpha=alpha),
optimizer_type=AdaBound,
optimizer_arguments=dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
),
augmentation_transforms=[],
crop_size=1024,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Implementation of the AdaBound optimizer.
<https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py>::
@inproceedings{Luo2019AdaBound,
author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu},
title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate},
booktitle = {Proceedings of the 7th International Conference on Learning Representations},
month = {May},
year = {2019},
address = {New Orleans, Louisiana}
}
"""
import math
import typing
import torch
import torch.optim
class AdaBound(torch.optim.Optimizer):
"""Implement the AdaBound algorithm.
Parameters
----------
params
Iterable of parameters to optimize or dicts defining parameter groups.
lr
Adam learning rate.
betas
Coefficients (as a 2-tuple of floats) used for computing running
averages of gradient and its square.
final_lr
Final (SGD) learning rate.
gamma
Convergence speed of the bound functions.
eps
Term added to the denominator to improve numerical stability.
weight_decay
Weight decay (L2 penalty).
amsbound
Whether to use the AMSBound variant of this algorithm.
"""
def __init__(
self,
params: list,
lr: float | None = 1e-3,
betas: tuple[float, float] | None = (0.9, 0.999),
final_lr: float | None = 0.1,
gamma: float | None = 1e-3,
eps: float | None = 1e-8,
weight_decay: float | None = 0,
amsbound: bool | None = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= final_lr:
raise ValueError(f"Invalid final learning rate: {final_lr}")
if not 0.0 <= gamma < 1.0:
raise ValueError(f"Invalid gamma parameter: {gamma}")
defaults = dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
super().__init__(params, defaults)
self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("amsbound", False)
def step(self, closure: typing.Callable | None = None):
"""Perform a single optimization step.
Parameters
----------
closure
A closure that reevaluates the model and returns the loss.
Returns
-------
The loss.
"""
loss = None
if closure is not None:
loss = closure()
for group, base_lr in zip(self.param_groups, self.base_lrs):
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
amsbound = group["amsbound"]
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
if amsbound:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsbound:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
if group["weight_decay"] != 0:
grad = grad.add(group["weight_decay"], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsbound:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else:
denom = exp_avg_sq.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
# Applies bounds on actual learning rate
# lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
final_lr = group["final_lr"] * group["lr"] / base_lr
lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1))
upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"]))
step_size = torch.full_like(denom, step_size)
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
p.data.add_(-step_size)
return loss
class AdaBoundW(torch.optim.Optimizer):
"""Implement AdaBound algorithm with Decoupled Weight Decay (See
https://arxiv.org/abs/1711.05101).
Parameters
----------
params
Iterable of parameters to optimize or dicts defining parameter groups.
lr
Adam learning rate.
betas
Coefficients (as a 2-tuple of floats) used for computing running
averages of gradient and its square.
final_lr
Final (SGD) learning rate.
gamma
Convergence speed of the bound functions.
eps
Term added to the denominator to improve numerical stability.
weight_decay
Weight decay (L2 penalty).
amsbound
Whether to use the AMSBound variant of this algorithm.
"""
def __init__(
self,
params: list,
lr: float | None = 1e-3,
betas: tuple[float, float] | None = (0.9, 0.999),
final_lr: float | None = 0.1,
gamma: float | None = 1e-3,
eps: float | None = 1e-8,
weight_decay: float | None = 0,
amsbound: bool | None = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= final_lr:
raise ValueError(f"Invalid final learning rate: {final_lr}")
if not 0.0 <= gamma < 1.0:
raise ValueError(f"Invalid gamma parameter: {gamma}")
defaults = dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
super().__init__(params, defaults)
self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("amsbound", False)
def step(self, closure: typing.Callable | None = None):
"""Perform a single optimization step.
Parameters
----------
closure
A closure that reevaluates the model and returns the loss.
Returns
-------
The loss.
"""
loss = None
if closure is not None:
loss = closure()
for group, base_lr in zip(self.param_groups, self.base_lrs):
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
amsbound = group["amsbound"]
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
if amsbound:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsbound:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsbound:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else:
denom = exp_avg_sq.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
# Applies bounds on actual learning rate
# lr_scheduler cannot affect final_lr, this is a workaround to
# apply lr decay
final_lr = group["final_lr"] * group["lr"] / base_lr
lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1))
upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"]))
step_size = torch.full_like(denom, step_size)
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
if group["weight_decay"] != 0:
decayed_weights = torch.mul(p.data, group["weight_decay"])
p.data.add_(-step_size)
p.data.sub_(decayed_weights)
else:
p.data.add_(-step_size)
return loss
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import torchvision.models
import torchvision.models.mobilenetv2
try:
# pytorch >= 1.12
from torch.hub import load_state_dict_from_url
except ImportError:
# pytorch < 1.12
from torchvision.models.utils import load_state_dict_from_url
class MobileNetV24Segmentation(torchvision.models.mobilenetv2.MobileNetV2):
"""Adaptation of base MobileNetV2 functionality to U-Net style
segmentation.
This version of MobileNetV2 is slightly modified so it can be used through
torchvision's API. It outputs intermediate features which are normally not
output by the base MobileNetV2 implementation, but are required for
segmentation operations.
Parameters
----------
*args
Arguments to be passed to the parent MobileNetV2 model.
**kwargs
Keyword arguments to be passed to the parent MobileNetV2 model.
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
"""
def __init__(self, *args, **kwargs):
self._return_features = kwargs.pop("return_features")
super().__init__(*args, **kwargs)
def forward(self, x):
outputs = []
# hw of input, needed for DRIU and HED
outputs.append(x.shape[2:4])
outputs.append(x)
for index, m in enumerate(self.features):
x = m(x)
# extract layers
if index in self._return_features:
outputs.append(x)
return outputs
def mobilenet_v2_for_segmentation(pretrained=False, progress=True, **kwargs):
"""Create MobileNetV2 model for segmentation task.
Parameters
----------
pretrained
If True, uses MobileNetV2 pretrained weights.
progress
If True, shows a progress bar when downloading the pretrained weights.
**kwargs
Keyword arguments to be passed to the parent MobileNetV2 model.
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
Returns
-------
Instance of the MobileNetV2 model for segmentation.
"""
model = MobileNetV24Segmentation(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
torchvision.models.mobilenetv2.MobileNet_V2_Weights.DEFAULT.url,
progress=progress,
)
model.load_state_dict(state_dict)
# erase MobileNetV2 head (for classification), not used for segmentation
delattr(model, "classifier")
return_features = kwargs.get("return_features")
if return_features is not None:
model.features = model.features[: (max(return_features) + 1)]
return model
mobilenet_v2_for_segmentation.__doc__ = torchvision.models.mobilenetv2.__doc__
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import torchvision.models
try:
# pytorch >= 1.12
from torch.hub import load_state_dict_from_url
except ImportError:
# pytorch < 1.12
from torchvision.models.utils import load_state_dict_from_url
class ResNet4Segmentation(torchvision.models.resnet.ResNet):
"""Adaptation of base ResNet functionality to U-Net style segmentation.
This version of ResNet is slightly modified so it can be used through
torchvision's API. It outputs intermediate features which are normally not
output by the base ResNet implementation, but are required for segmentation
operations.
Parameters
----------
*args
Arguments to be passed to the parent ResNet model.
**kwargs
Keyword arguments to be passed to the parent ResNet model.
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
"""
def __init__(self, *args, **kwargs):
self._return_features = kwargs.pop("return_features")
super().__init__(*args, **kwargs)
def forward(self, x):
outputs = []
# hardwiring of input
outputs.append(x.shape[2:4])
for index, m in enumerate(self.features):
x = m(x)
# extract layers
if index in self.return_features:
outputs.append(x)
return outputs
def resnet50_for_segmentation(pretrained=False, progress=True, **kwargs):
"""Create ResNet for segmentation task.
Parameters
----------
pretrained
If True, uses ResNet50 pretrained weights.
progress
If True, shows a progress bar when downloading the pretrained weights.
**kwargs
Keyword arguments to be passed to the parent ResNet model.
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
Returns
-------
Instance of the ResNet model for segmentation.
"""
model = ResNet4Segmentation(
torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
torchvision.models.resnet.ResNet50_Weights.DEFAULT.url,
progress=progress,
)
model.load_state_dict(state_dict)
# erase ResNet head (for classification), not used for segmentation
delattr(model, "avgpool")
delattr(model, "fc")
return model
resnet50_for_segmentation.__doc__ = torchvision.models.resnet50.__doc__
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import torchvision.models
try:
# pytorch >= 1.12
from torch.hub import load_state_dict_from_url
except ImportError:
# pytorch < 1.12
from torchvision.models.utils import load_state_dict_from_url
class VGG4Segmentation(torchvision.models.vgg.VGG):
"""Adaptation of base VGG functionality to U-Net style segmentation.
This version of VGG is slightly modified so it can be used through
torchvision's API. It outputs intermediate features which are normally not
output by the base VGG implementation, but are required for segmentation
operations.
Parameters
----------
*args
Arguments to be passed to the parent VGG model.
**kwargs
Keyword arguments to be passed to the parent VGG model.
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
"""
def __init__(self, *args, **kwargs):
self._return_features = kwargs.pop("return_features")
super().__init__(*args, **kwargs)
def forward(self, x):
outputs = []
# hardwiring of input
outputs.append(x.shape[2:4])
for index, m in enumerate(self.features):
x = m(x)
# extract layers
if index in self._return_features:
outputs.append(x)
return outputs
def _make_vgg16_type_d_for_segmentation(pretrained, batch_norm, progress, **kwargs):
if pretrained:
kwargs["init_weights"] = False
model = VGG4Segmentation(
torchvision.models.vgg.make_layers(
torchvision.models.vgg.cfgs["D"],
batch_norm=batch_norm,
),
**kwargs,
)
if pretrained:
weights = (
torchvision.models.vgg.VGG16_Weights.DEFAULT.url
if not batch_norm
else torchvision.models.vgg.VGG16_BN_Weights.DEFAULT.url
)
state_dict = load_state_dict_from_url(weights, progress=progress)
model.load_state_dict(state_dict)
# erase VGG head (for classification), not used for segmentation
delattr(model, "classifier")
delattr(model, "avgpool")
return model
def vgg16_for_segmentation(pretrained=False, progress=True, **kwargs):
"""Create an instance of VGG16.
Parameters
----------
pretrained
If True, usees VGG16 pretrained weights.
progress
If True, shows a progress bar when downloading weights.
**kwargs
Keyword arguments to be passed to the parent VGG model.
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
Returns
-------
Instance of VGG16.
"""
return _make_vgg16_type_d_for_segmentation(
pretrained=pretrained, batch_norm=False, progress=progress, **kwargs
)
vgg16_for_segmentation.__doc__ = torchvision.models.vgg16.__doc__
def vgg16_bn_for_segmentation(pretrained=False, progress=True, **kwargs):
"""Create an instance of VGG16 with batch norm.
Parameters
----------
pretrained
If True, usees VGG16 pretrained weights.
progress
If True, shows a progress bar when downloading weights.
**kwargs
Keyword arguments to be passed to the parent VGG model.
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
Returns
-------
Instance of VGG16.
"""
return _make_vgg16_type_d_for_segmentation(
pretrained=pretrained, batch_norm=True, progress=progress, **kwargs
)
vgg16_bn_for_segmentation.__doc__ = torchvision.models.vgg16_bn.__doc__
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import torch
import torch.nn
from torch.nn import Conv2d, ConvTranspose2d
def conv_with_kaiming_uniform(
in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
):
"""Convolution layer with kaiming uniform.
Parameters
----------
in_channels
Number of input channels.
out_channels
Number of output channels.
kernel_size
The kernel size.
stride
The stride.
padding
The padding.
dilation
The dilation.
Returns
-------
The convoluation layer.
"""
conv = Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True,
)
# Caffe2 implementation uses XavierFill, which in fact
# corresponds to kaiming_uniform_ in PyTorch
torch.nn.init.kaiming_uniform_(conv.weight, a=1)
torch.nn.init.constant_(conv.bias, 0)
return conv
def convtrans_with_kaiming_uniform(
in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
):
"""Implement convtrans layer with kaiming uniform.
Parameters
----------
in_channels
Number of input channels.
out_channels
Number of output channels.
kernel_size
The kernel size.
stride
The stride.
padding
The padding.
dilation
The dilation.
Returns
-------
The convtrans layer.
"""
conv = ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True,
)
# Caffe2 implementation uses XavierFill, which in fact
# corresponds to kaiming_uniform_ in PyTorch
torch.nn.init.kaiming_uniform_(conv.weight, a=1)
torch.nn.init.constant_(conv.bias, 0)
return conv
class UpsampleCropBlock(torch.nn.Module):
"""Combines Conv2d, ConvTransposed2d and Cropping. Simulates the caffe2
crop layer in the forward function.
Used for DRIU and HED.
Parameters
----------
in_channels
Number of channels of intermediate layer.
out_channels
Number of output channels.
up_kernel_size
Kernel size for transposed convolution.
up_stride
Stride for transposed convolution.
up_padding
Padding for transposed convolution.
pixelshuffle
If True, uses PixelShuffleICNR upsampling.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
up_kernel_size: int,
up_stride: int,
up_padding: int,
pixelshuffle: bool = False,
):
super().__init__()
# NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl.
self.conv = conv_with_kaiming_uniform(in_channels, out_channels, 3, 1, 1)
if pixelshuffle:
self.upconv = PixelShuffleICNR(out_channels, out_channels, scale=up_stride)
else:
self.upconv = convtrans_with_kaiming_uniform(
out_channels,
out_channels,
up_kernel_size,
up_stride,
up_padding,
)
def forward(self, x, input_res):
img_h = input_res[0]
img_w = input_res[1]
x = self.conv(x)
x = self.upconv(x)
# determine center crop
# height
up_h = x.shape[2]
h_crop = up_h - img_h
h_s = h_crop // 2
h_e = up_h - (h_crop - h_s)
# width
up_w = x.shape[3]
w_crop = up_w - img_w
w_s = w_crop // 2
w_e = up_w - (w_crop - w_s)
# perform crop
# needs explicit ranges for onnx export
return x[:, :, h_s:h_e, w_s:w_e] # crop to input size
def ifnone(a, b):
"""Return ``a`` if ``a`` is not None, otherwise ``b``.
Parameters
----------
a
The first parameter.
b
The second parameter.
Returns
-------
The parameter a if it is not None, else b.
"""
return b if a is None else a
def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_):
"""ICNR init of ``x``, with ``scale`` and ``init`` function.
https://docs.fast.ai/layers.html#PixelShuffleICNR.
Parameters
----------
x
Tensor.
scale
Scale of the upsample.
init
Function used to initialize.
"""
ni, nf, h, w = x.shape
ni2 = int(ni / (scale**2))
k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
k = k.contiguous().view(ni2, nf, -1)
k = k.repeat(1, 1, scale**2)
k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
x.data.copy_(k)
class PixelShuffleICNR(torch.nn.Module):
"""Upsample by ``scale`` from ``ni`` filters to ``nf`` (default
``ni``), using ``torch.nn.PixelShuffle``, ``icnr`` init, and
``weight_norm``.
https://docs.fast.ai/layers.html#PixelShuffleICNR.
Parameters
----------
ni
Number of initial filters.
nf
Number of final filters.
scale
Scale of the upsample.
"""
def __init__(self, ni: int, nf: int = None, scale: int = 2):
super().__init__()
nf = ifnone(nf, ni)
self.conv = conv_with_kaiming_uniform(ni, nf * (scale**2), 1)
icnr(self.conv.weight)
self.shuf = torch.nn.PixelShuffle(scale)
# Blurring over (h*w) kernel
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
# - https://arxiv.org/abs/1806.02658
self.pad = torch.nn.ReplicationPad2d((1, 0, 1, 0))
self.blur = torch.nn.AvgPool2d(2, stride=1)
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, x):
x = self.shuf(self.relu(self.conv(x)))
return self.blur(self.pad(x))
class UnetBlock(torch.nn.Module):
"""Unet block implementation.
Parameters
----------
up_in_c
Number of input channels.
x_in_c
Number of cat channels.
pixel_shuffle
If True, uses a PixelShuffleICNR layer for upsampling.
middle_block
If True, uses a middle block for VGG based U-Net.
"""
def __init__(self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False):
super().__init__()
# middle block for VGG based U-Net
if middle_block:
up_out_c = up_in_c
else:
up_out_c = up_in_c // 2
cat_channels = x_in_c + up_out_c
inner_channels = cat_channels // 2
if pixel_shuffle:
self.upsample = PixelShuffleICNR(up_in_c, up_out_c)
else:
self.upsample = convtrans_with_kaiming_uniform(up_in_c, up_out_c, 2, 2)
self.convtrans1 = convtrans_with_kaiming_uniform(
cat_channels, inner_channels, 3, 1, 1
)
self.convtrans2 = convtrans_with_kaiming_uniform(
inner_channels, inner_channels, 3, 1, 1
)
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, up_in, x_in):
up_out = self.upsample(up_in)
cat_x = torch.cat([up_out, x_in], dim=1)
x = self.relu(self.convtrans1(cat_x))
return self.relu(self.convtrans2(x))
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import torch.nn
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad
from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardBCELogitsLoss
from .make_layers import UnetBlock, conv_with_kaiming_uniform
logger = logging.getLogger("mednet")
class UNetHead(torch.nn.Module):
"""UNet head module.
Parameters
----------
in_channels_list
Number of channels for each feature map that is returned from backbone.
pixel_shuffle
If True, upsample using PixelShuffleICNR.
"""
def __init__(self, in_channels_list: list[int] = None, pixel_shuffle=False):
super().__init__()
# number of channels
c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
# build layers
self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle, middle_block=True)
self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle)
self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle)
self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle)
self.final = conv_with_kaiming_uniform(c_decode1, 1, 1)
def forward(self, x: list[torch.Tensor]):
"""Forward pass.
Parameters
----------
x
List of tensors as returned from the backbone network.
First element: height and width of input image.
Remaining elements: feature maps for each feature level.
Returns
-------
OUtput of the forward pass.
"""
# NOTE: x[0]: height and width of input image not needed in U-Net architecture
decode4 = self.decode4(x[5], x[4])
decode3 = self.decode3(decode4, x[3])
decode2 = self.decode2(decode3, x[2])
decode1 = self.decode1(decode2, x[1])
return self.final(decode1)
class Unet(Model):
"""Implementation of the Unet model.
Parameters
----------
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
pretrained
If True, will use VGG16 pretrained weights.
crop_size
The size of the image after center cropping.
"""
def __init__(
self,
loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
pretrained: bool = False,
crop_size: int = 544,
):
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "unet"
resize_transform = ResizeMaxSide(crop_size)
self.model_transforms = [
resize_transform,
SquareCenterPad(),
]
self.pretrained = pretrained
self.backbone = vgg16_for_segmentation(
pretrained=self.pretrained,
return_features=[3, 8, 14, 22, 29],
)
self.head = UNetHead([64, 128, 256, 512, 512], pixel_shuffle=False)
def forward(self, x):
if self.normalizer is not None:
x = self.normalizer(x)
x = self.backbone(x)
return self.head(x)
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the normalizer for the current model.
This function is NOOP if ``pretrained = True`` (normalizer set to
imagenet weights, during contruction).
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
Will not be used if the model is pretrained.
"""
if self.pretrained:
from mednet.libs.common.models.normalizer import make_imagenet_normalizer
logger.warning(
f"ImageNet pre-trained {self.name} model - NOT "
f"computing z-norm factors from train dataloader. "
f"Using preset factors from torchvision.",
)
self.normalizer = make_imagenet_normalizer()
else:
self.normalizer = None
def training_step(self, batch, batch_idx):
images = batch[0]
ground_truths = batch[1]["target"]
masks = batch[1]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]
ground_truths = batch[1]["target"]
masks = batch[1]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment