Skip to content
Snippets Groups Projects
Commit d36ce7c0 authored by Tim Laibacher's avatar Tim Laibacher
Browse files

Add M2U, ResUNet, cropconfigs

parent f30a4f83
No related branches found
No related tags found
No related merge requests found
Pipeline #29667 failed
Showing
with 605 additions and 23 deletions
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from bob.db.drive import Database as DRIVE
from bob.ip.binseg.data.transforms import *
from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config ####
transforms = Compose([
CenterCrop((544,544))
,ToTensor()
])
# bob.db.dataset init
bobdb = DRIVE(protocol = 'default')
# PyTorch dataset
dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from bob.db.drive import Database as DRIVE
from bob.ip.binseg.data.transforms import *
from bob.ip.binseg.data.binsegdataset import BinSegDataset
#### Config ####
transforms = Compose([
CenterCrop((544,544))
,RandomHFlip()
,RandomVFlip()
,RandomRotation()
,ColorJitter()
,ToTensor()
])
# bob.db.dataset init
bobdb = DRIVE(protocol = 'default')
# PyTorch dataset
dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
......@@ -4,15 +4,20 @@
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.driu import build_driu
import torch.optim as optim
from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
from torch.nn import BCEWithLogitsLoss
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
lr = 0.001
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
amsgrad = False
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
scheduler_milestones = [150]
scheduler_gamma = 0.1
......@@ -24,7 +29,8 @@ model = build_driu()
pretrained_backbone = modelurls['vgg16']
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
eps=eps, weight_decay=weight_decay, amsbound=amsbound)
# criterion
criterion = WeightedBCELogitsLoss(reduction='mean')
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.driu import build_driu
import torch.optim as optim
from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
from bob.ip.binseg.utils.model_zoo import modelurls
##### Config #####
lr = 0.001
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
amsgrad = False
scheduler_milestones = [150]
scheduler_gamma = 0.1
# model
model = build_driu()
# pretrained backbone
pretrained_backbone = modelurls['vgg16']
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
# criterion
criterion = WeightedBCELogitsLoss(reduction='mean')
# scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
......@@ -6,17 +6,23 @@ from bob.ip.binseg.modeling.hed import build_hed
import torch.optim as optim
from bob.ip.binseg.modeling.losses import HEDWeightedBCELogitsLoss
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
lr = 0.001
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
amsgrad = False
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
scheduler_milestones = [150]
scheduler_gamma = 0.1
# model
model = build_hed()
......@@ -24,8 +30,8 @@ model = build_hed()
pretrained_backbone = modelurls['vgg16']
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
eps=eps, weight_decay=weight_decay, amsbound=amsbound)
# criterion
criterion = HEDWeightedBCELogitsLoss(reduction='mean')
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.m2u import build_m2unet
import torch.optim as optim
from torch.nn import BCEWithLogitsLoss
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
lr = 0.001
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
scheduler_milestones = [150]
scheduler_gamma = 0.1
# model
model = build_m2unet()
# pretrained backbone
pretrained_backbone = modelurls['mobilenetv2']
# optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
eps=eps, weight_decay=weight_decay, amsbound=amsbound)
# criterion
criterion = WeightedBCELogitsLoss(reduction='mean')
# scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.resunet import build_res50unet
import torch.optim as optim
from torch.nn import BCEWithLogitsLoss
from bob.ip.binseg.utils.model_zoo import modelurls
from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
from bob.ip.binseg.engine.adabound import AdaBound
##### Config #####
lr = 0.001
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0
final_lr = 0.1
gamma = 1e-3
eps = 1e-8
amsbound = False
scheduler_milestones = [150]
scheduler_gamma = 0.1
# model
model = build_res50unet()
# pretrained backbone
pretrained_backbone = modelurls['resnet50']
# optimizer
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
eps=eps, weight_decay=weight_decay, amsbound=amsbound)
# criterion
criterion = WeightedBCELogitsLoss(reduction='mean')
# scheduler
scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
......@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
from torch.optim.lr_scheduler import MultiStepLR
from bob.ip.binseg.modeling.driu import build_driu
from bob.ip.binseg.modeling.unet import build_unet
import torch.optim as optim
from torch.nn import BCEWithLogitsLoss
from bob.ip.binseg.utils.model_zoo import modelurls
......@@ -23,13 +23,12 @@ scheduler_milestones = [150]
scheduler_gamma = 0.1
# model
model = build_driu()
model = build_unet()
# pretrained backbone
pretrained_backbone = modelurls['vgg16']
# optimizer
# TODO: Add Adabound
optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
eps=eps, weight_decay=weight_decay, amsbound=amsbound)
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Adopted from https://github.com/tonylins/pytorch-mobilenet-v2/ by @tonylins
# Ji Lin under Apache License 2.0
import torch.nn as nn
import math
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expand_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expand_ratio == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, n_class=1000, input_size=224, width_mult=1., return_features = None, m2u=True):
super(MobileNetV2, self).__init__()
self.return_features = return_features
self.m2u = m2u
block = InvertedResidual
input_channel = 32
last_channel = 1280
interverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
assert input_size % 32 == 0
input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
self.features = [conv_bn(3, input_channel, 2)]
# building inverted residual blocks
for t, c, n, s in interverted_residual_setting:
output_channel = int(c * width_mult)
for i in range(n):
if i == 0:
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
else:
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
input_channel = output_channel
# building last several layers
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
# make it nn.Sequential
self.features = nn.Sequential(*self.features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, n_class),
)
self._initialize_weights()
def forward(self, x):
outputs = []
# hw of input, needed for DRIU and HED
outputs.append(x.shape[2:4])
if self.m2u:
outputs.append(x)
for index,m in enumerate(self.features):
x = m(x)
# extract layers
if index in self.return_features:
outputs.append(x)
return outputs
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
# resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN : https://github.com/rgeirhos/texture-vs-shap
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
......@@ -17,7 +15,7 @@ model_urls = {
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnet50__SIN_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar',
'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar',
}
......@@ -115,18 +113,19 @@ class ResNet(nn.Module):
super(ResNet, self).__init__()
self.inplanes = 64
self.return_features = return_features
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
bn1 = nn.BatchNorm2d(64)
relu = nn.ReLU(inplace=True)
maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
layer1 = self._make_layer(block, 64, layers[0])
layer2 = self._make_layer(block, 128, layers[1], stride=2)
layer3 = self._make_layer(block, 256, layers[2], stride=2)
layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.features = nn.ModuleList([conv1, bn1, relu, maxpool
,layer1,layer2,layer3,layer4])
self.features = [self.conv1, self.bn1, self.relu, self.maxpool
,self.layer1,self.layer2,self.layer3,self.layer4]
for m in self.modules():
if isinstance(m, nn.Conv2d):
......@@ -205,6 +204,16 @@ def resnet50(pretrained=False, **kwargs):
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def shaperesnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model, pretrained on Stylized-ImageNe and ImageNet and fine-tuned on ImageNet.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN']))
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# https://github.com/laibe/M2U-Net
from collections import OrderedDict
import torch
from torch import nn
from bob.ip.binseg.modeling.backbones.mobilenetv2 import MobileNetV2, InvertedResidual
class DecoderBlock(nn.Module):
"""
Decoder block: upsample and concatenate with features maps from the encoder part
"""
def __init__(self,up_in_c,x_in_c,upsamplemode='bilinear',expand_ratio=0.15):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2,mode=upsamplemode,align_corners=False) # H, W -> 2H, 2W
self.ir1 = InvertedResidual(up_in_c+x_in_c,(x_in_c + up_in_c) // 2,stride=1,expand_ratio=expand_ratio)
def forward(self,up_in,x_in):
up_out = self.upsample(up_in)
cat_x = torch.cat([up_out, x_in] , dim=1)
x = self.ir1(cat_x)
return x
class LastDecoderBlock(nn.Module):
def __init__(self,x_in_c,upsamplemode='bilinear',expand_ratio=0.15):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2,mode=upsamplemode,align_corners=False) # H, W -> 2H, 2W
self.ir1 = InvertedResidual(x_in_c,1,stride=1,expand_ratio=expand_ratio)
def forward(self,up_in,x_in):
up_out = self.upsample(up_in)
cat_x = torch.cat([up_out, x_in] , dim=1)
x = self.ir1(cat_x)
return x
class M2U(nn.Module):
"""
M2U-Net head module
Attributes
----------
in_channels_list (list[int]): number of channels for each feature map that is returned from backbone
"""
def __init__(self, in_channels_list=None,upsamplemode='bilinear',expand_ratio=0.15):
super(M2U, self).__init__()
# Decoder
self.decode4 = DecoderBlock(96,32,upsamplemode,expand_ratio)
self.decode3 = DecoderBlock(64,24,upsamplemode,expand_ratio)
self.decode2 = DecoderBlock(44,16,upsamplemode,expand_ratio)
self.decode1 = LastDecoderBlock(33,upsamplemode,expand_ratio)
# initilaize weights
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, a=1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self,x):
decode4 = self.decode4(x[5],x[4]) # 96, 32
decode3 = self.decode3(decode4,x[3]) # 64, 24
decode2 = self.decode2(decode3,x[2]) # 44, 16
decode1 = self.decode1(decode2,x[1]) # 30, 3
return decode1
def build_m2unet():
backbone = MobileNetV2(return_features = [1,3,6,13], m2u=True)
m2u_head = M2U(in_channels_list=[16, 24, 32, 96])
model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", m2u_head)]))
model.name = "M2UNet"
return model
\ No newline at end of file
......@@ -88,4 +88,73 @@ class UpsampleCropBlock(nn.Module):
# needs explicit ranges for onnx export
x = x[:,:,h_s:h_e,w_s:w_e] # crop to input size
return x
def ifnone(a, b):
"`a` if `a` is not None, otherwise `b`."
return b if a is None else a
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
"""
https://docs.fast.ai/layers.html#PixelShuffle_ICNR
ICNR init of `x`, with `scale` and `init` function.
"""
ni,nf,h,w = x.shape
ni2 = int(ni/(scale**2))
k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
k = k.contiguous().view(ni2, nf, -1)
k = k.repeat(1, 1, scale**2)
k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
x.data.copy_(k)
class PixelShuffle_ICNR(nn.Module):
"""
https://docs.fast.ai/layers.html#PixelShuffle_ICNR
Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`.
"""
def __init__(self, ni:int, nf:int=None, scale:int=2):
super().__init__()
nf = ifnone(nf, ni)
self.conv = conv_with_kaiming_uniform(ni, nf*(scale**2), 1)
icnr(self.conv.weight)
self.shuf = nn.PixelShuffle(scale)
# Blurring over (h*w) kernel
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
# - https://arxiv.org/abs/1806.02658
self.pad = nn.ReplicationPad2d((1,0,1,0))
self.blur = nn.AvgPool2d(2, stride=1)
self.relu = nn.ReLU(inplace=True)
def forward(self,x):
x = self.shuf(self.relu(self.conv(x)))
x = self.blur(self.pad(x))
return x
class UnetBlock(nn.Module):
def __init__(self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False):
super().__init__()
# middle block for VGG based U-Net
if middle_block:
up_out_c = up_in_c
else:
up_out_c = up_in_c // 2
cat_channels = x_in_c + up_out_c
inner_channels = cat_channels // 2
if pixel_shuffle:
self.upsample = PixelShuffle_ICNR( up_in_c, up_out_c )
else:
self.upsample = convtrans_with_kaiming_uniform( up_in_c, up_out_c, 2, 2)
self.convtrans1 = convtrans_with_kaiming_uniform( cat_channels, inner_channels, 3, 1, 1)
self.convtrans2 = convtrans_with_kaiming_uniform( inner_channels, inner_channels, 3, 1, 1)
self.relu = nn.ReLU(inplace=True)
def forward(self, up_in, x_in):
up_out = self.upsample(up_in)
cat_x = torch.cat([up_out, x_in] , dim=1)
x = self.relu(self.convtrans1(cat_x))
x = self.relu(self.convtrans2(x))
return x
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch
from collections import OrderedDict
from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, PixelShuffle_ICNR, UnetBlock
from bob.ip.binseg.modeling.backbones.resnet import resnet50
class ResUNet(nn.Module):
"""
UNet head module for ResNet backbones
Attributes
----------
in_channels_list (list[int]): number of channels for each feature map that is returned from backbone
"""
def __init__(self, in_channels_list=None, pixel_shuffle=False):
super(ResUNet, self).__init__()
# number of channels
c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
# number of channels for last upsampling operation
c_decode0 = (c_decode1 + c_decode2//2)//2
# build layers
self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle)
self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle)
self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle)
self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle)
if pixel_shuffle:
self.decode0 = PixelShuffle_ICNR(c_decode0, c_decode0)
else:
self.decode0 = convtrans_with_kaiming_uniform(c_decode0, c_decode0, 2, 2)
self.final = conv_with_kaiming_uniform(c_decode0, 1, 1)
def forward(self,x):
"""
Arguments:
x (list[Tensor]): tensor as returned from the backbone network.
First element: height and width of input image.
Remaining elements: feature maps for each feature level.
"""
# NOTE: x[0]: height and width of input image not needed in U-Net architecture
decode4 = self.decode4(x[5], x[4])
decode3 = self.decode3(decode4, x[3])
decode2 = self.decode2(decode3, x[2])
decode1 = self.decode1(decode2, x[1])
decode0 = self.decode0(decode1)
out = self.final(decode0)
return out
def build_res50unet():
backbone = resnet50(pretrained=False, return_features = [2, 4, 5, 6, 7])
unet_head = ResUNet([64, 256, 512, 1024, 2048],pixel_shuffle=False)
model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
model.name = "ResUNet"
return model
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch
from collections import OrderedDict
from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, PixelShuffle_ICNR, UnetBlock
from bob.ip.binseg.modeling.backbones.vgg import vgg16
class UNet(nn.Module):
"""
UNet head module
Attributes
----------
in_channels_list (list[int]): number of channels for each feature map that is returned from backbone
"""
def __init__(self, in_channels_list=None, pixel_shuffle=False):
super(UNet, self).__init__()
# number of channels
c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
# build layers
self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle, middle_block=True)
self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle)
self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle)
self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle)
self.final = conv_with_kaiming_uniform(c_decode1, 1, 1)
def forward(self,x):
"""
Arguments:
x (list[Tensor]): tensor as returned from the backbone network.
First element: height and width of input image.
Remaining elements: feature maps for each feature level.
"""
# NOTE: x[0]: height and width of input image not needed in U-Net architecture
decode4 = self.decode4(x[5], x[4])
decode3 = self.decode3(decode4, x[3])
decode2 = self.decode2(decode3, x[2])
decode1 = self.decode1(decode2, x[1])
out = self.final(decode1)
return out
def build_unet():
backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22, 29])
unet_head = UNet([64, 128, 256, 512, 512], pixel_shuffle=False)
model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
model.name = "UNet"
return model
\ No newline at end of file
......@@ -34,6 +34,7 @@ modelurls = {
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
"resnet50_SIN_IN": "https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar",
"mobilenetv2": "https://dl.dropboxusercontent.com/s/4nie4ygivq04p8y/mobilenet_v2.pth.tar",
}
def _download_url_to_file(url, dst, hash_prefix, progress):
......
......@@ -27,6 +27,7 @@ requirements:
- setuptools {{ setuptools }}
- torchvision {{ torchvision }}
- pytorch {{ pytorch }}
- cudatoolkit=8.0
- numpy {{ numpy }}
- bob.extension
# place your other host dependencies here
......
......@@ -55,8 +55,13 @@ setup(
'bob.ip.binseg.config': [
'DRIU = bob.ip.binseg.configs.models.driu',
'HED = bob.ip.binseg.configs.models.hed',
'M2UNet = bob.ip.binseg.configs.models.m2unet',
'UNet = bob.ip.binseg.configs.models.unet',
'ResUNet = bob.ip.binseg.configs.models.resunet',
'DRIUADABOUND = bob.ip.binseg.configs.models.driuadabound',
'DRIVETRAIN = bob.ip.binseg.configs.datasets.drivetrain',
'DRIVECROPTRAIN = bob.ip.binseg.configs.datasets.drivecroptrain',
'DRIVECROPTEST = bob.ip.binseg.configs.datasets.drivecroptest',
'DRIVETEST = bob.ip.binseg.configs.datasets.drivetest',
]
},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment