Skip to content
Snippets Groups Projects
Commit ba2b492f authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch 'model-cleanup' into 'master'

Cleanup model implementation

See merge request bob/bob.ip.binseg!13
parents 32831d9e 1e54e176
No related branches found
No related tags found
1 merge request!13Cleanup model implementation
Pipeline #39947 passed
Showing
with 253 additions and 706 deletions
......@@ -11,9 +11,8 @@ Reference: [MANINIS-2016]_
"""
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.driu import build_driu
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.models.driu import driu
from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
......@@ -29,9 +28,7 @@ amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
model = build_driu()
pretrained_backbone = modelurls["vgg16"]
model = driu()
optimizer = AdaBound(
model.parameters(),
......
......@@ -12,9 +12,8 @@ Reference: [MANINIS-2016]_
"""
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.driubn import build_driu
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.models.driu_bn import driu_bn
from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
......@@ -30,11 +29,7 @@ amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
# model
model = build_driu()
# pretrained backbone
pretrained_backbone = modelurls["vgg16_bn"]
model = driu_bn()
# optimizer
optimizer = AdaBound(
......
......@@ -13,9 +13,8 @@ Reference: [MANINIS-2016]_
"""
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.driubn import build_driu
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import MixJacLoss
from bob.ip.binseg.models.driu_bn import driu_bn
from bob.ip.binseg.models.losses import MixJacLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
......@@ -31,11 +30,7 @@ amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
# model
model = build_driu()
# pretrained backbone
pretrained_backbone = modelurls["vgg16_bn"]
model = driu_bn()
# optimizer
optimizer = AdaBound(
......
......@@ -11,9 +11,8 @@ Reference: [MANINIS-2016]_
"""
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.driuod import build_driuod
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.models.driu_od import driu_od
from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
......@@ -29,11 +28,7 @@ amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
# model
model = build_driuod()
# pretrained backbone
pretrained_backbone = modelurls["vgg16"]
model = driu_od()
# optimizer
optimizer = AdaBound(
......
......@@ -12,9 +12,8 @@ Reference: [MANINIS-2016]_
"""
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.driu import build_driu
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import MixJacLoss
from bob.ip.binseg.models.driu import driu
from bob.ip.binseg.models.losses import MixJacLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
......@@ -30,11 +29,7 @@ amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
# model
model = build_driu()
# pretrained backbone
pretrained_backbone = modelurls["vgg16"]
model = driu()
# optimizer
optimizer = AdaBound(
......
......@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
"""HED Network for Vessel Segmentation
"""HED Network for image segmentation
Holistically-nested edge detection (HED), turns pixel-wise edge classification
into image-to-image prediction by means of a deep learning model that leverages
......@@ -13,9 +13,8 @@ Reference: [XIE-2015]_
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.hed import build_hed
from bob.ip.binseg.modeling.losses import HEDSoftJaccardBCELogitsLoss
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.models.hed import hed
from bob.ip.binseg.models.losses import HEDSoftJaccardBCELogitsLoss
from bob.ip.binseg.engine.adabound import AdaBound
......@@ -32,12 +31,7 @@ amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
# model
model = build_hed()
# pretrained backbone
pretrained_backbone = modelurls["vgg16"]
model = hed()
# optimizer
optimizer = AdaBound(
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""MobileNetV2 U-Net Model for Vessel Segmentation
"""MobileNetV2 U-Net model for image segmentation
The MobileNetV2 architecture is based on an inverted residual structure where
the input and output of the residual block are thin bottleneck layers opposite
......@@ -15,9 +15,8 @@ References: [SANDLER-2018]_, [RONNEBERGER-2015]_
"""
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.m2u import build_m2unet
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.models.m2unet import m2unet
from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
......@@ -33,11 +32,7 @@ amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
# model
model = build_m2unet()
# pretrained backbone
pretrained_backbone = modelurls["mobilenetv2"]
model = m2unet()
# optimizer
optimizer = AdaBound(
......
......@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
"""MobileNetV2 U-Net Model for Vessel Segmentation using SSL
"""MobileNetV2 U-Net model for image segmentation using SSL
The MobileNetV2 architecture is based on an inverted residual structure where
the input and output of the residual block are thin bottleneck layers opposite
......@@ -18,9 +18,8 @@ References: [SANDLER-2018]_, [RONNEBERGER-2015]_
"""
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.m2u import build_m2unet
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import MixJacLoss
from bob.ip.binseg.models.m2unet import m2unet
from bob.ip.binseg.models.losses import MixJacLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
......@@ -36,11 +35,7 @@ amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
# model
model = build_m2unet()
# pretrained backbone
pretrained_backbone = modelurls["mobilenetv2"]
model = m2unet()
# optimizer
optimizer = AdaBound(
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Residual U-Net for Vessel Segmentation
"""Residual U-Net for image segmentation
A semantic segmentation neural network which combines the strengths of residual
learning and U-Net is proposed for road area extraction. The network is built
......@@ -15,9 +15,8 @@ Reference: [ZHANG-2017]_
"""
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.resunet import build_res50unet
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.models.resunet import resunet50
from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
......@@ -33,11 +32,7 @@ amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
# model
model = build_res50unet()
# pretrained backbone
pretrained_backbone = modelurls["resnet50"]
model = resunet50()
# optimizer
optimizer = AdaBound(
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""U-Net for Vessel Segmentation
"""U-Net for image segmentation
U-Net is a convolutional neural network that was developed for biomedical image
segmentation at the Computer Science Department of the University of Freiburg,
......@@ -13,9 +13,8 @@ Reference: [RONNEBERGER-2015]_
"""
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.unet import build_unet
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.models.unet import unet
from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
......@@ -31,11 +30,7 @@ amsbound = False
scheduler_milestones = [900]
scheduler_gamma = 0.1
# model
model = build_unet()
# pretrained backbone
pretrained_backbone = modelurls["vgg16"]
model = unet()
# optimizer
optimizer = AdaBound(
......
......@@ -207,7 +207,7 @@ def run(
scheduler : :py:mod:`torch.optim`
learning rate scheduler
checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer`
checkpointer implementation
checkpoint_period : int
......
......@@ -93,7 +93,7 @@ def run(
scheduler : :py:mod:`torch.optim`
learning rate scheduler
checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer`
checkpointer implementation
checkpoint_period : int
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Adopted from https://github.com/tonylins/pytorch-mobilenet-v2/ by @tonylins
# Ji Lin under Apache License 2.0
import torch.nn
import math
def conv_bn(inp, oup, stride):
return torch.nn.Sequential(
torch.nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
torch.nn.BatchNorm2d(oup),
torch.nn.ReLU6(inplace=True),
)
def conv_1x1_bn(inp, oup):
return torch.nn.Sequential(
torch.nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
torch.nn.BatchNorm2d(oup),
torch.nn.ReLU6(inplace=True),
)
class InvertedResidual(torch.nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expand_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expand_ratio == 1:
self.conv = torch.nn.Sequential(
# dw
torch.nn.Conv2d(
hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False
),
torch.nn.BatchNorm2d(hidden_dim),
torch.nn.ReLU6(inplace=True),
# pw-linear
torch.nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
torch.nn.BatchNorm2d(oup),
)
else:
self.conv = torch.nn.Sequential(
# pw
torch.nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
torch.nn.BatchNorm2d(hidden_dim),
torch.nn.ReLU6(inplace=True),
# dw
torch.nn.Conv2d(
hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False
),
torch.nn.BatchNorm2d(hidden_dim),
torch.nn.ReLU6(inplace=True),
# pw-linear
torch.nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
torch.nn.BatchNorm2d(oup),
)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(torch.nn.Module):
def __init__(
self,
n_class=1000,
input_size=224,
width_mult=1.0,
return_features=None,
m2u=True,
):
super(MobileNetV2, self).__init__()
self.return_features = return_features
self.m2u = m2u
block = InvertedResidual
input_channel = 32
#last_channel = 1280
interverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
# [6, 160, 3, 2],
# [6, 320, 1, 1],
]
# building first layer
assert input_size % 32 == 0
input_channel = int(input_channel * width_mult)
# self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
self.features = [conv_bn(3, input_channel, 2)]
# building inverted residual blocks
for t, c, n, s in interverted_residual_setting:
output_channel = int(c * width_mult)
for i in range(n):
if i == 0:
self.features.append(
block(input_channel, output_channel, s, expand_ratio=t)
)
else:
self.features.append(
block(input_channel, output_channel, 1, expand_ratio=t)
)
input_channel = output_channel
# building last several layers
# self.features.append(conv_1x1_bn(input_channel, self.last_channel))
# make it torch.nn.Sequential
self.features = torch.nn.Sequential(*self.features)
# building classifier
# self.classifier = torch.nn.Sequential(
# torch.nn.Dropout(0.2),
# torch.nn.Linear(self.last_channel, n_class),
# )
self._initialize_weights()
def forward(self, x):
outputs = []
# hw of input, needed for DRIU and HED
outputs.append(x.shape[2:4])
if self.m2u:
outputs.append(x)
for index, m in enumerate(self.features):
x = m(x)
# extract layers
if index in self.return_features:
outputs.append(x)
return outputs
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, torch.nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, torch.nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
# Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
# resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN : https://github.com/rgeirhos/texture-vs-shap
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
"resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN": "https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar",
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False,
)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, return_features, zero_init_residual=False):
"""
Generic ResNet network with layer return.
Attributes
----------
return_features: list of length 5
layers to return.
"""
super(ResNet, self).__init__()
self.inplanes = 64
self.return_features = return_features
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.features = [
self.conv1,
self.bn1,
self.relu,
self.maxpool,
self.layer1,
self.layer2,
self.layer3,
self.layer4,
]
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
outputs = []
# hw of input, needed for DRIU and HED
outputs.append(x.shape[2:4])
for index, m in enumerate(self.features):
x = m(x)
# extract layers
if index in self.return_features:
outputs.append(x)
return outputs
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]))
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["resnet34"]))
return model
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))
return model
def shaperesnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model, pretrained on Stylized-ImageNe and ImageNet and fine-tuned on ImageNet.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(
model_zoo.load_url(
model_urls["resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN"]
)
)
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]))
return model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["resnet152"]))
return model
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
model_urls = {
"vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
}
class VGG(nn.Module):
def __init__(self, features, return_features, init_weights=True):
super(VGG, self).__init__()
self.features = features
self.return_features = return_features
if init_weights:
self._initialize_weights()
def forward(self, x):
outputs = []
# hw of input, needed for DRIU and HED
outputs.append(x.shape[2:4])
for index, m in enumerate(self.features):
x = m(x)
# extract layers
if index in self.return_features:
outputs.append(x)
return outputs
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == "M":
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
_cfg = {
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"D": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
"M",
512,
512,
512,
"M",
512,
512,
512,
"M",
],
"E": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
256,
"M",
512,
512,
512,
512,
"M",
512,
512,
512,
512,
"M",
],
}
def vgg11(pretrained=False, **kwargs):
"""VGG 11-layer model (configuration "A")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs["init_weights"] = False
model = VGG(make_layers(_cfg["A"]), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["vgg11"]))
return model
def vgg11_bn(pretrained=False, **kwargs):
"""VGG 11-layer model (configuration "A") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs["init_weights"] = False
model = VGG(make_layers(_cfg["A"], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["vgg11_bn"]))
return model
def vgg13(pretrained=False, **kwargs):
"""VGG 13-layer model (configuration "B")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs["init_weights"] = False
model = VGG(make_layers(_cfg["B"]), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["vgg13"]))
return model
def vgg13_bn(pretrained=False, **kwargs):
"""VGG 13-layer model (configuration "B") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs["init_weights"] = False
model = VGG(make_layers(_cfg["B"], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["vgg13_bn"]))
return model
def vgg16(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs["init_weights"] = False
model = VGG(make_layers(_cfg["D"]), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["vgg16"]), strict=False)
return model
def vgg16_bn(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D") with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs["init_weights"] = False
model = VGG(make_layers(_cfg["D"], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["vgg16_bn"]))
return model
def vgg19(pretrained=False, **kwargs):
"""VGG 19-layer model (configuration "E")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs["init_weights"] = False
model = VGG(make_layers(_cfg["E"]), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["vgg19"]))
return model
def vgg19_bn(pretrained=False, **kwargs):
"""VGG 19-layer model (configuration 'E') with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
kwargs["init_weights"] = False
model = VGG(make_layers(_cfg["E"], batch_norm=True), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["vgg19_bn"]))
return model
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
import torchvision.models.mobilenet
class MobileNetV24Segmentation(torchvision.models.mobilenet.MobileNetV2):
"""Adaptation of base MobileNetV2 functionality to U-Net style segmentation
This version of MobileNetV2 is slightly modified so it can be used through
torchvision's API. It outputs intermediate features which are normally not
output by the base MobileNetV2 implementation, but are required for
segmentation operations.
Parameters
==========
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
"""
def __init__(self, *args, **kwargs):
self._return_features = kwargs.pop("return_features")
super(MobileNetV24Segmentation, self).__init__(*args, **kwargs)
def forward(self, x):
outputs = []
# hw of input, needed for DRIU and HED
outputs.append(x.shape[2:4])
outputs.append(x)
for index, m in enumerate(self.features):
x = m(x)
# extract layers
if index in self._return_features:
outputs.append(x)
return outputs
def mobilenet_v2_for_segmentation(pretrained=False, progress=True, **kwargs):
model = MobileNetV24Segmentation(**kwargs)
if pretrained:
state_dict = torchvision.models.mobilenet.load_state_dict_from_url(
torchvision.models.mobilenet.model_urls["mobilenet_v2"],
progress=progress,
)
model.load_state_dict(state_dict)
# erase MobileNetV2 head (for classification), not used for segmentation
delattr(model, 'classifier')
return_features = kwargs.get("return_features")
if return_features is not None:
model.features = model.features[:(max(return_features)+1)]
return model
mobilenet_v2_for_segmentation.__doc__ = (
torchvision.models.mobilenet.mobilenet_v2.__doc__
)
#!/usr/bin/env python
# coding=utf-8
import torchvision.models.resnet
class ResNet4Segmentation(torchvision.models.resnet.ResNet):
"""Adaptation of base ResNet functionality to U-Net style segmentation
This version of ResNet is slightly modified so it can be used through
torchvision's API. It outputs intermediate features which are normally not
output by the base ResNet implementation, but are required for segmentation
operations.
Parameters
==========
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
"""
def __init__(self, *args, **kwargs):
self._return_features = kwargs.pop("return_features")
super(ResNet4Segmentation, self).__init__(*args, **kwargs)
def forward(self, x):
outputs = []
# hardwiring of input
outputs.append(x.shape[2:4])
for index, m in enumerate(self.features):
x = m(x)
# extract layers
if index in self.return_features:
outputs.append(x)
return outputs
def _resnet_for_segmentation(
arch, block, layers, pretrained, progress, **kwargs
):
model = ResNet4Segmentation(block, layers, **kwargs)
if pretrained:
state_dict = torchvision.models.resnet.load_state_dict_from_url(
torchvision.models.resnet.model_urls[arch], progress=progress
)
model.load_state_dict(state_dict)
# erase ResNet head (for classification), not used for segmentation
delattr(model, 'avgpool')
delattr(model, 'fc')
return model
def resnet50_for_segmentation(pretrained=False, progress=True, **kwargs):
return _resnet_for_segmentation(
"resnet50",
torchvision.models.resnet.Bottleneck,
[3, 4, 6, 3],
pretrained,
progress,
**kwargs
)
resnet50_for_segmentation.__doc__ = torchvision.models.resnet.resnet50.__doc__
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torchvision.models.vgg
class VGG4Segmentation(torchvision.models.vgg.VGG):
"""Adaptation of base VGG functionality to U-Net style segmentation
This version of VGG is slightly modified so it can be used through
torchvision's API. It outputs intermediate features which are normally not
output by the base VGG implementation, but are required for segmentation
operations.
Parameters
==========
return_features : :py:class:`list`, Optional
A list of integers indicating the feature layers to be returned from
the original module.
"""
def __init__(self, *args, **kwargs):
self._return_features = kwargs.pop("return_features")
super(VGG4Segmentation, self).__init__(*args, **kwargs)
def forward(self, x):
outputs = []
# hardwiring of input
outputs.append(x.shape[2:4])
for index, m in enumerate(self.features):
x = m(x)
# extract layers
if index in self._return_features:
outputs.append(x)
return outputs
def _vgg_for_segmentation(
arch, cfg, batch_norm, pretrained, progress, **kwargs
):
if pretrained:
kwargs["init_weights"] = False
model = VGG4Segmentation(
torchvision.models.vgg.make_layers(
torchvision.models.vgg.cfgs[cfg], batch_norm=batch_norm
),
**kwargs
)
if pretrained:
state_dict = torchvision.models.vgg.load_state_dict_from_url(
torchvision.models.vgg.model_urls[arch], progress=progress
)
model.load_state_dict(state_dict)
# erase VGG head (for classification), not used for segmentation
delattr(model, 'classifier')
delattr(model, 'avgpool')
return model
def vgg16_for_segmentation(pretrained=False, progress=True, **kwargs):
return _vgg_for_segmentation(
"vgg16", "D", False, pretrained, progress, **kwargs
)
vgg16_for_segmentation.__doc__ = torchvision.models.vgg16.__doc__
def vgg16_bn_for_segmentation(pretrained=False, progress=True, **kwargs):
return _vgg_for_segmentation(
"vgg16_bn", "D", True, pretrained, progress, **kwargs
)
vgg16_bn_for_segmentation.__doc__ = torchvision.models.vgg16_bn.__doc__
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment