Skip to content
Snippets Groups Projects
Commit 4db9f0a7 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation.models] Remove resnet model

ResNet was ununsed and non functional in deepdraw, there is no need to
port it over.
parent 73dfc4e5
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -429,7 +429,6 @@ driu-pix = "mednet.libs.segmentation.config.models.driu_pix"
hed = "mednet.libs.segmentation.config.models.hed"
lwnet = "mednet.libs.segmentation.config.models.lwnet"
m2unet = "mednet.libs.segmentation.config.models.m2unet"
#resunet = "mednet.libs.segmentation.config.models.resunet"
unet = "mednet.libs.segmentation.config.models.unet"
# chase-db1 - retinography
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import torchvision.models
try:
# pytorch >= 1.12
from torch.hub import load_state_dict_from_url
except ImportError:
# pytorch < 1.12
from torchvision.models.utils import load_state_dict_from_url
class ResNet4Segmentation(torchvision.models.resnet.ResNet):
"""Adaptation of base ResNet functionality to U-Net style segmentation.
This version of ResNet is slightly modified so it can be used through
torchvision's API. It outputs intermediate features which are normally not
output by the base ResNet implementation, but are required for segmentation
operations.
Parameters
----------
*args
Arguments to be passed to the parent ResNet model.
**kwargs
Keyword arguments to be passed to the parent ResNet model.
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
"""
def __init__(self, *args, **kwargs):
self._return_features = kwargs.pop("return_features")
super().__init__(*args, **kwargs)
def forward(self, x):
outputs = []
# hardwiring of input
outputs.append(x.shape[2:4])
for index, m in enumerate(self.features):
x = m(x)
# extract layers
if index in self.return_features:
outputs.append(x)
return outputs
def resnet50_for_segmentation(pretrained=False, progress=True, **kwargs):
"""Create ResNet for segmentation task.
Parameters
----------
pretrained
If True, uses ResNet50 pretrained weights.
progress
If True, shows a progress bar when downloading the pretrained weights.
**kwargs
Keyword arguments to be passed to the parent ResNet model.
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
Returns
-------
Instance of the ResNet model for segmentation.
"""
model = ResNet4Segmentation(
torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs
)
if pretrained:
state_dict = load_state_dict_from_url(
torchvision.models.resnet.ResNet50_Weights.DEFAULT.url,
progress=progress,
)
model.load_state_dict(state_dict)
# erase ResNet head (for classification), not used for segmentation
delattr(model, "avgpool")
delattr(model, "fc")
return model
resnet50_for_segmentation.__doc__ = torchvision.models.resnet50.__doc__
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment