diff --git a/bob/ip/binseg/configs/datasets/drivecroptest.py b/bob/ip/binseg/configs/datasets/drivecroptest.py
new file mode 100644
index 0000000000000000000000000000000000000000..230598dce92a39276e05dd4b4f842643428546b4
--- /dev/null
+++ b/bob/ip/binseg/configs/datasets/drivecroptest.py
@@ -0,0 +1,19 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from bob.db.drive import Database as DRIVE
+from bob.ip.binseg.data.transforms import *
+from bob.ip.binseg.data.binsegdataset import BinSegDataset
+
+#### Config ####
+
+transforms = Compose([  
+                        CenterCrop((544,544))
+                        ,ToTensor()
+                    ])
+
+# bob.db.dataset init
+bobdb = DRIVE(protocol = 'default')
+
+# PyTorch dataset
+dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
diff --git a/bob/ip/binseg/configs/datasets/drivecroptrain.py b/bob/ip/binseg/configs/datasets/drivecroptrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b6fa356b6f944f6e7e29e1f44b3dda3ec80d9a8
--- /dev/null
+++ b/bob/ip/binseg/configs/datasets/drivecroptrain.py
@@ -0,0 +1,23 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from bob.db.drive import Database as DRIVE
+from bob.ip.binseg.data.transforms import *
+from bob.ip.binseg.data.binsegdataset import BinSegDataset
+
+#### Config ####
+
+transforms = Compose([  
+                        CenterCrop((544,544))
+                        ,RandomHFlip()
+                        ,RandomVFlip()
+                        ,RandomRotation()
+                        ,ColorJitter()
+                        ,ToTensor()
+                    ])
+
+# bob.db.dataset init
+bobdb = DRIVE(protocol = 'default')
+
+# PyTorch dataset
+dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
diff --git a/bob/ip/binseg/configs/models/driu.py b/bob/ip/binseg/configs/models/driu.py
index c801206b9c4538fa4ce08d75d7451ed4a73556db..37911ed7837dd05f8e61906878782f07ad6c325c 100644
--- a/bob/ip/binseg/configs/models/driu.py
+++ b/bob/ip/binseg/configs/models/driu.py
@@ -4,15 +4,20 @@
 from torch.optim.lr_scheduler import MultiStepLR
 from bob.ip.binseg.modeling.driu import build_driu
 import torch.optim as optim
-from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
+from torch.nn import BCEWithLogitsLoss
 from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
+from bob.ip.binseg.engine.adabound import AdaBound
 
 ##### Config #####
 lr = 0.001
 betas = (0.9, 0.999)
 eps = 1e-08
 weight_decay = 0
-amsgrad = False
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
 
 scheduler_milestones = [150]
 scheduler_gamma = 0.1
@@ -24,7 +29,8 @@ model = build_driu()
 pretrained_backbone = modelurls['vgg16']
 
 # optimizer
-optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
     
 # criterion
 criterion = WeightedBCELogitsLoss(reduction='mean')
diff --git a/bob/ip/binseg/configs/models/driuadam.py b/bob/ip/binseg/configs/models/driuadam.py
new file mode 100644
index 0000000000000000000000000000000000000000..c801206b9c4538fa4ce08d75d7451ed4a73556db
--- /dev/null
+++ b/bob/ip/binseg/configs/models/driuadam.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.driu import build_driu
+import torch.optim as optim
+from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+amsgrad = False
+
+scheduler_milestones = [150]
+scheduler_gamma = 0.1
+
+# model
+model = build_driu()
+
+# pretrained backbone
+pretrained_backbone = modelurls['vgg16']
+
+# optimizer
+optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
+    
+# criterion
+criterion = WeightedBCELogitsLoss(reduction='mean')
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/configs/models/hed.py b/bob/ip/binseg/configs/models/hed.py
index 0677ef59f39ac1c6fe67b69151fdd39e1fd974fc..ed79474505d02f453e789c0263d0cde2cebae396 100644
--- a/bob/ip/binseg/configs/models/hed.py
+++ b/bob/ip/binseg/configs/models/hed.py
@@ -6,17 +6,23 @@ from bob.ip.binseg.modeling.hed import build_hed
 import torch.optim as optim
 from bob.ip.binseg.modeling.losses import HEDWeightedBCELogitsLoss
 from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.engine.adabound import AdaBound
+
 
 ##### Config #####
 lr = 0.001
 betas = (0.9, 0.999)
 eps = 1e-08
 weight_decay = 0
-amsgrad = False
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
 
 scheduler_milestones = [150]
 scheduler_gamma = 0.1
 
+
 # model
 model = build_hed()
 
@@ -24,8 +30,8 @@ model = build_hed()
 pretrained_backbone = modelurls['vgg16']
 
 # optimizer
-optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
-    
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
 # criterion
 criterion = HEDWeightedBCELogitsLoss(reduction='mean')
 
diff --git a/bob/ip/binseg/configs/models/m2unet.py b/bob/ip/binseg/configs/models/m2unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..471ce372bf93b8b7302af0370731e75ccea18c14
--- /dev/null
+++ b/bob/ip/binseg/configs/models/m2unet.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.m2u import build_m2unet
+import torch.optim as optim
+from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
+from bob.ip.binseg.engine.adabound import AdaBound
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+scheduler_milestones = [150]
+scheduler_gamma = 0.1
+
+# model
+model = build_m2unet()
+
+# pretrained backbone
+pretrained_backbone = modelurls['mobilenetv2']
+
+# optimizer
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+    
+# criterion
+criterion = WeightedBCELogitsLoss(reduction='mean')
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/configs/models/resunet.py b/bob/ip/binseg/configs/models/resunet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0725443096d9970a5c8b835bcf051214cec99ecc
--- /dev/null
+++ b/bob/ip/binseg/configs/models/resunet.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.resunet import build_res50unet
+import torch.optim as optim
+from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
+from bob.ip.binseg.engine.adabound import AdaBound
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+scheduler_milestones = [150]
+scheduler_gamma = 0.1
+
+# model
+model = build_res50unet()
+
+# pretrained backbone
+pretrained_backbone = modelurls['resnet50']
+
+# optimizer
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+    
+# criterion
+criterion = WeightedBCELogitsLoss(reduction='mean')
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/configs/models/driuadabound.py b/bob/ip/binseg/configs/models/unet.py
similarity index 91%
rename from bob/ip/binseg/configs/models/driuadabound.py
rename to bob/ip/binseg/configs/models/unet.py
index b17e1fac1cb8cf524e5fc54fd92154ec02aa083a..ccc0f2c554b72eaf4f16b0a137998fbb0467c6c1 100644
--- a/bob/ip/binseg/configs/models/driuadabound.py
+++ b/bob/ip/binseg/configs/models/unet.py
@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.driu import build_driu
+from bob.ip.binseg.modeling.unet import build_unet
 import torch.optim as optim
 from torch.nn import BCEWithLogitsLoss
 from bob.ip.binseg.utils.model_zoo import modelurls
@@ -23,13 +23,12 @@ scheduler_milestones = [150]
 scheduler_gamma = 0.1
 
 # model
-model = build_driu()
+model = build_unet()
 
 # pretrained backbone
 pretrained_backbone = modelurls['vgg16']
 
 # optimizer
-# TODO: Add Adabound
 optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
                  eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
     
diff --git a/bob/ip/binseg/modeling/backbones/mobilenetv2.py b/bob/ip/binseg/modeling/backbones/mobilenetv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d821430560ebc3ec319a5cb2104b77b7e6fef953
--- /dev/null
+++ b/bob/ip/binseg/modeling/backbones/mobilenetv2.py
@@ -0,0 +1,140 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+
+# Adopted from https://github.com/tonylins/pytorch-mobilenet-v2/ by @tonylins 
+# Ji Lin under Apache License 2.0
+
+import torch.nn as nn
+import math
+
+
+def conv_bn(inp, oup, stride):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+        nn.BatchNorm2d(oup),
+        nn.ReLU6(inplace=True)
+    )
+
+
+def conv_1x1_bn(inp, oup):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+        nn.BatchNorm2d(oup),
+        nn.ReLU6(inplace=True)
+    )
+
+
+class InvertedResidual(nn.Module):
+    def __init__(self, inp, oup, stride, expand_ratio):
+        super(InvertedResidual, self).__init__()
+        self.stride = stride
+        assert stride in [1, 2]
+
+        hidden_dim = round(inp * expand_ratio)
+        self.use_res_connect = self.stride == 1 and inp == oup
+
+        if expand_ratio == 1:
+            self.conv = nn.Sequential(
+                # dw
+                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+                nn.BatchNorm2d(hidden_dim),
+                nn.ReLU6(inplace=True),
+                # pw-linear
+                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+                nn.BatchNorm2d(oup),
+            )
+        else:
+            self.conv = nn.Sequential(
+                # pw
+                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
+                nn.BatchNorm2d(hidden_dim),
+                nn.ReLU6(inplace=True),
+                # dw
+                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+                nn.BatchNorm2d(hidden_dim),
+                nn.ReLU6(inplace=True),
+                # pw-linear
+                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+                nn.BatchNorm2d(oup),
+            )
+
+    def forward(self, x):
+        if self.use_res_connect:
+            return x + self.conv(x)
+        else:
+            return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+    def __init__(self, n_class=1000, input_size=224, width_mult=1., return_features = None, m2u=True):
+        super(MobileNetV2, self).__init__()
+        self.return_features = return_features 
+        self.m2u = m2u 
+        block = InvertedResidual
+        input_channel = 32
+        last_channel = 1280
+        interverted_residual_setting = [
+            # t, c, n, s
+            [1, 16, 1, 1],
+            [6, 24, 2, 2],
+            [6, 32, 3, 2],
+            [6, 64, 4, 2],
+            [6, 96, 3, 1],
+            [6, 160, 3, 2],
+            [6, 320, 1, 1],
+        ]
+
+        # building first layer
+        assert input_size % 32 == 0
+        input_channel = int(input_channel * width_mult)
+        self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
+        self.features = [conv_bn(3, input_channel, 2)]
+        # building inverted residual blocks
+        for t, c, n, s in interverted_residual_setting:
+            output_channel = int(c * width_mult)
+            for i in range(n):
+                if i == 0:
+                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
+                else:
+                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
+                input_channel = output_channel
+        # building last several layers
+        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
+        # make it nn.Sequential
+        self.features = nn.Sequential(*self.features)
+
+        # building classifier
+        self.classifier = nn.Sequential(
+            nn.Dropout(0.2),
+            nn.Linear(self.last_channel, n_class),
+        )
+
+        self._initialize_weights()
+
+    def forward(self, x):
+        outputs = []
+        # hw of input, needed for DRIU and HED
+        outputs.append(x.shape[2:4])
+        if self.m2u:
+            outputs.append(x)
+        for index,m in enumerate(self.features):
+            x = m(x)
+            # extract layers
+            if index in self.return_features:
+                outputs.append(x)
+        return outputs
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+                if m.bias is not None:
+                    m.bias.data.zero_()
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+            elif isinstance(m, nn.Linear):
+                n = m.weight.size(1)
+                m.weight.data.normal_(0, 0.01)
+                m.bias.data.zero_()
\ No newline at end of file
diff --git a/bob/ip/binseg/modeling/backbones/resnet.py b/bob/ip/binseg/modeling/backbones/resnet.py
index 008811eb0710b71f9289e665023529b81247f52b..5881652e571c94cd0aff20082a90986feffa96db 100644
--- a/bob/ip/binseg/modeling/backbones/resnet.py
+++ b/bob/ip/binseg/modeling/backbones/resnet.py
@@ -1,7 +1,5 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
 # Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 
+# resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN : https://github.com/rgeirhos/texture-vs-shap
 
 import torch.nn as nn
 import torch.utils.model_zoo as model_zoo
@@ -17,7 +15,7 @@ model_urls = {
     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
     'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
     'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
-    'resnet50__SIN_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar',
+    'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar',
 }
 
 
@@ -115,18 +113,19 @@ class ResNet(nn.Module):
         super(ResNet, self).__init__()
         self.inplanes = 64
         self.return_features = return_features
-        conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                                bias=False)
-        bn1 = nn.BatchNorm2d(64)
-        relu = nn.ReLU(inplace=True)
-        maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-        layer1 = self._make_layer(block, 64, layers[0])
-        layer2 = self._make_layer(block, 128, layers[1], stride=2)
-        layer3 = self._make_layer(block, 256, layers[2], stride=2)
-        layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
         
-        self.features = nn.ModuleList([conv1, bn1, relu, maxpool
-                                    ,layer1,layer2,layer3,layer4])
+        self.features = [self.conv1, self.bn1, self.relu, self.maxpool
+                                    ,self.layer1,self.layer2,self.layer3,self.layer4]
 
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
@@ -205,6 +204,16 @@ def resnet50(pretrained=False, **kwargs):
         model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
     return model
 
+def shaperesnet50(pretrained=False, **kwargs):
+    """Constructs a ResNet-50 model, pretrained on Stylized-ImageNe and ImageNet and fine-tuned on ImageNet.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN']))
+    return model
+
 def resnet101(pretrained=False, **kwargs):
     """Constructs a ResNet-101 model.
     Args:
