diff --git a/pyproject.toml b/pyproject.toml index 432ecb6bb08a75f9a8178eabdfcc3bd8ee521248..a5680a15e74d6c39f38d0429477ebddf75c96657 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -429,7 +429,6 @@ driu-pix = "mednet.libs.segmentation.config.models.driu_pix" hed = "mednet.libs.segmentation.config.models.hed" lwnet = "mednet.libs.segmentation.config.models.lwnet" m2unet = "mednet.libs.segmentation.config.models.m2unet" -#resunet = "mednet.libs.segmentation.config.models.resunet" unet = "mednet.libs.segmentation.config.models.unet" # chase-db1 - retinography diff --git a/src/mednet/libs/segmentation/models/backbones/resnet.py b/src/mednet/libs/segmentation/models/backbones/resnet.py deleted file mode 100644 index 3bd2b3f35a7f9e109fad85ebdac317a30abd3791..0000000000000000000000000000000000000000 --- a/src/mednet/libs/segmentation/models/backbones/resnet.py +++ /dev/null @@ -1,87 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import torchvision.models - -try: - # pytorch >= 1.12 - from torch.hub import load_state_dict_from_url -except ImportError: - # pytorch < 1.12 - from torchvision.models.utils import load_state_dict_from_url - - -class ResNet4Segmentation(torchvision.models.resnet.ResNet): - """Adaptation of base ResNet functionality to U-Net style segmentation. - - This version of ResNet is slightly modified so it can be used through - torchvision's API. It outputs intermediate features which are normally not - output by the base ResNet implementation, but are required for segmentation - operations. - - Parameters - ---------- - *args - Arguments to be passed to the parent ResNet model. - **kwargs - Keyword arguments to be passed to the parent ResNet model. - return_features : :py:class:`list`, Optional - A list of integers indicating the feature layers to be returned from - the original module. - """ - - def __init__(self, *args, **kwargs): - self._return_features = kwargs.pop("return_features") - super().__init__(*args, **kwargs) - - def forward(self, x): - outputs = [] - # hardwiring of input - outputs.append(x.shape[2:4]) - for index, m in enumerate(self.features): - x = m(x) - # extract layers - if index in self.return_features: - outputs.append(x) - return outputs - - -def resnet50_for_segmentation(pretrained=False, progress=True, **kwargs): - """Create ResNet for segmentation task. - - Parameters - ---------- - pretrained - If True, uses ResNet50 pretrained weights. - progress - If True, shows a progress bar when downloading the pretrained weights. - **kwargs - Keyword arguments to be passed to the parent ResNet model. - return_features : :py:class:`list`, Optional - A list of integers indicating the feature layers to be returned from - the original module. - - Returns - ------- - Instance of the ResNet model for segmentation. - """ - model = ResNet4Segmentation( - torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs - ) - - if pretrained: - state_dict = load_state_dict_from_url( - torchvision.models.resnet.ResNet50_Weights.DEFAULT.url, - progress=progress, - ) - model.load_state_dict(state_dict) - - # erase ResNet head (for classification), not used for segmentation - delattr(model, "avgpool") - delattr(model, "fc") - - return model - - -resnet50_for_segmentation.__doc__ = torchvision.models.resnet50.__doc__