diff --git a/pyproject.toml b/pyproject.toml index 11c85b858868c9e8f0f2d27e67b607a4b2b7bebb..3a548d8ab87f59c4e1a2e6cec6983b1d222102e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -422,6 +422,7 @@ visceral = "mednet.config.data.visceral.default" [project.entry-points."mednet.libs.segmentation.config"] lwnet = "mednet.libs.segmentation.config.models.lwnet" +unet = "mednet.libs.segmentation.config.models.unet" # chase-db1 - retinography chasedb1 = "mednet.libs.segmentation.config.data.chasedb1.first_annotator" @@ -546,6 +547,7 @@ exclude = [ # don't report on objects that match any of these regex '\.__len__$', '\.__getitem__$', '\.__iter__$', + '\.__setstate__$', '\.__exit__$', ] diff --git a/src/mednet/libs/segmentation/config/models/unet.py b/src/mednet/libs/segmentation/config/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..9336d2ef1e93549a1e468b460998a7eb49aec90e --- /dev/null +++ b/src/mednet/libs/segmentation/config/models/unet.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Little W-Net for image segmentation. + +The Little W-Net architecture contains roughly around 70k parameters and +closely matches (or outperforms) other more complex techniques. + +Reference: [GALDRAN-2020]_ +""" + +from mednet.libs.segmentation.engine.adabound import AdaBound +from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss +from mednet.libs.segmentation.models.unet import Unet + +lr = 0.01 # start +alpha = 0.7 +betas = (0.9, 0.999) +eps = 1e-08 +weight_decay = 0 +final_lr = 0.1 +gamma = 1e-3 +eps = 1e-8 +amsbound = False + +model = Unet( + loss_type=SoftJaccardBCELogitsLoss, + loss_arguments=dict(alpha=alpha), + optimizer_type=AdaBound, + optimizer_arguments=dict( + lr=lr, + betas=betas, + final_lr=final_lr, + gamma=gamma, + eps=eps, + weight_decay=weight_decay, + amsbound=amsbound, + ), + augmentation_transforms=[], + crop_size=1024, +) diff --git a/src/mednet/libs/segmentation/engine/adabound.py b/src/mednet/libs/segmentation/engine/adabound.py new file mode 100644 index 0000000000000000000000000000000000000000..19f6261e1218fb3cca73555403673d38cea7d2a5 --- /dev/null +++ b/src/mednet/libs/segmentation/engine/adabound.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Implementation of the AdaBound optimizer. + +<https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py>:: + + @inproceedings{Luo2019AdaBound, + author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu}, + title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate}, + booktitle = {Proceedings of the 7th International Conference on Learning Representations}, + month = {May}, + year = {2019}, + address = {New Orleans, Louisiana} + } +""" + +import math +import typing + +import torch +import torch.optim + + +class AdaBound(torch.optim.Optimizer): + """Implement the AdaBound algorithm. + + Parameters + ---------- + params + Iterable of parameters to optimize or dicts defining parameter groups. + lr + Adam learning rate. + betas + Coefficients (as a 2-tuple of floats) used for computing running + averages of gradient and its square. + final_lr + Final (SGD) learning rate. + gamma + Convergence speed of the bound functions. + eps + Term added to the denominator to improve numerical stability. + weight_decay + Weight decay (L2 penalty). + amsbound + Whether to use the AMSBound variant of this algorithm. + """ + + def __init__( + self, + params: list, + lr: float | None = 1e-3, + betas: tuple[float, float] | None = (0.9, 0.999), + final_lr: float | None = 0.1, + gamma: float | None = 1e-3, + eps: float | None = 1e-8, + weight_decay: float | None = 0, + amsbound: bool | None = False, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= final_lr: + raise ValueError(f"Invalid final learning rate: {final_lr}") + if not 0.0 <= gamma < 1.0: + raise ValueError(f"Invalid gamma parameter: {gamma}") + defaults = dict( + lr=lr, + betas=betas, + final_lr=final_lr, + gamma=gamma, + eps=eps, + weight_decay=weight_decay, + amsbound=amsbound, + ) + super().__init__(params, defaults) + + self.base_lrs = list(map(lambda group: group["lr"], self.param_groups)) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsbound", False) + + def step(self, closure: typing.Callable | None = None): + """Perform a single optimization step. + + Parameters + ---------- + closure + A closure that reevaluates the model and returns the loss. + + Returns + ------- + The loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group, base_lr in zip(self.param_groups, self.base_lrs): + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsbound = group["amsbound"] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + if amsbound: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if amsbound: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + if group["weight_decay"] != 0: + grad = grad.add(group["weight_decay"], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsbound: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group["eps"]) + else: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + + # Applies bounds on actual learning rate + # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay + final_lr = group["final_lr"] * group["lr"] / base_lr + lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1)) + upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"])) + step_size = torch.full_like(denom, step_size) + step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) + + p.data.add_(-step_size) + + return loss + + +class AdaBoundW(torch.optim.Optimizer): + """Implement AdaBound algorithm with Decoupled Weight Decay (See + https://arxiv.org/abs/1711.05101). + + Parameters + ---------- + params + Iterable of parameters to optimize or dicts defining parameter groups. + lr + Adam learning rate. + betas + Coefficients (as a 2-tuple of floats) used for computing running + averages of gradient and its square. + final_lr + Final (SGD) learning rate. + gamma + Convergence speed of the bound functions. + eps + Term added to the denominator to improve numerical stability. + weight_decay + Weight decay (L2 penalty). + amsbound + Whether to use the AMSBound variant of this algorithm. + """ + + def __init__( + self, + params: list, + lr: float | None = 1e-3, + betas: tuple[float, float] | None = (0.9, 0.999), + final_lr: float | None = 0.1, + gamma: float | None = 1e-3, + eps: float | None = 1e-8, + weight_decay: float | None = 0, + amsbound: bool | None = False, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= final_lr: + raise ValueError(f"Invalid final learning rate: {final_lr}") + if not 0.0 <= gamma < 1.0: + raise ValueError(f"Invalid gamma parameter: {gamma}") + defaults = dict( + lr=lr, + betas=betas, + final_lr=final_lr, + gamma=gamma, + eps=eps, + weight_decay=weight_decay, + amsbound=amsbound, + ) + super().__init__(params, defaults) + + self.base_lrs = list(map(lambda group: group["lr"], self.param_groups)) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsbound", False) + + def step(self, closure: typing.Callable | None = None): + """Perform a single optimization step. + + Parameters + ---------- + closure + A closure that reevaluates the model and returns the loss. + + Returns + ------- + The loss. + """ + + loss = None + if closure is not None: + loss = closure() + + for group, base_lr in zip(self.param_groups, self.base_lrs): + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsbound = group["amsbound"] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + if amsbound: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if amsbound: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsbound: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group["eps"]) + else: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + + # Applies bounds on actual learning rate + # lr_scheduler cannot affect final_lr, this is a workaround to + # apply lr decay + final_lr = group["final_lr"] * group["lr"] / base_lr + lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1)) + upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"])) + step_size = torch.full_like(denom, step_size) + step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) + + if group["weight_decay"] != 0: + decayed_weights = torch.mul(p.data, group["weight_decay"]) + p.data.add_(-step_size) + p.data.sub_(decayed_weights) + else: + p.data.add_(-step_size) + + return loss diff --git a/src/mednet/libs/segmentation/models/backbones/__init__.py b/src/mednet/libs/segmentation/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/libs/segmentation/models/backbones/mobilenetv2.py b/src/mednet/libs/segmentation/models/backbones/mobilenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..0340b3b895603406a3e3d3c8965ee756a5a0edf5 --- /dev/null +++ b/src/mednet/libs/segmentation/models/backbones/mobilenetv2.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import torchvision.models +import torchvision.models.mobilenetv2 + +try: + # pytorch >= 1.12 + from torch.hub import load_state_dict_from_url +except ImportError: + # pytorch < 1.12 + from torchvision.models.utils import load_state_dict_from_url + + +class MobileNetV24Segmentation(torchvision.models.mobilenetv2.MobileNetV2): + """Adaptation of base MobileNetV2 functionality to U-Net style + segmentation. + + This version of MobileNetV2 is slightly modified so it can be used through + torchvision's API. It outputs intermediate features which are normally not + output by the base MobileNetV2 implementation, but are required for + segmentation operations. + + Parameters + ---------- + *args + Arguments to be passed to the parent MobileNetV2 model. + **kwargs + Keyword arguments to be passed to the parent MobileNetV2 model. + return_features : :py:class:`list`, Optional + A list of integers indicating the feature layers to be returned from + the original module. + """ + + def __init__(self, *args, **kwargs): + self._return_features = kwargs.pop("return_features") + super().__init__(*args, **kwargs) + + def forward(self, x): + outputs = [] + # hw of input, needed for DRIU and HED + outputs.append(x.shape[2:4]) + outputs.append(x) + for index, m in enumerate(self.features): + x = m(x) + # extract layers + if index in self._return_features: + outputs.append(x) + return outputs + + +def mobilenet_v2_for_segmentation(pretrained=False, progress=True, **kwargs): + """Create MobileNetV2 model for segmentation task. + + Parameters + ---------- + pretrained + If True, uses MobileNetV2 pretrained weights. + progress + If True, shows a progress bar when downloading the pretrained weights. + **kwargs + Keyword arguments to be passed to the parent MobileNetV2 model. + return_features : :py:class:`list`, Optional + A list of integers indicating the feature layers to be returned from + the original module. + + Returns + ------- + Instance of the MobileNetV2 model for segmentation. + """ + + model = MobileNetV24Segmentation(**kwargs) + + if pretrained: + state_dict = load_state_dict_from_url( + torchvision.models.mobilenetv2.MobileNet_V2_Weights.DEFAULT.url, + progress=progress, + ) + model.load_state_dict(state_dict) + + # erase MobileNetV2 head (for classification), not used for segmentation + delattr(model, "classifier") + + return_features = kwargs.get("return_features") + if return_features is not None: + model.features = model.features[: (max(return_features) + 1)] + + return model + + +mobilenet_v2_for_segmentation.__doc__ = torchvision.models.mobilenetv2.__doc__ diff --git a/src/mednet/libs/segmentation/models/backbones/resnet.py b/src/mednet/libs/segmentation/models/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd2b3f35a7f9e109fad85ebdac317a30abd3791 --- /dev/null +++ b/src/mednet/libs/segmentation/models/backbones/resnet.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import torchvision.models + +try: + # pytorch >= 1.12 + from torch.hub import load_state_dict_from_url +except ImportError: + # pytorch < 1.12 + from torchvision.models.utils import load_state_dict_from_url + + +class ResNet4Segmentation(torchvision.models.resnet.ResNet): + """Adaptation of base ResNet functionality to U-Net style segmentation. + + This version of ResNet is slightly modified so it can be used through + torchvision's API. It outputs intermediate features which are normally not + output by the base ResNet implementation, but are required for segmentation + operations. + + Parameters + ---------- + *args + Arguments to be passed to the parent ResNet model. + **kwargs + Keyword arguments to be passed to the parent ResNet model. + return_features : :py:class:`list`, Optional + A list of integers indicating the feature layers to be returned from + the original module. + """ + + def __init__(self, *args, **kwargs): + self._return_features = kwargs.pop("return_features") + super().__init__(*args, **kwargs) + + def forward(self, x): + outputs = [] + # hardwiring of input + outputs.append(x.shape[2:4]) + for index, m in enumerate(self.features): + x = m(x) + # extract layers + if index in self.return_features: + outputs.append(x) + return outputs + + +def resnet50_for_segmentation(pretrained=False, progress=True, **kwargs): + """Create ResNet for segmentation task. + + Parameters + ---------- + pretrained + If True, uses ResNet50 pretrained weights. + progress + If True, shows a progress bar when downloading the pretrained weights. + **kwargs + Keyword arguments to be passed to the parent ResNet model. + return_features : :py:class:`list`, Optional + A list of integers indicating the feature layers to be returned from + the original module. + + Returns + ------- + Instance of the ResNet model for segmentation. + """ + model = ResNet4Segmentation( + torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs + ) + + if pretrained: + state_dict = load_state_dict_from_url( + torchvision.models.resnet.ResNet50_Weights.DEFAULT.url, + progress=progress, + ) + model.load_state_dict(state_dict) + + # erase ResNet head (for classification), not used for segmentation + delattr(model, "avgpool") + delattr(model, "fc") + + return model + + +resnet50_for_segmentation.__doc__ = torchvision.models.resnet50.__doc__ diff --git a/src/mednet/libs/segmentation/models/backbones/vgg.py b/src/mednet/libs/segmentation/models/backbones/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ed74ae0589daca96019daee33801021d7c30e2 --- /dev/null +++ b/src/mednet/libs/segmentation/models/backbones/vgg.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import torchvision.models + +try: + # pytorch >= 1.12 + from torch.hub import load_state_dict_from_url +except ImportError: + # pytorch < 1.12 + from torchvision.models.utils import load_state_dict_from_url + + +class VGG4Segmentation(torchvision.models.vgg.VGG): + """Adaptation of base VGG functionality to U-Net style segmentation. + + This version of VGG is slightly modified so it can be used through + torchvision's API. It outputs intermediate features which are normally not + output by the base VGG implementation, but are required for segmentation + operations. + + Parameters + ---------- + *args + Arguments to be passed to the parent VGG model. + **kwargs + Keyword arguments to be passed to the parent VGG model. + return_features : :py:class:`list`, Optional + A list of integers indicating the feature layers to be returned from + the original module. + """ + + def __init__(self, *args, **kwargs): + self._return_features = kwargs.pop("return_features") + super().__init__(*args, **kwargs) + + def forward(self, x): + outputs = [] + # hardwiring of input + outputs.append(x.shape[2:4]) + for index, m in enumerate(self.features): + x = m(x) + # extract layers + if index in self._return_features: + outputs.append(x) + return outputs + + +def _make_vgg16_type_d_for_segmentation(pretrained, batch_norm, progress, **kwargs): + if pretrained: + kwargs["init_weights"] = False + + model = VGG4Segmentation( + torchvision.models.vgg.make_layers( + torchvision.models.vgg.cfgs["D"], + batch_norm=batch_norm, + ), + **kwargs, + ) + + if pretrained: + weights = ( + torchvision.models.vgg.VGG16_Weights.DEFAULT.url + if not batch_norm + else torchvision.models.vgg.VGG16_BN_Weights.DEFAULT.url + ) + + state_dict = load_state_dict_from_url(weights, progress=progress) + model.load_state_dict(state_dict) + + # erase VGG head (for classification), not used for segmentation + delattr(model, "classifier") + delattr(model, "avgpool") + + return model + + +def vgg16_for_segmentation(pretrained=False, progress=True, **kwargs): + """Create an instance of VGG16. + + Parameters + ---------- + pretrained + If True, usees VGG16 pretrained weights. + progress + If True, shows a progress bar when downloading weights. + **kwargs + Keyword arguments to be passed to the parent VGG model. + return_features : :py:class:`list`, Optional + A list of integers indicating the feature layers to be returned from + the original module. + + Returns + ------- + Instance of VGG16. + """ + return _make_vgg16_type_d_for_segmentation( + pretrained=pretrained, batch_norm=False, progress=progress, **kwargs + ) + + +vgg16_for_segmentation.__doc__ = torchvision.models.vgg16.__doc__ + + +def vgg16_bn_for_segmentation(pretrained=False, progress=True, **kwargs): + """Create an instance of VGG16 with batch norm. + + Parameters + ---------- + pretrained + If True, usees VGG16 pretrained weights. + progress + If True, shows a progress bar when downloading weights. + **kwargs + Keyword arguments to be passed to the parent VGG model. + return_features : :py:class:`list`, Optional + A list of integers indicating the feature layers to be returned from + the original module. + + Returns + ------- + Instance of VGG16. + """ + return _make_vgg16_type_d_for_segmentation( + pretrained=pretrained, batch_norm=True, progress=progress, **kwargs + ) + + +vgg16_bn_for_segmentation.__doc__ = torchvision.models.vgg16_bn.__doc__ diff --git a/src/mednet/libs/segmentation/models/make_layers.py b/src/mednet/libs/segmentation/models/make_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3b3228b87740da102051f4b8e70973f3db5ba7b2 --- /dev/null +++ b/src/mednet/libs/segmentation/models/make_layers.py @@ -0,0 +1,273 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import torch +import torch.nn +from torch.nn import Conv2d, ConvTranspose2d + + +def conv_with_kaiming_uniform( + in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1 +): + """Convolution layer with kaiming uniform. + + Parameters + ---------- + in_channels + Number of input channels. + out_channels + Number of output channels. + kernel_size + The kernel size. + stride + The stride. + padding + The padding. + dilation + The dilation. + + Returns + ------- + The convoluation layer. + """ + conv = Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ) + # Caffe2 implementation uses XavierFill, which in fact + # corresponds to kaiming_uniform_ in PyTorch + torch.nn.init.kaiming_uniform_(conv.weight, a=1) + torch.nn.init.constant_(conv.bias, 0) + return conv + + +def convtrans_with_kaiming_uniform( + in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1 +): + """Implement convtrans layer with kaiming uniform. + + Parameters + ---------- + in_channels + Number of input channels. + out_channels + Number of output channels. + kernel_size + The kernel size. + stride + The stride. + padding + The padding. + dilation + The dilation. + + Returns + ------- + The convtrans layer. + """ + conv = ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ) + # Caffe2 implementation uses XavierFill, which in fact + # corresponds to kaiming_uniform_ in PyTorch + torch.nn.init.kaiming_uniform_(conv.weight, a=1) + torch.nn.init.constant_(conv.bias, 0) + return conv + + +class UpsampleCropBlock(torch.nn.Module): + """Combines Conv2d, ConvTransposed2d and Cropping. Simulates the caffe2 + crop layer in the forward function. + + Used for DRIU and HED. + + Parameters + ---------- + in_channels + Number of channels of intermediate layer. + out_channels + Number of output channels. + up_kernel_size + Kernel size for transposed convolution. + up_stride + Stride for transposed convolution. + up_padding + Padding for transposed convolution. + pixelshuffle + If True, uses PixelShuffleICNR upsampling. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + up_kernel_size: int, + up_stride: int, + up_padding: int, + pixelshuffle: bool = False, + ): + super().__init__() + # NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl. + self.conv = conv_with_kaiming_uniform(in_channels, out_channels, 3, 1, 1) + if pixelshuffle: + self.upconv = PixelShuffleICNR(out_channels, out_channels, scale=up_stride) + else: + self.upconv = convtrans_with_kaiming_uniform( + out_channels, + out_channels, + up_kernel_size, + up_stride, + up_padding, + ) + + def forward(self, x, input_res): + img_h = input_res[0] + img_w = input_res[1] + x = self.conv(x) + x = self.upconv(x) + # determine center crop + # height + up_h = x.shape[2] + h_crop = up_h - img_h + h_s = h_crop // 2 + h_e = up_h - (h_crop - h_s) + # width + up_w = x.shape[3] + w_crop = up_w - img_w + w_s = w_crop // 2 + w_e = up_w - (w_crop - w_s) + # perform crop + # needs explicit ranges for onnx export + return x[:, :, h_s:h_e, w_s:w_e] # crop to input size + + +def ifnone(a, b): + """Return ``a`` if ``a`` is not None, otherwise ``b``. + + Parameters + ---------- + a + The first parameter. + b + The second parameter. + + Returns + ------- + The parameter a if it is not None, else b. + """ + return b if a is None else a + + +def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_): + """ICNR init of ``x``, with ``scale`` and ``init`` function. + + https://docs.fast.ai/layers.html#PixelShuffleICNR. + + Parameters + ---------- + x + Tensor. + scale + Scale of the upsample. + init + Function used to initialize. + """ + + ni, nf, h, w = x.shape + ni2 = int(ni / (scale**2)) + k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1) + k = k.contiguous().view(ni2, nf, -1) + k = k.repeat(1, 1, scale**2) + k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1) + x.data.copy_(k) + + +class PixelShuffleICNR(torch.nn.Module): + """Upsample by ``scale`` from ``ni`` filters to ``nf`` (default + ``ni``), using ``torch.nn.PixelShuffle``, ``icnr`` init, and + ``weight_norm``. + + https://docs.fast.ai/layers.html#PixelShuffleICNR. + + Parameters + ---------- + ni + Number of initial filters. + nf + Number of final filters. + scale + Scale of the upsample. + """ + + def __init__(self, ni: int, nf: int = None, scale: int = 2): + super().__init__() + nf = ifnone(nf, ni) + self.conv = conv_with_kaiming_uniform(ni, nf * (scale**2), 1) + icnr(self.conv.weight) + self.shuf = torch.nn.PixelShuffle(scale) + # Blurring over (h*w) kernel + # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts" + # - https://arxiv.org/abs/1806.02658 + self.pad = torch.nn.ReplicationPad2d((1, 0, 1, 0)) + self.blur = torch.nn.AvgPool2d(2, stride=1) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + x = self.shuf(self.relu(self.conv(x))) + return self.blur(self.pad(x)) + + +class UnetBlock(torch.nn.Module): + """Unet block implementation. + + Parameters + ---------- + up_in_c + Number of input channels. + x_in_c + Number of cat channels. + pixel_shuffle + If True, uses a PixelShuffleICNR layer for upsampling. + middle_block + If True, uses a middle block for VGG based U-Net. + """ + + def __init__(self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False): + super().__init__() + # middle block for VGG based U-Net + if middle_block: + up_out_c = up_in_c + else: + up_out_c = up_in_c // 2 + cat_channels = x_in_c + up_out_c + inner_channels = cat_channels // 2 + + if pixel_shuffle: + self.upsample = PixelShuffleICNR(up_in_c, up_out_c) + else: + self.upsample = convtrans_with_kaiming_uniform(up_in_c, up_out_c, 2, 2) + self.convtrans1 = convtrans_with_kaiming_uniform( + cat_channels, inner_channels, 3, 1, 1 + ) + self.convtrans2 = convtrans_with_kaiming_uniform( + inner_channels, inner_channels, 3, 1, 1 + ) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, up_in, x_in): + up_out = self.upsample(up_in) + cat_x = torch.cat([up_out, x_in], dim=1) + x = self.relu(self.convtrans1(cat_x)) + return self.relu(self.convtrans2(x)) diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..23ea8878fb58f5299263d8b60dc1a78f1cdae2ac --- /dev/null +++ b/src/mednet/libs/segmentation/models/unet.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import logging +import typing + +import torch.nn +from mednet.libs.common.data.typing import TransformSequence +from mednet.libs.common.models.model import Model +from mednet.libs.common.models.transforms import ResizeMaxSide, SquareCenterPad + +from .backbones.vgg import vgg16_for_segmentation +from .losses import SoftJaccardBCELogitsLoss +from .make_layers import UnetBlock, conv_with_kaiming_uniform + +logger = logging.getLogger("mednet") + + +class UNetHead(torch.nn.Module): + """UNet head module. + + Parameters + ---------- + in_channels_list + Number of channels for each feature map that is returned from backbone. + pixel_shuffle + If True, upsample using PixelShuffleICNR. + """ + + def __init__(self, in_channels_list: list[int] = None, pixel_shuffle=False): + super().__init__() + # number of channels + c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list + + # build layers + self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle, middle_block=True) + self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle) + self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle) + self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle) + self.final = conv_with_kaiming_uniform(c_decode1, 1, 1) + + def forward(self, x: list[torch.Tensor]): + """Forward pass. + + Parameters + ---------- + x + List of tensors as returned from the backbone network. + First element: height and width of input image. + Remaining elements: feature maps for each feature level. + + Returns + ------- + OUtput of the forward pass. + """ + # NOTE: x[0]: height and width of input image not needed in U-Net architecture + decode4 = self.decode4(x[5], x[4]) + decode3 = self.decode3(decode4, x[3]) + decode2 = self.decode2(decode3, x[2]) + decode1 = self.decode1(decode2, x[1]) + return self.final(decode1) + + +class Unet(Model): + """Implementation of the Unet model. + + Parameters + ---------- + loss_type + The loss to be used for training and evaluation. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. + optimizer_type + The type of optimizer to use for training. + optimizer_arguments + Arguments to the optimizer after ``params``. + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + num_classes + Number of outputs (classes) for this model. + pretrained + If True, will use VGG16 pretrained weights. + crop_size + The size of the image after center cropping. + """ + + def __init__( + self, + loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, + augmentation_transforms: TransformSequence = [], + num_classes: int = 1, + pretrained: bool = False, + crop_size: int = 544, + ): + super().__init__( + loss_type, + loss_arguments, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) + + self.name = "unet" + + resize_transform = ResizeMaxSide(crop_size) + + self.model_transforms = [ + resize_transform, + SquareCenterPad(), + ] + + self.pretrained = pretrained + + self.backbone = vgg16_for_segmentation( + pretrained=self.pretrained, + return_features=[3, 8, 14, 22, 29], + ) + + self.head = UNetHead([64, 128, 256, 512, 512], pixel_shuffle=False) + + def forward(self, x): + if self.normalizer is not None: + x = self.normalizer(x) + x = self.backbone(x) + return self.head(x) + + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the normalizer for the current model. + + This function is NOOP if ``pretrained = True`` (normalizer set to + imagenet weights, during contruction). + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + Will not be used if the model is pretrained. + """ + if self.pretrained: + from mednet.libs.common.models.normalizer import make_imagenet_normalizer + + logger.warning( + f"ImageNet pre-trained {self.name} model - NOT " + f"computing z-norm factors from train dataloader. " + f"Using preset factors from torchvision.", + ) + self.normalizer = make_imagenet_normalizer() + else: + self.normalizer = None + + def training_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["target"] + masks = batch[1]["mask"] + + outputs = self(self._augmentation_transforms(images)) + return self._train_loss(outputs, ground_truths, masks) + + def validation_step(self, batch, batch_idx): + images = batch[0] + ground_truths = batch[1]["target"] + masks = batch[1]["mask"] + + outputs = self(images) + return self._validation_loss(outputs, ground_truths, masks)