diff --git a/bob/ip/binseg/modeling/m2u.py b/bob/ip/binseg/modeling/m2u.py
new file mode 100644
index 0000000000000000000000000000000000000000..e393625220a2ed1cb7e6f14f9867f2f9263bfed6
--- /dev/null
+++ b/bob/ip/binseg/modeling/m2u.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+
+# https://github.com/laibe/M2U-Net
+
+from collections import OrderedDict
+import torch
+from torch import nn
+from bob.ip.binseg.modeling.backbones.mobilenetv2 import MobileNetV2, InvertedResidual
+
+class DecoderBlock(nn.Module):
+    """
+    Decoder block: upsample and concatenate with features maps from the encoder part
+    """
+    def __init__(self,up_in_c,x_in_c,upsamplemode='bilinear',expand_ratio=0.15):
+        super().__init__()
+        self.upsample = nn.Upsample(scale_factor=2,mode=upsamplemode,align_corners=False) # H, W -> 2H, 2W
+        self.ir1 = InvertedResidual(up_in_c+x_in_c,(x_in_c + up_in_c) // 2,stride=1,expand_ratio=expand_ratio)
+
+    def forward(self,up_in,x_in):
+        up_out = self.upsample(up_in)
+        cat_x = torch.cat([up_out, x_in] , dim=1)
+        x = self.ir1(cat_x)
+        return x
+    
+class LastDecoderBlock(nn.Module):
+    def __init__(self,x_in_c,upsamplemode='bilinear',expand_ratio=0.15):
+        super().__init__()
+        self.upsample = nn.Upsample(scale_factor=2,mode=upsamplemode,align_corners=False) # H, W -> 2H, 2W
+        self.ir1 = InvertedResidual(x_in_c,1,stride=1,expand_ratio=expand_ratio)
+
+    def forward(self,up_in,x_in):
+        up_out = self.upsample(up_in)
+        cat_x = torch.cat([up_out, x_in] , dim=1)
+        x = self.ir1(cat_x)
+        return x
+
+
+
+class M2U(nn.Module):
+    """
+    M2U-Net head module
+    Attributes
+    ----------
+        in_channels_list (list[int]): number of channels for each feature map that is returned from backbone
+    """
+    def __init__(self, in_channels_list=None,upsamplemode='bilinear',expand_ratio=0.15):
+        super(M2U, self).__init__()
+
+        # Decoder
+        self.decode4 = DecoderBlock(96,32,upsamplemode,expand_ratio)
+        self.decode3 = DecoderBlock(64,24,upsamplemode,expand_ratio)
+        self.decode2 = DecoderBlock(44,16,upsamplemode,expand_ratio)
+        self.decode1 = LastDecoderBlock(33,upsamplemode,expand_ratio)
+        
+        # initilaize weights 
+        self._initialize_weights()
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_uniform_(m.weight, a=1)
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+    
+    def forward(self,x):
+        decode4 = self.decode4(x[5],x[4])    # 96, 32
+        decode3 = self.decode3(decode4,x[3]) # 64, 24
+        decode2 = self.decode2(decode3,x[2]) # 44, 16
+        decode1 = self.decode1(decode2,x[1]) # 30, 3
+        
+        return decode1
+
+def build_m2unet():
+    backbone = MobileNetV2(return_features = [1,3,6,13], m2u=True)
+    m2u_head = M2U(in_channels_list=[16, 24, 32, 96])
+
+    model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", m2u_head)]))
+    model.name = "M2UNet"
+    return model
\ No newline at end of file
diff --git a/bob/ip/binseg/modeling/make_layers.py b/bob/ip/binseg/modeling/make_layers.py
index 14e3b47f365f38156a348e21e905ae7d08400a0c..fbe40fd3a0eb3d2ae024cc51f46848f587133d65 100644
--- a/bob/ip/binseg/modeling/make_layers.py
+++ b/bob/ip/binseg/modeling/make_layers.py
@@ -88,4 +88,73 @@ class UpsampleCropBlock(nn.Module):
         # needs explicit ranges for onnx export 
         x = x[:,:,h_s:h_e,w_s:w_e] # crop to input size 
         
+        return x
+
+
+
+def ifnone(a, b):
+    "`a` if `a` is not None, otherwise `b`."
+    return b if a is None else a
+
+def icnr(x, scale=2, init=nn.init.kaiming_normal_):
+    """
+    https://docs.fast.ai/layers.html#PixelShuffle_ICNR
+    ICNR init of `x`, with `scale` and `init` function.
+    """
+    ni,nf,h,w = x.shape
+    ni2 = int(ni/(scale**2))
+    k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
+    k = k.contiguous().view(ni2, nf, -1)
+    k = k.repeat(1, 1, scale**2)
+    k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
+    x.data.copy_(k)
+
+class PixelShuffle_ICNR(nn.Module):
+    """
+    https://docs.fast.ai/layers.html#PixelShuffle_ICNR 
+    Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`.
+    """
+    def __init__(self, ni:int, nf:int=None, scale:int=2):
+        super().__init__()
+        nf = ifnone(nf, ni)
+        self.conv = conv_with_kaiming_uniform(ni, nf*(scale**2), 1)
+        icnr(self.conv.weight)
+        self.shuf = nn.PixelShuffle(scale)
+        # Blurring over (h*w) kernel
+        # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
+        # - https://arxiv.org/abs/1806.02658
+        self.pad = nn.ReplicationPad2d((1,0,1,0))
+        self.blur = nn.AvgPool2d(2, stride=1)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self,x):
+        x = self.shuf(self.relu(self.conv(x)))
+        x = self.blur(self.pad(x))
+        return x
+
+class UnetBlock(nn.Module):
+    def __init__(self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False):
+        super().__init__()
+
+        # middle block for VGG based U-Net
+        if middle_block:
+            up_out_c =  up_in_c
+        else:
+            up_out_c =  up_in_c // 2
+        cat_channels = x_in_c + up_out_c
+        inner_channels = cat_channels // 2
+        
+        if pixel_shuffle:
+            self.upsample = PixelShuffle_ICNR( up_in_c, up_out_c )
+        else:
+            self.upsample = convtrans_with_kaiming_uniform( up_in_c, up_out_c, 2, 2)
+        self.convtrans1 = convtrans_with_kaiming_uniform( cat_channels, inner_channels, 3, 1, 1)
+        self.convtrans2 = convtrans_with_kaiming_uniform( inner_channels, inner_channels, 3, 1, 1)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, up_in, x_in):
+        up_out = self.upsample(up_in)
+        cat_x = torch.cat([up_out, x_in] , dim=1)
+        x = self.relu(self.convtrans1(cat_x))
+        x = self.relu(self.convtrans2(x))
         return x
\ No newline at end of file
diff --git a/bob/ip/binseg/modeling/resunet.py b/bob/ip/binseg/modeling/resunet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ee949617b1d4ea1a58d60df375b665522999677
--- /dev/null
+++ b/bob/ip/binseg/modeling/resunet.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import torch.nn as nn
+import torch
+from collections import OrderedDict
+from bob.ip.binseg.modeling.make_layers  import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, PixelShuffle_ICNR, UnetBlock
+from bob.ip.binseg.modeling.backbones.resnet import resnet50
+
+
+
+class ResUNet(nn.Module):
+    """
+    UNet head module for ResNet backbones
+    Attributes
+    ----------
+        in_channels_list (list[int]): number of channels for each feature map that is returned from backbone
+    """
+    def __init__(self, in_channels_list=None, pixel_shuffle=False):
+        super(ResUNet, self).__init__()
+        # number of channels
+        c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
+        # number of channels for last upsampling operation
+        c_decode0 = (c_decode1 + c_decode2//2)//2
+
+        # build layers
+        self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle)
+        self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle)
+        self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle)
+        self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle)
+        if pixel_shuffle:
+            self.decode0 = PixelShuffle_ICNR(c_decode0, c_decode0)
+        else:
+            self.decode0 = convtrans_with_kaiming_uniform(c_decode0, c_decode0, 2, 2)
+        self.final = conv_with_kaiming_uniform(c_decode0, 1, 1)
+
+    def forward(self,x):
+        """
+        Arguments:
+            x (list[Tensor]): tensor as returned from the backbone network.
+                First element: height and width of input image. 
+                Remaining elements: feature maps for each feature level.
+        """
+        # NOTE: x[0]: height and width of input image not needed in U-Net architecture
+        decode4 = self.decode4(x[5], x[4])  
+        decode3 = self.decode3(decode4, x[3]) 
+        decode2 = self.decode2(decode3, x[2]) 
+        decode1 = self.decode1(decode2, x[1]) 
+        decode0 = self.decode0(decode1)
+        out = self.final(decode0)
+        return out
+
+def build_res50unet():
+    backbone = resnet50(pretrained=False, return_features = [2, 4, 5, 6, 7])
+    unet_head  = ResUNet([64, 256, 512, 1024, 2048],pixel_shuffle=False)
+    model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
+    model.name = "ResUNet"
+    return model
\ No newline at end of file
diff --git a/bob/ip/binseg/modeling/unet.py b/bob/ip/binseg/modeling/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9be2b498aa37c610686aa2e1daaea3be34c27d1e
--- /dev/null
+++ b/bob/ip/binseg/modeling/unet.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import torch.nn as nn
+import torch
+from collections import OrderedDict
+from bob.ip.binseg.modeling.make_layers  import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, PixelShuffle_ICNR, UnetBlock
+from bob.ip.binseg.modeling.backbones.vgg import vgg16
+
+
+
+class UNet(nn.Module):
+    """
+    UNet head module
+    Attributes
+    ----------
+        in_channels_list (list[int]): number of channels for each feature map that is returned from backbone
+    """
+    def __init__(self, in_channels_list=None, pixel_shuffle=False):
+        super(UNet, self).__init__()
+        # number of channels
+        c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
+        
+        # build layers
+        self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle, middle_block=True)
+        self.decode3 = UnetBlock(c_decode4, c_decode3, pixel_shuffle)
+        self.decode2 = UnetBlock(c_decode3, c_decode2, pixel_shuffle)
+        self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle)
+        self.final = conv_with_kaiming_uniform(c_decode1, 1, 1)
+
+    def forward(self,x):
+        """
+        Arguments:
+            x (list[Tensor]): tensor as returned from the backbone network.
+                First element: height and width of input image. 
+                Remaining elements: feature maps for each feature level.
+        """
+        # NOTE: x[0]: height and width of input image not needed in U-Net architecture
+        decode4 = self.decode4(x[5], x[4])  
+        decode3 = self.decode3(decode4, x[3]) 
+        decode2 = self.decode2(decode3, x[2]) 
+        decode1 = self.decode1(decode2, x[1]) 
+        out = self.final(decode1)
+        return out
+
+def build_unet():
+    backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22, 29])
+    unet_head = UNet([64, 128, 256, 512, 512], pixel_shuffle=False)
+
+    model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
+    model.name = "UNet"
+    return model
\ No newline at end of file
diff --git a/bob/ip/binseg/utils/model_zoo.py b/bob/ip/binseg/utils/model_zoo.py
index e8f94abf1c6dcccc1cd44ec965940bff4bcb70fb..8bc7c931e0e05d682e03cc52c22413c2f185965e 100644
--- a/bob/ip/binseg/utils/model_zoo.py
+++ b/bob/ip/binseg/utils/model_zoo.py
@@ -34,6 +34,7 @@ modelurls = {
     "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
     "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
     "resnet50_SIN_IN": "https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar",
+    "mobilenetv2": "https://dl.dropboxusercontent.com/s/4nie4ygivq04p8y/mobilenet_v2.pth.tar",
     }
 
 def _download_url_to_file(url, dst, hash_prefix, progress):
diff --git a/conda/meta.yaml b/conda/meta.yaml
index b85fcc209bee58e5baddff85524e69decb709df0..a942971b19d98526c014e1d15dacb36d8810e50f 100644
--- a/conda/meta.yaml
+++ b/conda/meta.yaml
@@ -27,6 +27,7 @@ requirements:
     - setuptools {{ setuptools }}
     - torchvision  {{ torchvision }}
     - pytorch {{ pytorch }}
+    - cudatoolkit=8.0
     - numpy {{ numpy }}
     - bob.extension
     # place your other host dependencies here
diff --git a/setup.py b/setup.py
index 2a2a92f569ea9f69880db2ac91d3703ac555d25d..b3423a5dd6a2931690d985eb0c3c5bbd8a1ff434 100644
--- a/setup.py
+++ b/setup.py
@@ -55,8 +55,13 @@ setup(
         'bob.ip.binseg.config': [
           'DRIU = bob.ip.binseg.configs.models.driu',
           'HED = bob.ip.binseg.configs.models.hed',
+          'M2UNet = bob.ip.binseg.configs.models.m2unet',
+          'UNet = bob.ip.binseg.configs.models.unet',
+          'ResUNet = bob.ip.binseg.configs.models.resunet',
           'DRIUADABOUND = bob.ip.binseg.configs.models.driuadabound',
           'DRIVETRAIN = bob.ip.binseg.configs.datasets.drivetrain',
+          'DRIVECROPTRAIN = bob.ip.binseg.configs.datasets.drivecroptrain',
+          'DRIVECROPTEST = bob.ip.binseg.configs.datasets.drivecroptest',
           'DRIVETEST = bob.ip.binseg.configs.datasets.drivetest',
           ]
     },