From 805e43a1f67d8fbe5d34efb1932986d4259cf706 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 18 May 2020 19:59:47 +0200
Subject: [PATCH] [models,configs/models] Cleanup model implementation to
 re-use all backbones from torchvision; Simplify checkpointer; Remove own
 implementation of model_zoo; Implement normalization for torchvision-based
 backbones; Enables pytorch installation on osx

---
 bob/ip/binseg/configs/models/driu.py          |   9 +-
 bob/ip/binseg/configs/models/driu_bn.py       |  11 +-
 bob/ip/binseg/configs/models/driu_bn_ssl.py   |  11 +-
 bob/ip/binseg/configs/models/driu_od.py       |  11 +-
 bob/ip/binseg/configs/models/driu_ssl.py      |  11 +-
 bob/ip/binseg/configs/models/hed.py           |  14 +-
 bob/ip/binseg/configs/models/m2unet.py        |  13 +-
 bob/ip/binseg/configs/models/m2unet_ssl.py    |  13 +-
 bob/ip/binseg/configs/models/resunet.py       |  13 +-
 bob/ip/binseg/configs/models/unet.py          |  13 +-
 bob/ip/binseg/engine/ssltrainer.py            |   2 +-
 bob/ip/binseg/engine/trainer.py               |   2 +-
 .../binseg/modeling/backbones/mobilenetv2.py  | 155 -----------
 bob/ip/binseg/modeling/backbones/resnet.py    | 244 ------------------
 bob/ip/binseg/modeling/backbones/vgg.py       | 221 ----------------
 .../binseg/{modeling => models}/__init__.py   |   0
 .../backbones/__init__.py                     |   0
 bob/ip/binseg/models/backbones/mobilenetv2.py |  64 +++++
 bob/ip/binseg/models/backbones/resnet.py      |  69 +++++
 bob/ip/binseg/models/backbones/vgg.py         |  83 ++++++
 bob/ip/binseg/{modeling => models}/driu.py    |  57 ++--
 .../{modeling/driubn.py => models/driu_bn.py} |  62 +++--
 .../{modeling/driuod.py => models/driu_od.py} |  64 ++---
 .../driupix.py => models/driu_pix.py}         |  62 +++--
 bob/ip/binseg/{modeling => models}/hed.py     |  52 +++-
 bob/ip/binseg/{modeling => models}/losses.py  |   0
 .../{modeling/m2u.py => models/m2unet.py}     |  58 ++++-
 .../{modeling => models}/make_layers.py       |   0
 bob/ip/binseg/models/normalizer.py            |  33 +++
 bob/ip/binseg/{modeling => models}/resunet.py |  72 ++++--
 bob/ip/binseg/{modeling => models}/unet.py    |  53 ++--
 bob/ip/binseg/script/experiment.py            |  11 -
 bob/ip/binseg/script/predict.py               |   9 +-
 bob/ip/binseg/script/train.py                 |  19 +-
 bob/ip/binseg/test/test_checkpointer.py       |  33 ++-
 bob/ip/binseg/test/test_cli.py                |   6 +-
 bob/ip/binseg/test/test_models.py             | 140 ++++++++++
 bob/ip/binseg/utils/checkpointer.py           | 158 +++++-------
 bob/ip/binseg/utils/model_serialization.py    |  41 +--
 bob/ip/binseg/utils/model_zoo.py              | 116 ---------
 conda/meta.yaml                               |   4 +-
 doc/api.rst                                   |  36 +--
 doc/extras.inv                                | Bin 513 -> 576 bytes
 doc/extras.txt                                |   3 +
 44 files changed, 895 insertions(+), 1153 deletions(-)
 delete mode 100644 bob/ip/binseg/modeling/backbones/mobilenetv2.py
 delete mode 100644 bob/ip/binseg/modeling/backbones/resnet.py
 delete mode 100644 bob/ip/binseg/modeling/backbones/vgg.py
 rename bob/ip/binseg/{modeling => models}/__init__.py (100%)
 rename bob/ip/binseg/{modeling => models}/backbones/__init__.py (100%)
 create mode 100644 bob/ip/binseg/models/backbones/mobilenetv2.py
 create mode 100644 bob/ip/binseg/models/backbones/resnet.py
 create mode 100644 bob/ip/binseg/models/backbones/vgg.py
 rename bob/ip/binseg/{modeling => models}/driu.py (60%)
 rename bob/ip/binseg/{modeling/driubn.py => models/driu_bn.py} (56%)
 rename bob/ip/binseg/{modeling/driuod.py => models/driu_od.py} (57%)
 rename bob/ip/binseg/{modeling/driupix.py => models/driu_pix.py} (60%)
 rename bob/ip/binseg/{modeling => models}/hed.py (66%)
 rename bob/ip/binseg/{modeling => models}/losses.py (100%)
 rename bob/ip/binseg/{modeling/m2u.py => models/m2unet.py} (66%)
 rename bob/ip/binseg/{modeling => models}/make_layers.py (100%)
 create mode 100755 bob/ip/binseg/models/normalizer.py
 rename bob/ip/binseg/{modeling => models}/resunet.py (50%)
 rename bob/ip/binseg/{modeling => models}/unet.py (55%)
 create mode 100755 bob/ip/binseg/test/test_models.py
 delete mode 100644 bob/ip/binseg/utils/model_zoo.py

diff --git a/bob/ip/binseg/configs/models/driu.py b/bob/ip/binseg/configs/models/driu.py
index cdc9cb89..eda8f1e7 100644
--- a/bob/ip/binseg/configs/models/driu.py
+++ b/bob/ip/binseg/configs/models/driu.py
@@ -11,9 +11,8 @@ Reference: [MANINIS-2016]_
 """
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.driu import build_driu
-from bob.ip.binseg.utils.model_zoo import modelurls
-from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.models.driu import driu
+from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
 from bob.ip.binseg.engine.adabound import AdaBound
 
 ##### Config #####
@@ -29,9 +28,7 @@ amsbound = False
 scheduler_milestones = [900]
 scheduler_gamma = 0.1
 
-model = build_driu()
-
-pretrained_backbone = modelurls["vgg16"]
+model = driu()
 
 optimizer = AdaBound(
     model.parameters(),
diff --git a/bob/ip/binseg/configs/models/driu_bn.py b/bob/ip/binseg/configs/models/driu_bn.py
index 4e3a4b3c..8eb17242 100644
--- a/bob/ip/binseg/configs/models/driu_bn.py
+++ b/bob/ip/binseg/configs/models/driu_bn.py
@@ -12,9 +12,8 @@ Reference: [MANINIS-2016]_
 """
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.driubn import build_driu
-from bob.ip.binseg.utils.model_zoo import modelurls
-from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.models.driu_bn import driu_bn
+from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
 from bob.ip.binseg.engine.adabound import AdaBound
 
 ##### Config #####
@@ -30,11 +29,7 @@ amsbound = False
 scheduler_milestones = [900]
 scheduler_gamma = 0.1
 
-# model
-model = build_driu()
-
-# pretrained backbone
-pretrained_backbone = modelurls["vgg16_bn"]
+model = driu_bn()
 
 # optimizer
 optimizer = AdaBound(
diff --git a/bob/ip/binseg/configs/models/driu_bn_ssl.py b/bob/ip/binseg/configs/models/driu_bn_ssl.py
index a73d4ebe..273529bd 100644
--- a/bob/ip/binseg/configs/models/driu_bn_ssl.py
+++ b/bob/ip/binseg/configs/models/driu_bn_ssl.py
@@ -13,9 +13,8 @@ Reference: [MANINIS-2016]_
 """
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.driubn import build_driu
-from bob.ip.binseg.utils.model_zoo import modelurls
-from bob.ip.binseg.modeling.losses import MixJacLoss
+from bob.ip.binseg.models.driu_bn import driu_bn
+from bob.ip.binseg.models.losses import MixJacLoss
 from bob.ip.binseg.engine.adabound import AdaBound
 
 ##### Config #####
@@ -31,11 +30,7 @@ amsbound = False
 scheduler_milestones = [900]
 scheduler_gamma = 0.1
 
-# model
-model = build_driu()
-
-# pretrained backbone
-pretrained_backbone = modelurls["vgg16_bn"]
+model = driu_bn()
 
 # optimizer
 optimizer = AdaBound(
diff --git a/bob/ip/binseg/configs/models/driu_od.py b/bob/ip/binseg/configs/models/driu_od.py
index 9535c89a..6cf27f8a 100644
--- a/bob/ip/binseg/configs/models/driu_od.py
+++ b/bob/ip/binseg/configs/models/driu_od.py
@@ -11,9 +11,8 @@ Reference: [MANINIS-2016]_
 """
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.driuod import build_driuod
-from bob.ip.binseg.utils.model_zoo import modelurls
-from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.models.driu_od import driu_od
+from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
 from bob.ip.binseg.engine.adabound import AdaBound
 
 ##### Config #####
@@ -29,11 +28,7 @@ amsbound = False
 scheduler_milestones = [900]
 scheduler_gamma = 0.1
 
-# model
-model = build_driuod()
-
-# pretrained backbone
-pretrained_backbone = modelurls["vgg16"]
+model = driu_od()
 
 # optimizer
 optimizer = AdaBound(
diff --git a/bob/ip/binseg/configs/models/driu_ssl.py b/bob/ip/binseg/configs/models/driu_ssl.py
index 45194f6d..d5023894 100644
--- a/bob/ip/binseg/configs/models/driu_ssl.py
+++ b/bob/ip/binseg/configs/models/driu_ssl.py
@@ -12,9 +12,8 @@ Reference: [MANINIS-2016]_
 """
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.driu import build_driu
-from bob.ip.binseg.utils.model_zoo import modelurls
-from bob.ip.binseg.modeling.losses import MixJacLoss
+from bob.ip.binseg.models.driu import driu
+from bob.ip.binseg.models.losses import MixJacLoss
 from bob.ip.binseg.engine.adabound import AdaBound
 
 ##### Config #####
@@ -30,11 +29,7 @@ amsbound = False
 scheduler_milestones = [900]
 scheduler_gamma = 0.1
 
-# model
-model = build_driu()
-
-# pretrained backbone
-pretrained_backbone = modelurls["vgg16"]
+model = driu()
 
 # optimizer
 optimizer = AdaBound(
diff --git a/bob/ip/binseg/configs/models/hed.py b/bob/ip/binseg/configs/models/hed.py
index 6a9d7e82..7300dee5 100644
--- a/bob/ip/binseg/configs/models/hed.py
+++ b/bob/ip/binseg/configs/models/hed.py
@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 
 
-"""HED Network for Vessel Segmentation
+"""HED Network for image segmentation
 
 Holistically-nested edge detection (HED), turns pixel-wise edge classification
 into image-to-image prediction by means of a deep learning model that leverages
@@ -13,9 +13,8 @@ Reference: [XIE-2015]_
 
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.hed import build_hed
-from bob.ip.binseg.modeling.losses import HEDSoftJaccardBCELogitsLoss
-from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.models.hed import hed
+from bob.ip.binseg.models.losses import HEDSoftJaccardBCELogitsLoss
 from bob.ip.binseg.engine.adabound import AdaBound
 
 
@@ -32,12 +31,7 @@ amsbound = False
 scheduler_milestones = [900]
 scheduler_gamma = 0.1
 
-
-# model
-model = build_hed()
-
-# pretrained backbone
-pretrained_backbone = modelurls["vgg16"]
+model = hed()
 
 # optimizer
 optimizer = AdaBound(
diff --git a/bob/ip/binseg/configs/models/m2unet.py b/bob/ip/binseg/configs/models/m2unet.py
index 2edc0372..025fcbb1 100644
--- a/bob/ip/binseg/configs/models/m2unet.py
+++ b/bob/ip/binseg/configs/models/m2unet.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-"""MobileNetV2 U-Net Model for Vessel Segmentation
+"""MobileNetV2 U-Net model for image segmentation
 
 The MobileNetV2 architecture is based on an inverted residual structure where
 the input and output of the residual block are thin bottleneck layers opposite
@@ -15,9 +15,8 @@ References: [SANDLER-2018]_, [RONNEBERGER-2015]_
 """
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.m2u import build_m2unet
-from bob.ip.binseg.utils.model_zoo import modelurls
-from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.models.m2unet import m2unet
+from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
 from bob.ip.binseg.engine.adabound import AdaBound
 
 ##### Config #####
@@ -33,11 +32,7 @@ amsbound = False
 scheduler_milestones = [900]
 scheduler_gamma = 0.1
 
-# model
-model = build_m2unet()
-
-# pretrained backbone
-pretrained_backbone = modelurls["mobilenetv2"]
+model = m2unet()
 
 # optimizer
 optimizer = AdaBound(
diff --git a/bob/ip/binseg/configs/models/m2unet_ssl.py b/bob/ip/binseg/configs/models/m2unet_ssl.py
index 9a456d86..8fea09ce 100644
--- a/bob/ip/binseg/configs/models/m2unet_ssl.py
+++ b/bob/ip/binseg/configs/models/m2unet_ssl.py
@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 
 
-"""MobileNetV2 U-Net Model for Vessel Segmentation using SSL
+"""MobileNetV2 U-Net model for image segmentation using SSL
 
 The MobileNetV2 architecture is based on an inverted residual structure where
 the input and output of the residual block are thin bottleneck layers opposite
@@ -18,9 +18,8 @@ References: [SANDLER-2018]_, [RONNEBERGER-2015]_
 """
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.m2u import build_m2unet
-from bob.ip.binseg.utils.model_zoo import modelurls
-from bob.ip.binseg.modeling.losses import MixJacLoss
+from bob.ip.binseg.models.m2unet import m2unet
+from bob.ip.binseg.models.losses import MixJacLoss
 from bob.ip.binseg.engine.adabound import AdaBound
 
 ##### Config #####
@@ -36,11 +35,7 @@ amsbound = False
 scheduler_milestones = [900]
 scheduler_gamma = 0.1
 
-# model
-model = build_m2unet()
-
-# pretrained backbone
-pretrained_backbone = modelurls["mobilenetv2"]
+model = m2unet()
 
 # optimizer
 optimizer = AdaBound(
diff --git a/bob/ip/binseg/configs/models/resunet.py b/bob/ip/binseg/configs/models/resunet.py
index ff7e26e5..a5402130 100644
--- a/bob/ip/binseg/configs/models/resunet.py
+++ b/bob/ip/binseg/configs/models/resunet.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-"""Residual U-Net for Vessel Segmentation
+"""Residual U-Net for image segmentation
 
 A semantic segmentation neural network which combines the strengths of residual
 learning and U-Net is proposed for road area extraction.  The network is built
@@ -15,9 +15,8 @@ Reference: [ZHANG-2017]_
 """
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.resunet import build_res50unet
-from bob.ip.binseg.utils.model_zoo import modelurls
-from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.models.resunet import resunet50
+from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
 from bob.ip.binseg.engine.adabound import AdaBound
 
 ##### Config #####
@@ -33,11 +32,7 @@ amsbound = False
 scheduler_milestones = [900]
 scheduler_gamma = 0.1
 
-# model
-model = build_res50unet()
-
-# pretrained backbone
-pretrained_backbone = modelurls["resnet50"]
+model = resunet50()
 
 # optimizer
 optimizer = AdaBound(
diff --git a/bob/ip/binseg/configs/models/unet.py b/bob/ip/binseg/configs/models/unet.py
index ee1eddb7..9bcf2a58 100644
--- a/bob/ip/binseg/configs/models/unet.py
+++ b/bob/ip/binseg/configs/models/unet.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-"""U-Net for Vessel Segmentation
+"""U-Net for image segmentation
 
 U-Net is a convolutional neural network that was developed for biomedical image
 segmentation at the Computer Science Department of the University of Freiburg,
@@ -13,9 +13,8 @@ Reference: [RONNEBERGER-2015]_
 """
 
 from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.modeling.unet import build_unet
-from bob.ip.binseg.utils.model_zoo import modelurls
-from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.models.unet import unet
+from bob.ip.binseg.models.losses import SoftJaccardBCELogitsLoss
 from bob.ip.binseg.engine.adabound import AdaBound
 
 ##### Config #####
@@ -31,11 +30,7 @@ amsbound = False
 scheduler_milestones = [900]
 scheduler_gamma = 0.1
 
-# model
-model = build_unet()
-
-# pretrained backbone
-pretrained_backbone = modelurls["vgg16"]
+model = unet()
 
 # optimizer
 optimizer = AdaBound(
diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 4a8cca16..f642c1f5 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -207,7 +207,7 @@ def run(
     scheduler : :py:mod:`torch.optim`
         learning rate scheduler
 
-    checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
+    checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer`
         checkpointer implementation
 
     checkpoint_period : int
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index e0299cc1..1c587de7 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -93,7 +93,7 @@ def run(
     scheduler : :py:mod:`torch.optim`
         learning rate scheduler
 
-    checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
+    checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer`
         checkpointer implementation
 
     checkpoint_period : int
diff --git a/bob/ip/binseg/modeling/backbones/mobilenetv2.py b/bob/ip/binseg/modeling/backbones/mobilenetv2.py
deleted file mode 100644
index 9e6cd245..00000000
--- a/bob/ip/binseg/modeling/backbones/mobilenetv2.py
+++ /dev/null
@@ -1,155 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-
-# Adopted from https://github.com/tonylins/pytorch-mobilenet-v2/ by @tonylins
-# Ji Lin under Apache License 2.0
-
-import torch.nn
-import math
-
-
-def conv_bn(inp, oup, stride):
-    return torch.nn.Sequential(
-        torch.nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
-        torch.nn.BatchNorm2d(oup),
-        torch.nn.ReLU6(inplace=True),
-    )
-
-
-def conv_1x1_bn(inp, oup):
-    return torch.nn.Sequential(
-        torch.nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
-        torch.nn.BatchNorm2d(oup),
-        torch.nn.ReLU6(inplace=True),
-    )
-
-
-class InvertedResidual(torch.nn.Module):
-    def __init__(self, inp, oup, stride, expand_ratio):
-        super(InvertedResidual, self).__init__()
-        self.stride = stride
-        assert stride in [1, 2]
-
-        hidden_dim = round(inp * expand_ratio)
-        self.use_res_connect = self.stride == 1 and inp == oup
-
-        if expand_ratio == 1:
-            self.conv = torch.nn.Sequential(
-                # dw
-                torch.nn.Conv2d(
-                    hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False
-                ),
-                torch.nn.BatchNorm2d(hidden_dim),
-                torch.nn.ReLU6(inplace=True),
-                # pw-linear
-                torch.nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
-                torch.nn.BatchNorm2d(oup),
-            )
-        else:
-            self.conv = torch.nn.Sequential(
-                # pw
-                torch.nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
-                torch.nn.BatchNorm2d(hidden_dim),
-                torch.nn.ReLU6(inplace=True),
-                # dw
-                torch.nn.Conv2d(
-                    hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False
-                ),
-                torch.nn.BatchNorm2d(hidden_dim),
-                torch.nn.ReLU6(inplace=True),
-                # pw-linear
-                torch.nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
-                torch.nn.BatchNorm2d(oup),
-            )
-
-    def forward(self, x):
-        if self.use_res_connect:
-            return x + self.conv(x)
-        else:
-            return self.conv(x)
-
-
-class MobileNetV2(torch.nn.Module):
-    def __init__(
-        self,
-        n_class=1000,
-        input_size=224,
-        width_mult=1.0,
-        return_features=None,
-        m2u=True,
-    ):
-        super(MobileNetV2, self).__init__()
-        self.return_features = return_features
-        self.m2u = m2u
-        block = InvertedResidual
-        input_channel = 32
-        #last_channel = 1280
-        interverted_residual_setting = [
-            # t, c, n, s
-            [1, 16, 1, 1],
-            [6, 24, 2, 2],
-            [6, 32, 3, 2],
-            [6, 64, 4, 2],
-            [6, 96, 3, 1],
-            # [6, 160, 3, 2],
-            # [6, 320, 1, 1],
-        ]
-
-        # building first layer
-        assert input_size % 32 == 0
-        input_channel = int(input_channel * width_mult)
-        # self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
-        self.features = [conv_bn(3, input_channel, 2)]
-        # building inverted residual blocks
-        for t, c, n, s in interverted_residual_setting:
-            output_channel = int(c * width_mult)
-            for i in range(n):
-                if i == 0:
-                    self.features.append(
-                        block(input_channel, output_channel, s, expand_ratio=t)
-                    )
-                else:
-                    self.features.append(
-                        block(input_channel, output_channel, 1, expand_ratio=t)
-                    )
-                input_channel = output_channel
-        # building last several layers
-        # self.features.append(conv_1x1_bn(input_channel, self.last_channel))
-        # make it torch.nn.Sequential
-        self.features = torch.nn.Sequential(*self.features)
-
-        # building classifier
-        # self.classifier = torch.nn.Sequential(
-        #    torch.nn.Dropout(0.2),
-        #    torch.nn.Linear(self.last_channel, n_class),
-        # )
-
-        self._initialize_weights()
-
-    def forward(self, x):
-        outputs = []
-        # hw of input, needed for DRIU and HED
-        outputs.append(x.shape[2:4])
-        if self.m2u:
-            outputs.append(x)
-        for index, m in enumerate(self.features):
-            x = m(x)
-            # extract layers
-            if index in self.return_features:
-                outputs.append(x)
-        return outputs
-
-    def _initialize_weights(self):
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
-                m.weight.data.normal_(0, math.sqrt(2.0 / n))
-                if m.bias is not None:
-                    m.bias.data.zero_()
-            elif isinstance(m, torch.nn.BatchNorm2d):
-                m.weight.data.fill_(1)
-                m.bias.data.zero_()
-            elif isinstance(m, torch.nn.Linear):
-                n = m.weight.size(1)
-                m.weight.data.normal_(0, 0.01)
-                m.bias.data.zero_()
diff --git a/bob/ip/binseg/modeling/backbones/resnet.py b/bob/ip/binseg/modeling/backbones/resnet.py
deleted file mode 100644
index 445c4ba7..00000000
--- a/bob/ip/binseg/modeling/backbones/resnet.py
+++ /dev/null
@@ -1,244 +0,0 @@
-# Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
-# resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN : https://github.com/rgeirhos/texture-vs-shap
-
-import torch.nn as nn
-import torch.utils.model_zoo as model_zoo
-
-
-model_urls = {
-    "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
-    "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
-    "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
-    "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
-    "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
-    "resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN": "https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar",
-}
-
-
-def conv3x3(in_planes, out_planes, stride=1):
-    """3x3 convolution with padding"""
-    return nn.Conv2d(
-        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False,
-    )
-
-
-def conv1x1(in_planes, out_planes, stride=1):
-    """1x1 convolution"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-class BasicBlock(nn.Module):
-    expansion = 1
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(BasicBlock, self).__init__()
-        self.conv1 = conv3x3(inplanes, planes, stride)
-        self.bn1 = nn.BatchNorm2d(planes)
-        self.relu = nn.ReLU(inplace=True)
-        self.conv2 = conv3x3(planes, planes)
-        self.bn2 = nn.BatchNorm2d(planes)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        identity = x
-
-        out = self.conv1(x)
-        out = self.bn1(out)
-        out = self.relu(out)
-
-        out = self.conv2(out)
-        out = self.bn2(out)
-
-        if self.downsample is not None:
-            identity = self.downsample(x)
-
-        out += identity
-        out = self.relu(out)
-
-        return out
-
-
-class Bottleneck(nn.Module):
-    expansion = 4
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(Bottleneck, self).__init__()
-        self.conv1 = conv1x1(inplanes, planes)
-        self.bn1 = nn.BatchNorm2d(planes)
-        self.conv2 = conv3x3(planes, planes, stride)
-        self.bn2 = nn.BatchNorm2d(planes)
-        self.conv3 = conv1x1(planes, planes * self.expansion)
-        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
-        self.relu = nn.ReLU(inplace=True)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        identity = x
-
-        out = self.conv1(x)
-        out = self.bn1(out)
-        out = self.relu(out)
-
-        out = self.conv2(out)
-        out = self.bn2(out)
-        out = self.relu(out)
-
-        out = self.conv3(out)
-        out = self.bn3(out)
-
-        if self.downsample is not None:
-            identity = self.downsample(x)
-
-        out += identity
-        out = self.relu(out)
-
-        return out
-
-
-class ResNet(nn.Module):
-    def __init__(self, block, layers, return_features, zero_init_residual=False):
-        """
-        Generic ResNet network with layer return.
-        Attributes
-        ----------
-        return_features: list of length 5
-            layers to return.
-        """
-        super(ResNet, self).__init__()
-        self.inplanes = 64
-        self.return_features = return_features
-        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
-        self.bn1 = nn.BatchNorm2d(64)
-        self.relu = nn.ReLU(inplace=True)
-        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-        self.layer1 = self._make_layer(block, 64, layers[0])
-        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
-        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
-        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
-
-        self.features = [
-            self.conv1,
-            self.bn1,
-            self.relu,
-            self.maxpool,
-            self.layer1,
-            self.layer2,
-            self.layer3,
-            self.layer4,
-        ]
-
-        for m in self.modules():
-            if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
-            elif isinstance(m, nn.BatchNorm2d):
-                nn.init.constant_(m.weight, 1)
-                nn.init.constant_(m.bias, 0)
-
-        # Zero-initialize the last BN in each residual branch,
-        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
-        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
-        if zero_init_residual:
-            for m in self.modules():
-                if isinstance(m, Bottleneck):
-                    nn.init.constant_(m.bn3.weight, 0)
-                elif isinstance(m, BasicBlock):
-                    nn.init.constant_(m.bn2.weight, 0)
-
-    def _make_layer(self, block, planes, blocks, stride=1):
-        downsample = None
-        if stride != 1 or self.inplanes != planes * block.expansion:
-            downsample = nn.Sequential(
-                conv1x1(self.inplanes, planes * block.expansion, stride),
-                nn.BatchNorm2d(planes * block.expansion),
-            )
-
-        layers = []
-        layers.append(block(self.inplanes, planes, stride, downsample))
-        self.inplanes = planes * block.expansion
-        for _ in range(1, blocks):
-            layers.append(block(self.inplanes, planes))
-
-        return nn.Sequential(*layers)
-
-    def forward(self, x):
-        outputs = []
-        # hw of input, needed for DRIU and HED
-        outputs.append(x.shape[2:4])
-        for index, m in enumerate(self.features):
-            x = m(x)
-            # extract layers
-            if index in self.return_features:
-                outputs.append(x)
-        return outputs
-
-
-def resnet18(pretrained=False, **kwargs):
-    """Constructs a ResNet-18 model.
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]))
-    return model
-
-
-def resnet34(pretrained=False, **kwargs):
-    """Constructs a ResNet-34 model.
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["resnet34"]))
-    return model
-
-
-def resnet50(pretrained=False, **kwargs):
-    """Constructs a ResNet-50 model.
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))
-    return model
-
-
-def shaperesnet50(pretrained=False, **kwargs):
-    """Constructs a ResNet-50 model, pretrained on Stylized-ImageNe and ImageNet and fine-tuned on ImageNet.
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
-    if pretrained:
-        model.load_state_dict(
-            model_zoo.load_url(
-                model_urls["resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN"]
-            )
-        )
-    return model
-
-
-def resnet101(pretrained=False, **kwargs):
-    """Constructs a ResNet-101 model.
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]))
-    return model
-
-
-def resnet152(pretrained=False, **kwargs):
-    """Constructs a ResNet-152 model.
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["resnet152"]))
-    return model
diff --git a/bob/ip/binseg/modeling/backbones/vgg.py b/bob/ip/binseg/modeling/backbones/vgg.py
deleted file mode 100644
index e3909fcb..00000000
--- a/bob/ip/binseg/modeling/backbones/vgg.py
+++ /dev/null
@@ -1,221 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-# Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
-
-import torch.nn as nn
-import torch.utils.model_zoo as model_zoo
-
-
-model_urls = {
-    "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
-    "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
-    "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
-    "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
-    "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
-    "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
-    "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
-    "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
-}
-
-
-class VGG(nn.Module):
-    def __init__(self, features, return_features, init_weights=True):
-        super(VGG, self).__init__()
-        self.features = features
-        self.return_features = return_features
-        if init_weights:
-            self._initialize_weights()
-
-    def forward(self, x):
-        outputs = []
-        # hw of input, needed for DRIU and HED
-        outputs.append(x.shape[2:4])
-        for index, m in enumerate(self.features):
-            x = m(x)
-            # extract layers
-            if index in self.return_features:
-                outputs.append(x)
-        return outputs
-
-    def _initialize_weights(self):
-        for m in self.modules():
-            if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
-                if m.bias is not None:
-                    nn.init.constant_(m.bias, 0)
-            elif isinstance(m, nn.BatchNorm2d):
-                nn.init.constant_(m.weight, 1)
-                nn.init.constant_(m.bias, 0)
-            elif isinstance(m, nn.Linear):
-                nn.init.normal_(m.weight, 0, 0.01)
-                nn.init.constant_(m.bias, 0)
-
-
-def make_layers(cfg, batch_norm=False):
-    layers = []
-    in_channels = 3
-    for v in cfg:
-        if v == "M":
-            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
-        else:
-            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
-            if batch_norm:
-                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
-            else:
-                layers += [conv2d, nn.ReLU(inplace=True)]
-            in_channels = v
-    return nn.Sequential(*layers)
-
-
-_cfg = {
-    "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
-    "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
-    "D": [
-        64,
-        64,
-        "M",
-        128,
-        128,
-        "M",
-        256,
-        256,
-        256,
-        "M",
-        512,
-        512,
-        512,
-        "M",
-        512,
-        512,
-        512,
-        "M",
-    ],
-    "E": [
-        64,
-        64,
-        "M",
-        128,
-        128,
-        "M",
-        256,
-        256,
-        256,
-        256,
-        "M",
-        512,
-        512,
-        512,
-        512,
-        "M",
-        512,
-        512,
-        512,
-        512,
-        "M",
-    ],
-}
-
-
-def vgg11(pretrained=False, **kwargs):
-    """VGG 11-layer model (configuration "A")
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    if pretrained:
-        kwargs["init_weights"] = False
-    model = VGG(make_layers(_cfg["A"]), **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["vgg11"]))
-    return model
-
-
-def vgg11_bn(pretrained=False, **kwargs):
-    """VGG 11-layer model (configuration "A") with batch normalization
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    if pretrained:
-        kwargs["init_weights"] = False
-    model = VGG(make_layers(_cfg["A"], batch_norm=True), **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["vgg11_bn"]))
-    return model
-
-
-def vgg13(pretrained=False, **kwargs):
-    """VGG 13-layer model (configuration "B")
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    if pretrained:
-        kwargs["init_weights"] = False
-    model = VGG(make_layers(_cfg["B"]), **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["vgg13"]))
-    return model
-
-
-def vgg13_bn(pretrained=False, **kwargs):
-    """VGG 13-layer model (configuration "B") with batch normalization
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    if pretrained:
-        kwargs["init_weights"] = False
-    model = VGG(make_layers(_cfg["B"], batch_norm=True), **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["vgg13_bn"]))
-    return model
-
-
-def vgg16(pretrained=False, **kwargs):
-    """VGG 16-layer model (configuration "D")
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    if pretrained:
-        kwargs["init_weights"] = False
-    model = VGG(make_layers(_cfg["D"]), **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["vgg16"]), strict=False)
-    return model
-
-
-def vgg16_bn(pretrained=False, **kwargs):
-    """VGG 16-layer model (configuration "D") with batch normalization
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    if pretrained:
-        kwargs["init_weights"] = False
-    model = VGG(make_layers(_cfg["D"], batch_norm=True), **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["vgg16_bn"]))
-    return model
-
-
-def vgg19(pretrained=False, **kwargs):
-    """VGG 19-layer model (configuration "E")
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    if pretrained:
-        kwargs["init_weights"] = False
-    model = VGG(make_layers(_cfg["E"]), **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["vgg19"]))
-    return model
-
-
-def vgg19_bn(pretrained=False, **kwargs):
-    """VGG 19-layer model (configuration 'E') with batch normalization
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    if pretrained:
-        kwargs["init_weights"] = False
-    model = VGG(make_layers(_cfg["E"], batch_norm=True), **kwargs)
-    if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls["vgg19_bn"]))
-    return model
diff --git a/bob/ip/binseg/modeling/__init__.py b/bob/ip/binseg/models/__init__.py
similarity index 100%
rename from bob/ip/binseg/modeling/__init__.py
rename to bob/ip/binseg/models/__init__.py
diff --git a/bob/ip/binseg/modeling/backbones/__init__.py b/bob/ip/binseg/models/backbones/__init__.py
similarity index 100%
rename from bob/ip/binseg/modeling/backbones/__init__.py
rename to bob/ip/binseg/models/backbones/__init__.py
diff --git a/bob/ip/binseg/models/backbones/mobilenetv2.py b/bob/ip/binseg/models/backbones/mobilenetv2.py
new file mode 100644
index 00000000..6b8f555f
--- /dev/null
+++ b/bob/ip/binseg/models/backbones/mobilenetv2.py
@@ -0,0 +1,64 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+
+import torchvision.models.mobilenet
+
+
+class MobileNetV24Segmentation(torchvision.models.mobilenet.MobileNetV2):
+    """Adaptation of base MobileNetV2 functionality to U-Net style segmentation
+
+    This version of MobileNetV2 is slightly modified so it can be used through
+    torchvision's API.  It outputs intermediate features which are normally not
+    output by the base MobileNetV2 implementation, but are required for
+    segmentation operations.
+
+
+    Parameters
+    ==========
+
+    return_features : :py:class:`list`, Optional
+        A list of integers indicating the feature layers to be returned from
+        the original module.
+
+    """
+
+    def __init__(self, *args, **kwargs):
+        self._return_features = kwargs.pop("return_features")
+        super(MobileNetV24Segmentation, self).__init__(*args, **kwargs)
+
+    def forward(self, x):
+        outputs = []
+        # hw of input, needed for DRIU and HED
+        outputs.append(x.shape[2:4])
+        outputs.append(x)
+        for index, m in enumerate(self.features):
+            x = m(x)
+            # extract layers
+            if index in self._return_features:
+                outputs.append(x)
+        return outputs
+
+
+def mobilenet_v2_for_segmentation(pretrained=False, progress=True, **kwargs):
+    model = MobileNetV24Segmentation(**kwargs)
+
+    if pretrained:
+        state_dict = torchvision.models.mobilenet.load_state_dict_from_url(
+            torchvision.models.mobilenet.model_urls["mobilenet_v2"],
+            progress=progress,
+        )
+        model.load_state_dict(state_dict)
+
+    # erase MobileNetV2 head (for classification), not used for segmentation
+    delattr(model, 'classifier')
+
+    return_features = kwargs.get("return_features")
+    if return_features is not None:
+        model.features = model.features[:(max(return_features)+1)]
+
+    return model
+
+
+mobilenet_v2_for_segmentation.__doc__ = (
+    torchvision.models.mobilenet.mobilenet_v2.__doc__
+)
diff --git a/bob/ip/binseg/models/backbones/resnet.py b/bob/ip/binseg/models/backbones/resnet.py
new file mode 100644
index 00000000..1f362efc
--- /dev/null
+++ b/bob/ip/binseg/models/backbones/resnet.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import torchvision.models.resnet
+
+
+class ResNet4Segmentation(torchvision.models.resnet.ResNet):
+    """Adaptation of base ResNet functionality to U-Net style segmentation
+
+    This version of ResNet is slightly modified so it can be used through
+    torchvision's API.  It outputs intermediate features which are normally not
+    output by the base ResNet implementation, but are required for segmentation
+    operations.
+
+
+    Parameters
+    ==========
+
+    return_features : :py:class:`list`, Optional
+        A list of integers indicating the feature layers to be returned from
+        the original module.
+
+    """
+
+    def __init__(self, *args, **kwargs):
+        self._return_features = kwargs.pop("return_features")
+        super(ResNet4Segmentation, self).__init__(*args, **kwargs)
+
+    def forward(self, x):
+        outputs = []
+        # hardwiring of input
+        outputs.append(x.shape[2:4])
+        for index, m in enumerate(self.features):
+            x = m(x)
+            # extract layers
+            if index in self.return_features:
+                outputs.append(x)
+        return outputs
+
+
+def _resnet_for_segmentation(
+    arch, block, layers, pretrained, progress, **kwargs
+):
+    model = ResNet4Segmentation(block, layers, **kwargs)
+    if pretrained:
+        state_dict = torchvision.models.resnet.load_state_dict_from_url(
+            torchvision.models.resnet.model_urls[arch], progress=progress
+        )
+        model.load_state_dict(state_dict)
+
+    # erase ResNet head (for classification), not used for segmentation
+    delattr(model, 'avgpool')
+    delattr(model, 'fc')
+
+    return model
+
+
+def resnet50_for_segmentation(pretrained=False, progress=True, **kwargs):
+    return _resnet_for_segmentation(
+        "resnet50",
+        torchvision.models.resnet.Bottleneck,
+        [3, 4, 6, 3],
+        pretrained,
+        progress,
+        **kwargs
+    )
+
+
+resnet50_for_segmentation.__doc__ = torchvision.models.resnet.resnet50.__doc__
diff --git a/bob/ip/binseg/models/backbones/vgg.py b/bob/ip/binseg/models/backbones/vgg.py
new file mode 100644
index 00000000..afe275b2
--- /dev/null
+++ b/bob/ip/binseg/models/backbones/vgg.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import torchvision.models.vgg
+
+
+class VGG4Segmentation(torchvision.models.vgg.VGG):
+    """Adaptation of base VGG functionality to U-Net style segmentation
+
+    This version of VGG is slightly modified so it can be used through
+    torchvision's API.  It outputs intermediate features which are normally not
+    output by the base VGG implementation, but are required for segmentation
+    operations.
+
+
+    Parameters
+    ==========
+
+    return_features : :py:class:`list`, Optional
+        A list of integers indicating the feature layers to be returned from
+        the original module.
+
+    """
+
+    def __init__(self, *args, **kwargs):
+        self._return_features = kwargs.pop("return_features")
+        super(VGG4Segmentation, self).__init__(*args, **kwargs)
+
+    def forward(self, x):
+        outputs = []
+        # hardwiring of input
+        outputs.append(x.shape[2:4])
+        for index, m in enumerate(self.features):
+            x = m(x)
+            # extract layers
+            if index in self._return_features:
+                outputs.append(x)
+        return outputs
+
+
+def _vgg_for_segmentation(
+    arch, cfg, batch_norm, pretrained, progress, **kwargs
+):
+
+    if pretrained:
+        kwargs["init_weights"] = False
+
+    model = VGG4Segmentation(
+        torchvision.models.vgg.make_layers(
+            torchvision.models.vgg.cfgs[cfg], batch_norm=batch_norm
+        ),
+        **kwargs
+    )
+
+    if pretrained:
+        state_dict = torchvision.models.vgg.load_state_dict_from_url(
+            torchvision.models.vgg.model_urls[arch], progress=progress
+        )
+        model.load_state_dict(state_dict)
+
+    # erase VGG head (for classification), not used for segmentation
+    delattr(model, 'classifier')
+    delattr(model, 'avgpool')
+
+    return model
+
+
+def vgg16_for_segmentation(pretrained=False, progress=True, **kwargs):
+    return _vgg_for_segmentation(
+        "vgg16", "D", False, pretrained, progress, **kwargs
+    )
+
+
+vgg16_for_segmentation.__doc__ = torchvision.models.vgg16.__doc__
+
+
+def vgg16_bn_for_segmentation(pretrained=False, progress=True, **kwargs):
+    return _vgg_for_segmentation(
+        "vgg16_bn", "D", True, pretrained, progress, **kwargs
+    )
+
+
+vgg16_bn_for_segmentation.__doc__ = torchvision.models.vgg16_bn.__doc__
diff --git a/bob/ip/binseg/modeling/driu.py b/bob/ip/binseg/models/driu.py
similarity index 60%
rename from bob/ip/binseg/modeling/driu.py
rename to bob/ip/binseg/models/driu.py
index 94cf77b3..00d9ea36 100644
--- a/bob/ip/binseg/modeling/driu.py
+++ b/bob/ip/binseg/models/driu.py
@@ -1,15 +1,14 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+from collections import OrderedDict
+
 import torch
 import torch.nn
-from collections import OrderedDict
-from .backbones.vgg import vgg16
-from .make_layers import (
-    conv_with_kaiming_uniform,
-    convtrans_with_kaiming_uniform,
-    UpsampleCropBlock,
-)
+
+from .backbones.vgg import vgg16_for_segmentation
+
+from .make_layers import conv_with_kaiming_uniform, UpsampleCropBlock
 
 
 class ConcatFuseBlock(torch.nn.Module):
@@ -39,11 +38,17 @@ class DRIU(torch.nn.Module):
     ----------
     in_channels_list : list
         number of channels for each feature map that is returned from backbone
+
     """
 
     def __init__(self, in_channels_list=None):
         super(DRIU, self).__init__()
-        (in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8,) = in_channels_list
+        (
+            in_conv_1_2_16,
+            in_upsample2,
+            in_upsample_4,
+            in_upsample_8,
+        ) = in_channels_list
 
         self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
         # Upsample layers
@@ -80,21 +85,43 @@ class DRIU(torch.nn.Module):
         return out
 
 
-def build_driu():
-    """
-    Adds backbone and head together
+def driu(pretrained_backbone=True, progress=True):
+    """Builds DRIU for vessel segmentation by adding backbone and head together
+
+
+    Parameters
+    ----------
+
+    pretrained_backbone : :py:class:`bool`, Optional
+        If set to ``True``, then loads a pre-trained version of the backbone
+        (not the head) for the DRIU network using VGG-16 trained for ImageNet
+        classification.
+
+    progress : :py:class:`bool`, Optional
+        If set to ``True``, and you decided to use a ``pretrained_backbone``,
+        then, shows a progress bar of the backbone model downloading if
+        download is necesssary.
+
 
     Returns
     -------
 
     module : :py:class:`torch.nn.Module`
+        Network model for DRIU (vessel segmentation)
 
     """
-    backbone = vgg16(pretrained=False, return_features=[3, 8, 14, 22])
-    driu_head = DRIU([64, 128, 256, 512])
 
-    model = torch.nn.Sequential(
-        OrderedDict([("backbone", backbone), ("head", driu_head)])
+    backbone = vgg16_for_segmentation(
+        pretrained=pretrained_backbone, progress=progress,
+        return_features=[3, 8, 14, 22],
     )
+    head = DRIU([64, 128, 256, 512])
+
+    order = [("backbone", backbone), ("head", head)]
+    if pretrained_backbone:
+        from .normalizer import TorchVisionNormalizer
+        order = [("normalizer", TorchVisionNormalizer())] + order
+
+    model = torch.nn.Sequential(OrderedDict(order))
     model.name = "driu"
     return model
diff --git a/bob/ip/binseg/modeling/driubn.py b/bob/ip/binseg/models/driu_bn.py
similarity index 56%
rename from bob/ip/binseg/modeling/driubn.py
rename to bob/ip/binseg/models/driu_bn.py
index 24055dec..09d7c570 100644
--- a/bob/ip/binseg/modeling/driubn.py
+++ b/bob/ip/binseg/models/driu_bn.py
@@ -1,15 +1,13 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+from collections import OrderedDict
+
 import torch
 import torch.nn
-from collections import OrderedDict
-from .backbones.vgg import vgg16_bn
-from .make_layers import (
-    conv_with_kaiming_uniform,
-    convtrans_with_kaiming_uniform,
-    UpsampleCropBlock,
-)
+from .backbones.vgg import vgg16_bn_for_segmentation
+
+from .make_layers import conv_with_kaiming_uniform, UpsampleCropBlock
 
 
 class ConcatFuseBlock(torch.nn.Module):
@@ -21,7 +19,8 @@ class ConcatFuseBlock(torch.nn.Module):
     def __init__(self):
         super().__init__()
         self.conv = torch.nn.Sequential(
-            conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0), torch.nn.BatchNorm2d(1)
+            conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0),
+            torch.nn.BatchNorm2d(1),
         )
 
     def forward(self, x1, x2, x3, x4):
@@ -30,9 +29,9 @@ class ConcatFuseBlock(torch.nn.Module):
         return x
 
 
-class DRIU(torch.nn.Module):
+class DRIUBN(torch.nn.Module):
     """
-    DRIU head module
+    DRIU with Batch-Normalization head module
 
     Based on paper by [MANINIS-2016]_.
 
@@ -43,8 +42,13 @@ class DRIU(torch.nn.Module):
     """
 
     def __init__(self, in_channels_list=None):
-        super(DRIU, self).__init__()
-        in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8 = in_channels_list
+        super(DRIUBN, self).__init__()
+        (
+            in_conv_1_2_16,
+            in_upsample2,
+            in_upsample_4,
+            in_upsample_8,
+        ) = in_channels_list
 
         self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
         # Upsample layers
@@ -77,21 +81,41 @@ class DRIU(torch.nn.Module):
         return out
 
 
-def build_driu():
-    """
-    Adds backbone and head together
+def driu_bn(pretrained_backbone=True, progress=True):
+    """Builds DRIU with batch-normalization by adding backbone and head together
+
+    Parameters
+    ----------
+
+    pretrained_backbone : :py:class:`bool`, Optional
+        If set to ``True``, then loads a pre-trained version of the backbone
+        (not the head) for the DRIU network using VGG-16 trained for ImageNet
+        classification.
+
+    progress : :py:class:`bool`, Optional
+        If set to ``True``, and you decided to use a ``pretrained_backbone``,
+        then, shows a progress bar of the backbone model downloading if
+        download is necesssary.
+
 
     Returns
     -------
 
     module : :py:class:`torch.nn.Module`
+        Network model for DRIU (vessel segmentation) using batch normalization
 
     """
-    backbone = vgg16_bn(pretrained=False, return_features=[5, 12, 19, 29])
-    driu_head = DRIU([64, 128, 256, 512])
 
-    model = torch.nn.Sequential(
-        OrderedDict([("backbone", backbone), ("head", driu_head)])
+    backbone = vgg16_bn_for_segmentation(
+        pretrained=False, return_features=[5, 12, 19, 29]
     )
+    head = DRIUBN([64, 128, 256, 512])
+
+    order = [("backbone", backbone), ("head", head)]
+    if pretrained_backbone:
+        from .normalizer import TorchVisionNormalizer
+        order = [("normalizer", TorchVisionNormalizer())] + order
+
+    model = torch.nn.Sequential(OrderedDict(order))
     model.name = "driu-bn"
     return model
diff --git a/bob/ip/binseg/modeling/driuod.py b/bob/ip/binseg/models/driu_od.py
similarity index 57%
rename from bob/ip/binseg/modeling/driuod.py
rename to bob/ip/binseg/models/driu_od.py
index 80336c6f..e45755af 100644
--- a/bob/ip/binseg/modeling/driuod.py
+++ b/bob/ip/binseg/models/driu_od.py
@@ -1,36 +1,20 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import torch
-import torch.nn
 from collections import OrderedDict
-from .backbones.vgg import vgg16
-from .make_layers import (
-    conv_with_kaiming_uniform,
-    convtrans_with_kaiming_uniform,
-    UpsampleCropBlock,
-)
-
 
-class ConcatFuseBlock(torch.nn.Module):
-    """
-    Takes in four feature maps with 16 channels each, concatenates them
-    and applies a 1x1 convolution with 1 output channel.
-    """
+import torch
+import torch.nn
 
-    def __init__(self):
-        super().__init__()
-        self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0)
+from .backbones.vgg import vgg16_for_segmentation
 
-    def forward(self, x1, x2, x3, x4):
-        x_cat = torch.cat([x1, x2, x3, x4], dim=1)
-        x = self.conv(x_cat)
-        return x
+from .make_layers import UpsampleCropBlock
+from .driu import ConcatFuseBlock
 
 
 class DRIUOD(torch.nn.Module):
     """
-    DRIU head module
+    DRIU for optic disc segmentation head module
 
     Parameters
     ----------
@@ -73,20 +57,42 @@ class DRIUOD(torch.nn.Module):
         return out
 
 
-def build_driuod():
-    """
-    Adds backbone and head together
+def driu_od(pretrained_backbone=True, progress=True):
+    """Builds DRIU for Optical Disc by adding backbone and head together
+
+    Parameters
+    ----------
+
+    pretrained_backbone : :py:class:`bool`, Optional
+        If set to ``True``, then loads a pre-trained version of the backbone
+        (not the head) for the DRIU network using VGG-16 trained for ImageNet
+        classification.
+
+    progress : :py:class:`bool`, Optional
+        If set to ``True``, and you decided to use a ``pretrained_backbone``,
+        then, shows a progress bar of the backbone model downloading if
+        download is necesssary.
+
 
     Returns
     -------
+
     module : :py:class:`torch.nn.Module`
+        Network model for DRIU (optic disc segmentation)
 
     """
-    backbone = vgg16(pretrained=False, return_features=[8, 14, 22, 29])
-    driu_head = DRIUOD([128, 256, 512, 512])
 
-    model = torch.nn.Sequential(
-        OrderedDict([("backbone", backbone), ("head", driu_head)])
+    backbone = vgg16_for_segmentation(
+        pretrained=pretrained_backbone, progress=progress,
+        return_features=[3, 8, 14, 22],
     )
+    head = DRIUOD([128, 256, 512, 512])
+
+    order = [("backbone", backbone), ("head", head)]
+    if pretrained_backbone:
+        from .normalizer import TorchVisionNormalizer
+        order = [("normalizer", TorchVisionNormalizer())] + order
+
+    model = torch.nn.Sequential(OrderedDict(order))
     model.name = "driu-od"
     return model
diff --git a/bob/ip/binseg/modeling/driupix.py b/bob/ip/binseg/models/driu_pix.py
similarity index 60%
rename from bob/ip/binseg/modeling/driupix.py
rename to bob/ip/binseg/models/driu_pix.py
index 5d4a6ce7..76b8c2b4 100644
--- a/bob/ip/binseg/modeling/driupix.py
+++ b/bob/ip/binseg/models/driu_pix.py
@@ -1,31 +1,15 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import torch
-import torch.nn
 from collections import OrderedDict
-from .backbones.vgg import vgg16
-from .make_layers import (
-    conv_with_kaiming_uniform,
-    convtrans_with_kaiming_uniform,
-    UpsampleCropBlock,
-)
-
 
-class ConcatFuseBlock(torch.nn.Module):
-    """
-    Takes in four feature maps with 16 channels each, concatenates them
-    and applies a 1x1 convolution with 1 output channel.
-    """
+import torch
+import torch.nn
 
-    def __init__(self):
-        super().__init__()
-        self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0)
+from .backbones.vgg import vgg16_for_segmentation
 
-    def forward(self, x1, x2, x3, x4):
-        x_cat = torch.cat([x1, x2, x3, x4], dim=1)
-        x = self.conv(x_cat)
-        return x
+from .make_layers import UpsampleCropBlock
+from .driu import ConcatFuseBlock
 
 
 class DRIUPIX(torch.nn.Module):
@@ -77,20 +61,42 @@ class DRIUPIX(torch.nn.Module):
         return out
 
 
-def build_driupix():
-    """
-    Adds backbone and head together
+def driu_pix(pretrained_backbone=True, progress=True):
+    """Builds DRIU with pixelshuffle by adding backbone and head together
+
+    Parameters
+    ----------
+
+    pretrained_backbone : :py:class:`bool`, Optional
+        If set to ``True``, then loads a pre-trained version of the backbone
+        (not the head) for the DRIU network using VGG-16 trained for ImageNet
+        classification.
+
+    progress : :py:class:`bool`, Optional
+        If set to ``True``, and you decided to use a ``pretrained_backbone``,
+        then, shows a progress bar of the backbone model downloading if
+        download is necesssary.
+
 
     Returns
     -------
+
     module : :py:class:`torch.nn.Module`
+        Network model for DRIU (vessel segmentation) with pixelshuffle
 
     """
-    backbone = vgg16(pretrained=False, return_features=[3, 8, 14, 22])
-    driu_head = DRIUPIX([64, 128, 256, 512])
 
-    model = torch.nn.Sequential(
-        OrderedDict([("backbone", backbone), ("head", driu_head)])
+    backbone = vgg16_for_segmentation(
+        pretrained=pretrained_backbone, progress=progress,
+        return_features=[3, 8, 14, 22],
     )
+    head = DRIUPIX([64, 128, 256, 512])
+
+    order = [("backbone", backbone), ("head", head)]
+    if pretrained_backbone:
+        from .normalizer import TorchVisionNormalizer
+        order = [("normalizer", TorchVisionNormalizer())] + order
+
+    model = torch.nn.Sequential(OrderedDict(order))
     model.name = "driu-pix"
     return model
diff --git a/bob/ip/binseg/modeling/hed.py b/bob/ip/binseg/models/hed.py
similarity index 66%
rename from bob/ip/binseg/modeling/hed.py
rename to bob/ip/binseg/models/hed.py
index 5c059b3d..f1bc13ed 100644
--- a/bob/ip/binseg/modeling/hed.py
+++ b/bob/ip/binseg/models/hed.py
@@ -1,15 +1,14 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+from collections import OrderedDict
+
 import torch
 import torch.nn
-from collections import OrderedDict
-from .backbones.vgg import vgg16
-from .make_layers import (
-    conv_with_kaiming_uniform,
-    convtrans_with_kaiming_uniform,
-    UpsampleCropBlock,
-)
+
+from .backbones.vgg import vgg16_for_segmentation
+
+from .make_layers import conv_with_kaiming_uniform, UpsampleCropBlock
 
 
 class ConcatFuseBlock(torch.nn.Module):
@@ -84,19 +83,44 @@ class HED(torch.nn.Module):
         return out
 
 
-def build_hed():
-    """
-    Adds backbone and head together
+def hed(pretrained_backbone=True, progress=True):
+    """Builds HED by adding backbone and head together
+
+    Parameters
+    ----------
+
+    pretrained_backbone : :py:class:`bool`, Optional
+        If set to ``True``, then loads a pre-trained version of the backbone
+        (not the head) for the DRIU network using VGG-16 trained for ImageNet
+        classification.
+
+    progress : :py:class:`bool`, Optional
+        If set to ``True``, and you decided to use a ``pretrained_backbone``,
+        then, shows a progress bar of the backbone model downloading if
+        download is necesssary.
+
 
     Returns
     -------
+
     module : :py:class:`torch.nn.Module`
+        Network model for HED
+
     """
-    backbone = vgg16(pretrained=False, return_features=[3, 8, 14, 22, 29])
-    hed_head = HED([64, 128, 256, 512, 512])
 
-    model = torch.nn.Sequential(
-        OrderedDict([("backbone", backbone), ("head", hed_head)])
+    backbone = vgg16_for_segmentation(
+        pretrained=pretrained_backbone,
+        progress=progress,
+        return_features=[3, 8, 14, 22, 29],
     )
+    head = HED([64, 128, 256, 512, 512])
+
+    order = [("backbone", backbone), ("head", head)]
+    if pretrained_backbone:
+        from .normalizer import TorchVisionNormalizer
+
+        order = [("normalizer", TorchVisionNormalizer())] + order
+
+    model = torch.nn.Sequential(OrderedDict(order))
     model.name = "hed"
     return model
diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/models/losses.py
similarity index 100%
rename from bob/ip/binseg/modeling/losses.py
rename to bob/ip/binseg/models/losses.py
diff --git a/bob/ip/binseg/modeling/m2u.py b/bob/ip/binseg/models/m2unet.py
similarity index 66%
rename from bob/ip/binseg/modeling/m2u.py
rename to bob/ip/binseg/models/m2unet.py
index c4ac69c3..a5fa80f1 100644
--- a/bob/ip/binseg/modeling/m2u.py
+++ b/bob/ip/binseg/models/m2unet.py
@@ -1,12 +1,13 @@
 #!/usr/bin/env python
 # vim: set fileencoding=utf-8 :
 
-# https://github.com/laibe/M2U-Net
-
 from collections import OrderedDict
+
 import torch
 import torch.nn
-from .backbones.mobilenetv2 import MobileNetV2, InvertedResidual
+from torchvision.models.mobilenet import InvertedResidual
+
+from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation
 
 
 class DecoderBlock(torch.nn.Module):
@@ -14,7 +15,9 @@ class DecoderBlock(torch.nn.Module):
     Decoder block: upsample and concatenate with features maps from the encoder part
     """
 
-    def __init__(self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
+    def __init__(
+        self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15
+    ):
         super().__init__()
         self.upsample = torch.nn.Upsample(
             scale_factor=2, mode=upsamplemode, align_corners=False
@@ -39,7 +42,9 @@ class LastDecoderBlock(torch.nn.Module):
         self.upsample = torch.nn.Upsample(
             scale_factor=2, mode=upsamplemode, align_corners=False
         )  # H, W -> 2H, 2W
-        self.ir1 = InvertedResidual(x_in_c, 1, stride=1, expand_ratio=expand_ratio)
+        self.ir1 = InvertedResidual(
+            x_in_c, 1, stride=1, expand_ratio=expand_ratio
+        )
 
     def forward(self, up_in, x_in):
         up_out = self.upsample(up_in)
@@ -48,7 +53,7 @@ class LastDecoderBlock(torch.nn.Module):
         return x
 
 
-class M2U(torch.nn.Module):
+class M2UNet(torch.nn.Module):
     """
     M2U-Net head module
 
@@ -61,7 +66,7 @@ class M2U(torch.nn.Module):
     def __init__(
         self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15
     ):
-        super(M2U, self).__init__()
+        super(M2UNet, self).__init__()
 
         # Decoder
         self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio)
@@ -102,19 +107,44 @@ class M2U(torch.nn.Module):
         return decode1
 
 
-def build_m2unet():
-    """
-    Adds backbone and head together
+def m2unet(pretrained_backbone=True, progress=True):
+    """Builds M2U-Net for segmentation by adding backbone and head together
+
+
+    Parameters
+    ----------
+
+    pretrained_backbone : :py:class:`bool`, Optional
+        If set to ``True``, then loads a pre-trained version of the backbone
+        (not the head) for the DRIU network using VGG-16 trained for ImageNet
+        classification.
+
+    progress : :py:class:`bool`, Optional
+        If set to ``True``, and you decided to use a ``pretrained_backbone``,
+        then, shows a progress bar of the backbone model downloading if
+        download is necesssary.
+
 
     Returns
     -------
+
     module : :py:class:`torch.nn.Module`
+        Network model for M2U-Net (segmentation)
+
     """
-    backbone = MobileNetV2(return_features=[1, 3, 6, 13], m2u=True)
-    m2u_head = M2U(in_channels_list=[16, 24, 32, 96])
 
-    model = torch.nn.Sequential(
-        OrderedDict([("backbone", backbone), ("head", m2u_head)])
+    backbone = mobilenet_v2_for_segmentation(
+        pretrained=pretrained_backbone,
+        progress=progress,
+        return_features=[1, 3, 6, 13],
     )
+    head = M2UNet(in_channels_list=[16, 24, 32, 96])
+
+    order = [("backbone", backbone), ("head", head)]
+    if pretrained_backbone:
+        from .normalizer import TorchVisionNormalizer
+        order = [("normalizer", TorchVisionNormalizer())] + order
+
+    model = torch.nn.Sequential(OrderedDict(order))
     model.name = "m2unet"
     return model
diff --git a/bob/ip/binseg/modeling/make_layers.py b/bob/ip/binseg/models/make_layers.py
similarity index 100%
rename from bob/ip/binseg/modeling/make_layers.py
rename to bob/ip/binseg/models/make_layers.py
diff --git a/bob/ip/binseg/models/normalizer.py b/bob/ip/binseg/models/normalizer.py
new file mode 100755
index 00000000..c62d3e48
--- /dev/null
+++ b/bob/ip/binseg/models/normalizer.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+"""A network model that prefixes a z-normalization step to any other module"""
+
+
+import torch
+import torch.nn
+
+
+class TorchVisionNormalizer(torch.nn.Module):
+    """A simple normalizer that applies the standard torchvision normalization
+
+    This module does not learn.
+
+    The values applied in this "prefix" operator are defined at
+    https://pytorch.org/docs/stable/torchvision/models.html, and are as
+    follows:
+
+    * ``mean``: ``[0.485, 0.456, 0.406]``,
+    * ``std``: ``[0.229, 0.224, 0.225]``
+    """
+
+    def __init__(self):
+        super(TorchVisionNormalizer, self).__init__()
+        mean = torch.as_tensor([0.485, 0.456, 0.406])[None, :, None, None]
+        std = torch.as_tensor([0.229, 0.224, 0.225])[None, :, None, None]
+        self.register_buffer('mean', mean)
+        self.register_buffer('std', std)
+        self.name = "torchvision-normalizer"
+
+    def forward(self, inputs):
+        return inputs.sub(self.mean).div(self.std)
diff --git a/bob/ip/binseg/modeling/resunet.py b/bob/ip/binseg/models/resunet.py
similarity index 50%
rename from bob/ip/binseg/modeling/resunet.py
rename to bob/ip/binseg/models/resunet.py
index f53ff73d..8b21bd5e 100644
--- a/bob/ip/binseg/modeling/resunet.py
+++ b/bob/ip/binseg/models/resunet.py
@@ -1,25 +1,32 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import torch.nn as nn
 from collections import OrderedDict
+
+import torch.nn
+
 from .make_layers import (
     conv_with_kaiming_uniform,
     convtrans_with_kaiming_uniform,
     PixelShuffle_ICNR,
     UnetBlock,
 )
-from .backbones.resnet import resnet50
 
+from .backbones.resnet import resnet50_for_segmentation
 
-class ResUNet(nn.Module):
-    """
-    UNet head module for ResNet backbones
+
+class ResUNet(torch.nn.Module):
+    """UNet head module for ResNet backbones
 
     Parameters
     ----------
-    in_channels_list : list
-                        number of channels for each feature map that is returned from backbone
+
+    in_channels_list : :py:class:`list`, Optional
+        number of channels for each feature map that is returned from backbone
+
+    pixel_shuffle : :py:class:`bool`, Optional
+        if should use pixel shuffling instead of pooling
+
     """
 
     def __init__(self, in_channels_list=None, pixel_shuffle=False):
@@ -37,7 +44,9 @@ class ResUNet(nn.Module):
         if pixel_shuffle:
             self.decode0 = PixelShuffle_ICNR(c_decode0, c_decode0)
         else:
-            self.decode0 = convtrans_with_kaiming_uniform(c_decode0, c_decode0, 2, 2)
+            self.decode0 = convtrans_with_kaiming_uniform(
+                c_decode0, c_decode0, 2, 2
+            )
         self.final = conv_with_kaiming_uniform(c_decode0, 1, 1)
 
     def forward(self, x):
@@ -49,7 +58,8 @@ class ResUNet(nn.Module):
                 First element: height and width of input image.
                 Remaining elements: feature maps for each feature level.
         """
-        # NOTE: x[0]: height and width of input image not needed in U-Net architecture
+        # NOTE: x[0]: height and width of input image not needed in U-Net
+        # architecture
         decode4 = self.decode4(x[5], x[4])
         decode3 = self.decode3(decode4, x[3])
         decode2 = self.decode2(decode3, x[2])
@@ -59,16 +69,44 @@ class ResUNet(nn.Module):
         return out
 
 
-def build_res50unet():
-    """
-    Adds backbone and head together
+def resunet50(pretrained_backbone=True, progress=True):
+    """Builds Residual-U-Net-50 by adding backbone and head together
+
+    Parameters
+    ----------
+
+    pretrained_backbone : :py:class:`bool`, Optional
+        If set to ``True``, then loads a pre-trained version of the backbone
+        (not the head) for the DRIU network using VGG-16 trained for ImageNet
+        classification.
+
+    progress : :py:class:`bool`, Optional
+        If set to ``True``, and you decided to use a ``pretrained_backbone``,
+        then, shows a progress bar of the backbone model downloading if
+        download is necesssary.
+
 
     Returns
     -------
-    model : :py:class:`torch.nn.Module`
+
+    module : :py:class:`torch.nn.Module`
+        Network model for Residual U-Net 50
+
     """
-    backbone = resnet50(pretrained=False, return_features=[2, 4, 5, 6, 7])
-    unet_head = ResUNet([64, 256, 512, 1024, 2048], pixel_shuffle=False)
-    model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
-    model.name = "resunet"
+
+    backbone = resnet50_for_segmentation(
+        pretrained=pretrained_backbone,
+        progress=progress,
+        return_features=[2, 4, 5, 6, 7],
+    )
+    head = ResUNet([64, 256, 512, 1024, 2048], pixel_shuffle=False)
+
+    order = [("backbone", backbone), ("head", head)]
+    if pretrained_backbone:
+        from .normalizer import TorchVisionNormalizer
+
+        order = [("normalizer", TorchVisionNormalizer())] + order
+
+    model = torch.nn.Sequential(OrderedDict(order))
+    model.name = "resunet50"
     return model
diff --git a/bob/ip/binseg/modeling/unet.py b/bob/ip/binseg/models/unet.py
similarity index 55%
rename from bob/ip/binseg/modeling/unet.py
rename to bob/ip/binseg/models/unet.py
index 37b6de6a..f0388be3 100644
--- a/bob/ip/binseg/modeling/unet.py
+++ b/bob/ip/binseg/models/unet.py
@@ -1,18 +1,15 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import torch.nn as nn
 from collections import OrderedDict
-from .make_layers import (
-    conv_with_kaiming_uniform,
-    convtrans_with_kaiming_uniform,
-    PixelShuffle_ICNR,
-    UnetBlock,
-)
-from .backbones.vgg import vgg16
 
+import torch.nn
 
-class UNet(nn.Module):
+from .backbones.vgg import vgg16_for_segmentation
+from .make_layers import conv_with_kaiming_uniform, UnetBlock
+
+
+class UNet(torch.nn.Module):
     """
     UNet head module
 
@@ -52,18 +49,42 @@ class UNet(nn.Module):
         return out
 
 
-def build_unet():
-    """
-    Adds backbone and head together
+def unet(pretrained_backbone=True, progress=True):
+    """Builds U-Net segmentation network by adding backbone and head together
+
+    Parameters
+    ----------
+
+    pretrained_backbone : :py:class:`bool`, Optional
+        If set to ``True``, then loads a pre-trained version of the backbone
+        (not the head) for the DRIU network using VGG-16 trained for ImageNet
+        classification.
+
+    progress : :py:class:`bool`, Optional
+        If set to ``True``, and you decided to use a ``pretrained_backbone``,
+        then, shows a progress bar of the backbone model downloading if
+        download is necesssary.
+
 
     Returns
     -------
+
     module : :py:class:`torch.nn.Module`
+        Network model for U-Net
+
     """
 
-    backbone = vgg16(pretrained=False, return_features=[3, 8, 14, 22, 29])
-    unet_head = UNet([64, 128, 256, 512, 512], pixel_shuffle=False)
+    backbone = vgg16_for_segmentation(
+        pretrained=pretrained_backbone, progress=progress,
+        return_features=[3, 8, 14, 22, 29],
+    )
+    head = UNet([64, 128, 256, 512, 512], pixel_shuffle=False)
+
+    order = [("backbone", backbone), ("head", head)]
+    if pretrained_backbone:
+        from .normalizer import TorchVisionNormalizer
+        order = [("normalizer", TorchVisionNormalizer())] + order
 
-    model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
-    model.name = "UNet"
+    model = torch.nn.Sequential(OrderedDict(order))
+    model.name = "unet"
     return model
diff --git a/bob/ip/binseg/script/experiment.py b/bob/ip/binseg/script/experiment.py
index f43fe5fc..0dc7153f 100644
--- a/bob/ip/binseg/script/experiment.py
+++ b/bob/ip/binseg/script/experiment.py
@@ -93,15 +93,6 @@ logger = logging.getLogger(__name__)
     required=True,
     cls=ResourceOption,
 )
-@click.option(
-    "--pretrained-backbone",
-    "-t",
-    help="URL of a pre-trained model file that will be used to preset "
-    "FCN weights (where relevant) before training starts "
-    "(e.g. vgg16, mobilenetv2)",
-    required=True,
-    cls=ResourceOption,
-)
 @click.option(
     "--batch-size",
     "-b",
@@ -224,7 +215,6 @@ def experiment(
     scheduler,
     output_folder,
     epochs,
-    pretrained_backbone,
     batch_size,
     drop_incomplete_batch,
     criterion,
@@ -310,7 +300,6 @@ def experiment(
         scheduler=scheduler,
         output_folder=train_output_folder,
         epochs=epochs,
-        pretrained_backbone=pretrained_backbone,
         batch_size=batch_size,
         drop_incomplete_batch=drop_incomplete_batch,
         criterion=criterion,
diff --git a/bob/ip/binseg/script/predict.py b/bob/ip/binseg/script/predict.py
index 14c9cd74..e096b60d 100644
--- a/bob/ip/binseg/script/predict.py
+++ b/bob/ip/binseg/script/predict.py
@@ -15,7 +15,7 @@ from bob.extension.scripts.click_helper import (
 )
 
 from ..engine.predictor import run
-from ..utils.checkpointer import DetectronCheckpointer
+from ..utils.checkpointer import Checkpointer
 
 from .binseg import download_to_tempfile
 
@@ -124,11 +124,8 @@ def predict(output_folder, model, dataset, batch_size, device, weight,
     else:
         weight_fullpath = os.path.abspath(weight)
 
-    weight_path = os.path.dirname(weight_fullpath)
-    weight_name = os.path.basename(weight_fullpath)
-    checkpointer = DetectronCheckpointer(model, save_dir=weight_path,
-            save_to_disk=False)
-    checkpointer.load(weight_name)
+    checkpointer = Checkpointer(model)
+    checkpointer.load(weight_fullpath)
 
     # clean-up the overlayed path
     if overlayed is not None:
diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py
index da56a3bd..5edcd7e7 100644
--- a/bob/ip/binseg/script/train.py
+++ b/bob/ip/binseg/script/train.py
@@ -13,7 +13,7 @@ from bob.extension.scripts.click_helper import (
     ResourceOption,
 )
 
-from ..utils.checkpointer import DetectronCheckpointer
+from ..utils.checkpointer import Checkpointer
 
 import logging
 logger = logging.getLogger(__name__)
@@ -97,15 +97,6 @@ logger = logging.getLogger(__name__)
     required=True,
     cls=ResourceOption,
 )
-@click.option(
-    "--pretrained-backbone",
-    "-t",
-    help="URL of a pre-trained model file that will be used to preset "
-    "FCN weights (where relevant) before training starts "
-    "(e.g. vgg16, mobilenetv2)",
-    required=True,
-    cls=ResourceOption,
-)
 @click.option(
     "--batch-size",
     "-b",
@@ -204,7 +195,6 @@ def train(
     scheduler,
     output_folder,
     epochs,
-    pretrained_backbone,
     batch_size,
     drop_incomplete_batch,
     criterion,
@@ -261,14 +251,11 @@ def train(
                 pin_memory=torch.cuda.is_available(),
                 )
 
-    # Checkpointer
-    checkpointer = DetectronCheckpointer(
-        model, optimizer, scheduler, save_dir=output_folder, save_to_disk=True
-    )
+    checkpointer = Checkpointer(model, optimizer, scheduler, path=output_folder)
 
     arguments = {}
     arguments["epoch"] = 0
-    extra_checkpoint_data = checkpointer.load(pretrained_backbone)
+    extra_checkpoint_data = checkpointer.load()
     arguments.update(extra_checkpoint_data)
     arguments["max_epoch"] = epochs
 
diff --git a/bob/ip/binseg/test/test_checkpointer.py b/bob/ip/binseg/test/test_checkpointer.py
index 7f0ca50d..16df40af 100644
--- a/bob/ip/binseg/test/test_checkpointer.py
+++ b/bob/ip/binseg/test/test_checkpointer.py
@@ -1,13 +1,14 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+import os
+import unittest
 from collections import OrderedDict
 from tempfile import TemporaryDirectory
-import unittest
 
 import torch
+import nose.tools
 from torch import nn
-import os
 
 from ..utils.checkpointer import Checkpointer
 
@@ -39,45 +40,43 @@ class TestCheckpointer(unittest.TestCase):
         trained_model = self.create_model()
         fresh_model = self.create_model()
         with TemporaryDirectory() as f:
-            checkpointer = Checkpointer(trained_model, save_dir=f, save_to_disk=True)
+            checkpointer = Checkpointer(trained_model, path=f)
             checkpointer.save("checkpoint_file")
 
             # in the same folder
-            fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
-            self.assertTrue(fresh_checkpointer.has_checkpoint())
-            self.assertEqual(
-                fresh_checkpointer.get_checkpoint_file(),
-                "checkpoint_file.pth",
-            )
+            fresh_checkpointer = Checkpointer(fresh_model, path=f)
+            assert fresh_checkpointer.has_checkpoint()
+            nose.tools.eq_(fresh_checkpointer.last_checkpoint(),
+                    os.path.realpath(os.path.join(f, "checkpoint_file.pth")))
             _ = fresh_checkpointer.load()
 
         for trained_p, loaded_p in zip(
             trained_model.parameters(), fresh_model.parameters()
         ):
             # different tensor references
-            self.assertFalse(id(trained_p) == id(loaded_p))
+            nose.tools.assert_not_equal(id(trained_p), id(loaded_p))
             # same content
-            self.assertTrue(trained_p.equal(loaded_p))
+            assert trained_p.equal(loaded_p)
 
     def test_from_name_file_model(self):
         # test that loading works even if they differ by a prefix
         trained_model = self.create_model()
         fresh_model = self.create_model()
         with TemporaryDirectory() as f:
-            checkpointer = Checkpointer(trained_model, save_dir=f, save_to_disk=True)
+            checkpointer = Checkpointer(trained_model, path=f)
             checkpointer.save("checkpoint_file")
 
             # on different folders
             with TemporaryDirectory() as g:
-                fresh_checkpointer = Checkpointer(fresh_model, save_dir=g)
-                self.assertFalse(fresh_checkpointer.has_checkpoint())
-                self.assertEqual(fresh_checkpointer.get_checkpoint_file(), "")
+                fresh_checkpointer = Checkpointer(fresh_model, path=g)
+                assert not fresh_checkpointer.has_checkpoint()
+                nose.tools.eq_(fresh_checkpointer.last_checkpoint(), None)
                 _ = fresh_checkpointer.load(os.path.join(f, "checkpoint_file.pth"))
 
         for trained_p, loaded_p in zip(
             trained_model.parameters(), fresh_model.parameters()
         ):
             # different tensor references
-            self.assertFalse(id(trained_p) == id(loaded_p))
+            nose.tools.assert_not_equal(id(trained_p), id(loaded_p))
             # same content
-            self.assertTrue(trained_p.equal(loaded_p))
+            assert trained_p.equal(loaded_p)
diff --git a/bob/ip/binseg/test/test_cli.py b/bob/ip/binseg/test/test_cli.py
index 9a7a9847..2b188270 100644
--- a/bob/ip/binseg/test/test_cli.py
+++ b/bob/ip/binseg/test/test_cli.py
@@ -184,7 +184,7 @@ def _check_experiment_stare(overlay):
             r"^Saving checkpoint": 2,
             r"^Ended training$": 1,
             r"^Started prediction$": 1,
-            r"^Loading checkpoint from": 2,
+            r"^Loading checkpoint from": 1,
             r"^Ended prediction$": 1,
             r"^Started evaluation$": 1,
             r"^Maximum F1-score of.*\(chosen \*a posteriori\*\)$": 3,
@@ -268,8 +268,8 @@ def _check_train(runner):
             r"^Continuing from epoch 0$": 1,
             r"^Saving model summary at.*$": 1,
             r"^Model has.*$": 1,
-            rf"^Saving checkpoint to {output_folder}/model_lowest_valid_loss.pth$": 1,
-            rf"^Saving checkpoint to {output_folder}/model_final.pth$": 1,
+            r"^Saving checkpoint to .*/model_lowest_valid_loss.pth$": 1,
+            r"^Saving checkpoint to .*/model_final.pth$": 1,
             r"^Total training time:": 1,
         }
         buf.seek(0)
diff --git a/bob/ip/binseg/test/test_models.py b/bob/ip/binseg/test/test_models.py
new file mode 100755
index 00000000..f078a99e
--- /dev/null
+++ b/bob/ip/binseg/test/test_models.py
@@ -0,0 +1,140 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+"""Tests model loading"""
+
+
+import nose.tools
+from ..models.normalizer import TorchVisionNormalizer
+from ..models.backbones.vgg import VGG4Segmentation
+
+
+def test_driu():
+
+    from ..models.driu import driu, DRIU
+
+    model = driu(pretrained_backbone=True, progress=True)
+    nose.tools.eq_(len(model), 3)
+    nose.tools.eq_(type(model[0]), TorchVisionNormalizer)
+    nose.tools.eq_(type(model[1]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[2]), DRIU)  #head
+
+    model = driu(pretrained_backbone=False)
+    nose.tools.eq_(len(model), 2)
+    nose.tools.eq_(type(model[0]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[1]), DRIU)  #head
+
+
+def test_driu_bn():
+
+    from ..models.driu_bn import driu_bn, DRIUBN
+
+    model = driu_bn(pretrained_backbone=True, progress=True)
+    nose.tools.eq_(len(model), 3)
+    nose.tools.eq_(type(model[0]), TorchVisionNormalizer)
+    nose.tools.eq_(type(model[1]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[2]), DRIUBN)  #head
+
+    model = driu_bn(pretrained_backbone=False)
+    nose.tools.eq_(len(model), 2)
+    nose.tools.eq_(type(model[0]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[1]), DRIUBN)  #head
+
+
+def test_driu_od():
+
+    from ..models.driu_od import driu_od, DRIUOD
+
+    model = driu_od(pretrained_backbone=True, progress=True)
+    nose.tools.eq_(len(model), 3)
+    nose.tools.eq_(type(model[0]), TorchVisionNormalizer)
+    nose.tools.eq_(type(model[1]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[2]), DRIUOD)  #head
+
+    model = driu_od(pretrained_backbone=False)
+    nose.tools.eq_(len(model), 2)
+    nose.tools.eq_(type(model[0]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[1]), DRIUOD)  #head
+
+
+def test_driu_pix():
+
+    from ..models.driu_pix import driu_pix, DRIUPIX
+
+    model = driu_pix(pretrained_backbone=True, progress=True)
+    nose.tools.eq_(len(model), 3)
+    nose.tools.eq_(type(model[0]), TorchVisionNormalizer)
+    nose.tools.eq_(type(model[1]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[2]), DRIUPIX)  #head
+
+    model = driu_pix(pretrained_backbone=False)
+    nose.tools.eq_(len(model), 2)
+    nose.tools.eq_(type(model[0]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[1]), DRIUPIX)  #head
+
+
+def test_unet():
+
+    from ..models.unet import unet, UNet
+
+    model = unet(pretrained_backbone=True, progress=True)
+    nose.tools.eq_(len(model), 3)
+    nose.tools.eq_(type(model[0]), TorchVisionNormalizer)
+    nose.tools.eq_(type(model[1]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[2]), UNet)  #head
+
+    model = unet(pretrained_backbone=False)
+    nose.tools.eq_(len(model), 2)
+    nose.tools.eq_(type(model[0]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[1]), UNet)  #head
+
+
+def test_hed():
+
+    from ..models.hed import hed, HED
+
+    model = hed(pretrained_backbone=True, progress=True)
+    nose.tools.eq_(len(model), 3)
+    nose.tools.eq_(type(model[0]), TorchVisionNormalizer)
+    nose.tools.eq_(type(model[1]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[2]), HED)  #head
+
+    model = hed(pretrained_backbone=False)
+    nose.tools.eq_(len(model), 2)
+    nose.tools.eq_(type(model[0]), VGG4Segmentation)  #backbone
+    nose.tools.eq_(type(model[1]), HED)  #head
+
+
+def test_m2unet():
+
+    from ..models.m2unet import m2unet, M2UNet
+    from ..models.backbones.mobilenetv2 import MobileNetV24Segmentation
+
+    model = m2unet(pretrained_backbone=True, progress=True)
+    nose.tools.eq_(len(model), 3)
+    nose.tools.eq_(type(model[0]), TorchVisionNormalizer)
+    nose.tools.eq_(type(model[1]), MobileNetV24Segmentation)  #backbone
+    nose.tools.eq_(type(model[2]), M2UNet)  #head
+
+    model = m2unet(pretrained_backbone=False)
+    nose.tools.eq_(len(model), 2)
+    nose.tools.eq_(type(model[0]), MobileNetV24Segmentation)  #backbone
+    nose.tools.eq_(type(model[1]), M2UNet)  #head
+
+
+def test_resunet50():
+
+    from ..models.resunet import resunet50, ResUNet
+    from ..models.backbones.resnet import ResNet4Segmentation
+
+    model = resunet50(pretrained_backbone=True, progress=True)
+    nose.tools.eq_(len(model), 3)
+    nose.tools.eq_(type(model[0]), TorchVisionNormalizer)
+    nose.tools.eq_(type(model[1]), ResNet4Segmentation)  #backbone
+    nose.tools.eq_(type(model[2]), ResUNet)  #head
+
+    model = resunet50(pretrained_backbone=False)
+    nose.tools.eq_(len(model), 2)
+    nose.tools.eq_(type(model[0]), ResNet4Segmentation)  #backbone
+    nose.tools.eq_(type(model[1]), ResUNet)  #head
+    print(model)
diff --git a/bob/ip/binseg/utils/checkpointer.py b/bob/ip/binseg/utils/checkpointer.py
index 33c200ad..19090db7 100644
--- a/bob/ip/binseg/utils/checkpointer.py
+++ b/bob/ip/binseg/utils/checkpointer.py
@@ -1,43 +1,45 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import os
 
 import torch
-import os
 
 from .model_serialization import load_state_dict
-from .model_zoo import cache_url
 
 import logging
+
 logger = logging.getLogger(__name__)
 
 
 class Checkpointer:
-    """Adapted from `maskrcnn-benchmark
-    <https://github.com/facebookresearch/maskrcnn-benchmark>`_ under MIT license
+    """A simple pytorch checkpointer
+
+    Parameters
+    ----------
+
+    model : torch.nn.Module
+        Network model, eventually loaded from a checkpointed file
+
+    optimizer : :py:mod:`torch.optim`, Optional
+        Optimizer
+
+    scheduler : :py:mod:`torch.optim`, Optional
+        Learning rate scheduler
+
+    path : :py:class:`str`, Optional
+        Directory where to save checkpoints.
+
     """
 
-    def __init__(
-        self,
-        model,
-        optimizer=None,
-        scheduler=None,
-        save_dir="",
-        save_to_disk=None,
-    ):
+    def __init__(self, model, optimizer=None, scheduler=None, path="."):
+
         self.model = model
         self.optimizer = optimizer
         self.scheduler = scheduler
-        self.save_dir = save_dir
-        self.save_to_disk = save_to_disk
+        self.path = os.path.realpath(path)
 
     def save(self, name, **kwargs):
-        if not self.save_dir:
-            return
-
-        if not self.save_to_disk:
-            return
 
         data = {}
         data["model"] = self.model.state_dict()
@@ -47,85 +49,59 @@ class Checkpointer:
             data["scheduler"] = self.scheduler.state_dict()
         data.update(kwargs)
 
-        dest_filename = f"{name}.pth"
-        save_file = os.path.join(self.save_dir, dest_filename)
-        logger.info(f"Saving checkpoint to {save_file}")
-        torch.save(data, save_file)
-        self.tag_last_checkpoint(dest_filename)
+        name = f"{name}.pth"
+        outf = os.path.join(self.path, name)
+        logger.info(f"Saving checkpoint to {outf}")
+        torch.save(data, outf)
+        with open(self._last_checkpoint_filename, "w") as f:
+            f.write(name)
 
     def load(self, f=None):
-        if self.has_checkpoint():
-            # override argument with existing checkpoint
-            f = self.get_checkpoint_file()
-        if not f:
+        """Loads model, optimizer and scheduler from file
+
+
+        Parameters
+        ==========
+
+        f : :py:class:`str`, Optional
+            Name of a file (absolute or relative to ``self.path``), that
+            contains the checkpoint data to load into the model, and optionally
+            into the optimizer and the scheduler.  If not specified, loads data
+            from current path.
+
+        """
+
+        if f is None:
+            f = self.last_checkpoint()
+
+        if f is None:
             # no checkpoint could be found
-            logger.warn("No checkpoint found. Initializing model from scratch")
+            logger.warn("No checkpoint found (and none passed)")
             return {}
-        checkpoint = self._load_file(f)
-        self._load_model(checkpoint)
-        actual_file = os.path.join(self.save_dir, f)
-        if "optimizer" in checkpoint and self.optimizer:
-            logger.info(f"Loading optimizer from {actual_file}")
+
+        # loads file data into memory
+        logger.info(f"Loading checkpoint from {f}...")
+        checkpoint = torch.load(f, map_location=torch.device("cpu"))
+
+        # converts model entry to model parameters
+        load_state_dict(self.model, checkpoint.pop("model"))
+
+        if self.optimizer is not None:
             self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
-        if "scheduler" in checkpoint and self.scheduler:
-            logger.info(f"Loading scheduler from {actual_file}")
+        if self.scheduler is not None:
             self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
 
-        # return any further checkpoint data
         return checkpoint
 
-    def has_checkpoint(self):
-        save_file = os.path.join(self.save_dir, "last_checkpoint")
-        return os.path.exists(save_file)
-
-    def get_checkpoint_file(self):
-        save_file = os.path.join(self.save_dir, "last_checkpoint")
-        try:
-            with open(save_file, "r") as f:
-                last_saved = f.read()
-                last_saved = last_saved.strip()
-        except IOError:
-            # if file doesn't exist, maybe because it has just been
-            # deleted by a separate process
-            last_saved = ""
-        return last_saved
-
-    def tag_last_checkpoint(self, last_filename):
-        save_file = os.path.join(self.save_dir, "last_checkpoint")
-        with open(save_file, "w") as f:
-            f.write(last_filename)
-
-    def _load_file(self, f):
-        actual_file = os.path.join(self.save_dir, f)
-        logger.info(f"Loading checkpoint from {actual_file}")
-        return torch.load(actual_file, map_location=torch.device("cpu"))
-
-    def _load_model(self, checkpoint):
-        load_state_dict(self.model, checkpoint.pop("model"))
+    @property
+    def _last_checkpoint_filename(self):
+        return os.path.join(self.path, "last_checkpoint")
 
+    def has_checkpoint(self):
+        return os.path.exists(self._last_checkpoint_filename)
 
-class DetectronCheckpointer(Checkpointer):
-    def __init__(
-        self,
-        model,
-        optimizer=None,
-        scheduler=None,
-        save_dir="",
-        save_to_disk=None,
-    ):
-        super(DetectronCheckpointer, self).__init__(
-            model, optimizer, scheduler, save_dir, save_to_disk
-        )
-
-    def _load_file(self, f):
-        # download url files
-        if f.startswith("http"):
-            # if the file is a url path, download it and cache it
-            cached_f = cache_url(f)
-            logger.info(f"url {f} cached in {cached_f}")
-            f = cached_f
-        # load checkpoint
-        loaded = super(DetectronCheckpointer, self)._load_file(f)
-        if "model" not in loaded:
-            loaded = dict(model=loaded)
-        return loaded
+    def last_checkpoint(self):
+        if self.has_checkpoint():
+            with open(self._last_checkpoint_filename, "r") as fobj:
+                return os.path.join(self.path, fobj.read().strip())
+        return None
diff --git a/bob/ip/binseg/utils/model_serialization.py b/bob/ip/binseg/utils/model_serialization.py
index 4c84e84f..d629eae1 100644
--- a/bob/ip/binseg/utils/model_serialization.py
+++ b/bob/ip/binseg/utils/model_serialization.py
@@ -4,6 +4,7 @@
 from collections import OrderedDict
 
 import logging
+
 logger = logging.getLogger(__name__)
 
 import torch
@@ -11,25 +12,31 @@ import torch
 
 def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
     """
-    Strategy: suppose that the models that we will create will have prefixes appended
-    to each of its keys, for example due to an extra level of nesting that the original
-    pre-trained weights from ImageNet won't contain. For example, model.state_dict()
-    might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
-    res2.conv1.weight. We thus want to match both parameters together.
-    For that, we look for each model weight, look among all loaded keys if there is one
-    that is a suffix of the current weight name, and use it if that's the case.
-    If multiple matches exist, take the one with longest size
-    of the corresponding name. For example, for the same model as before, the pretrained
-    weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
-    we want to match backbone[0].body.conv1.weight to conv1.weight, and
-    backbone[0].body.res2.conv1.weight to res2.conv1.weight.
+
+    Strategy: suppose that the models that we will create will have prefixes
+    appended to each of its keys, for example due to an extra level of nesting
+    that the original pre-trained weights from ImageNet won't contain. For
+    example, model.state_dict() might return
+    backbone[0].body.res2.conv1.weight, while the pre-trained model contains
+    res2.conv1.weight. We thus want to match both parameters together.  For
+    that, we look for each model weight, look among all loaded keys if there is
+    one that is a suffix of the current weight name, and use it if that's the
+    case.  If multiple matches exist, take the one with longest size of the
+    corresponding name. For example, for the same model as before, the
+    pretrained weight file can contain both res2.conv1.weight, as well as
+    conv1.weight. In this case, we want to match backbone[0].body.conv1.weight
+    to conv1.weight, and backbone[0].body.res2.conv1.weight to
+    res2.conv1.weight.
     """
+
     current_keys = sorted(list(model_state_dict.keys()))
     loaded_keys = sorted(list(loaded_state_dict.keys()))
     # get a matrix of string matches, where each (i, j) entry correspond to the size of the
     # loaded_key string, if it matches
     match_matrix = [
-        len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys
+        len(j) if i.endswith(j) else 0
+        for i in current_keys
+        for j in loaded_keys
     ]
     match_matrix = torch.as_tensor(match_matrix).view(
         len(current_keys), len(loaded_keys)
@@ -40,7 +47,9 @@ def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
 
     # used for logging
     max_size = max([len(key) for key in current_keys]) if current_keys else 1
-    max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
+    max_size_loaded = (
+        max([len(key) for key in loaded_keys]) if loaded_keys else 1
+    )
     log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
     for idx_new, idx_old in enumerate(idxs.tolist()):
         if idx_old == -1:
@@ -74,7 +83,9 @@ def load_state_dict(model, loaded_state_dict):
     # if the state_dict comes from a model that was wrapped in a
     # DataParallel or DistributedDataParallel during serialization,
     # remove the "module" prefix before performing the matching
-    loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
+    loaded_state_dict = strip_prefix_if_present(
+        loaded_state_dict, prefix="module."
+    )
     align_and_update_state_dicts(model_state_dict, loaded_state_dict)
 
     # use strict loading
diff --git a/bob/ip/binseg/utils/model_zoo.py b/bob/ip/binseg/utils/model_zoo.py
deleted file mode 100644
index 2eb98f55..00000000
--- a/bob/ip/binseg/utils/model_zoo.py
+++ /dev/null
@@ -1,116 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-
-# Adpated from:
-# https://github.com/pytorch/pytorch/blob/master/torch/hub.py
-# https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/checkpoint.py
-
-import hashlib
-import os
-import re
-import shutil
-import sys
-import tempfile
-from urllib.request import urlopen
-from urllib.parse import urlparse
-from tqdm import tqdm
-
-modelurls = {
-    "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
-    "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
-    "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
-    "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
-    "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
-    "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
-    "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
-    "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
-    "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
-    "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
-    "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
-    "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
-    "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
-    #"resnet50_SIN_IN": "https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar",
-    "resnet50_SIN_IN": "http://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar",
-    #"mobilenetv2": "https://dl.dropboxusercontent.com/s/4nie4ygivq04p8y/mobilenet_v2.pth.tar",
-    "mobilenetv2": "http://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/mobilenet_v2.pth.tar",
-}
-"""URLs of pre-trained models (backbones)"""
-
-
-def download_url_to_file(url, dst, hash_prefix, progress):
-    file_size = None
-    u = urlopen(url)
-    meta = u.info()
-    if hasattr(meta, "getheaders"):
-        content_length = meta.getheaders("Content-Length")
-    else:
-        content_length = meta.get_all("Content-Length")
-    if content_length is not None and len(content_length) > 0:
-        file_size = int(content_length[0])
-
-    f = tempfile.NamedTemporaryFile(delete=False)
-    try:
-        if hash_prefix is not None:
-            sha256 = hashlib.sha256()
-        with tqdm(total=file_size, disable=not progress) as pbar:
-            while True:
-                buffer = u.read(8192)
-                if len(buffer) == 0:
-                    break
-                f.write(buffer)
-                if hash_prefix is not None:
-                    sha256.update(buffer)
-                pbar.update(len(buffer))
-
-        f.close()
-        if hash_prefix is not None:
-            digest = sha256.hexdigest()
-            if digest[: len(hash_prefix)] != hash_prefix:
-                raise RuntimeError(
-                    'invalid hash value (expected "{}", got "{}")'.format(
-                        hash_prefix, digest
-                    )
-                )
-        shutil.move(f.name, dst)
-    finally:
-        f.close()
-        if os.path.exists(f.name):
-            os.remove(f.name)
-
-
-HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
-
-
-def cache_url(url, model_dir=None, progress=True):
-    r"""Loads the Torch serialized object at the given URL.
-    If the object is already present in `model_dir`, it's deserialized and
-    returned. The filename part of the URL should follow the naming convention
-    ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
-    digits of the SHA256 hash of the contents of the file. The hash is used to
-    ensure unique names and to verify the contents of the file.
-    The default value of `model_dir` is ``$TORCH_HOME/models`` where
-    ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
-    overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
-    Args:
-        url (string): URL of the object to download
-        model_dir (string, optional): directory in which to save the object
-        progress (bool, optional): whether or not to display a progress bar to stderr
-
-    """
-    if model_dir is None:
-        torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch"))
-        model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models"))
-    if not os.path.exists(model_dir):
-        os.makedirs(model_dir)
-    parts = urlparse(url)
-    filename = os.path.basename(parts.path)
-
-    cached_file = os.path.join(model_dir, filename)
-    if not os.path.exists(cached_file):
-        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
-        hash_prefix = HASH_REGEX.search(filename)
-        if hash_prefix is not None:
-            hash_prefix = hash_prefix.group(1)
-        download_url_to_file(url, cached_file, hash_prefix, progress=progress)
-
-    return cached_file
diff --git a/conda/meta.yaml b/conda/meta.yaml
index 60430acf..5756384d 100644
--- a/conda/meta.yaml
+++ b/conda/meta.yaml
@@ -27,14 +27,14 @@ requirements:
     - setuptools {{ setuptools }}
     - numpy {{ numpy }}
     - h5py {{ h5py }}
-    - pytorch {{ pytorch }} # [linux]
+    - pytorch {{ pytorch }}
     - torchvision  {{ torchvision }} # [linux]
     - bob.extension
   run:
     - python
     - setuptools
     - {{ pin_compatible('numpy') }}
-    - {{ pin_compatible('pytorch') }} # [linux]
+    - {{ pin_compatible('pytorch') }}
     - {{ pin_compatible('torchvision') }} # [linux]
     - matplotlib
     - pandas
diff --git a/doc/api.rst b/doc/api.rst
index f1e5cd3c..b6a288f0 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -57,23 +57,24 @@ Neural Network Models
 ---------------------
 
 .. autosummary::
-   :toctree: api/modeling
-
-   bob.ip.binseg.modeling
-   bob.ip.binseg.modeling.backbones
-   bob.ip.binseg.modeling.backbones.mobilenetv2
-   bob.ip.binseg.modeling.backbones.resnet
-   bob.ip.binseg.modeling.backbones.vgg
-   bob.ip.binseg.modeling.driu
-   bob.ip.binseg.modeling.driubn
-   bob.ip.binseg.modeling.driuod
-   bob.ip.binseg.modeling.driupix
-   bob.ip.binseg.modeling.hed
-   bob.ip.binseg.modeling.losses
-   bob.ip.binseg.modeling.m2u
-   bob.ip.binseg.modeling.make_layers
-   bob.ip.binseg.modeling.resunet
-   bob.ip.binseg.modeling.unet
+   :toctree: api/models
+
+   bob.ip.binseg.models
+   bob.ip.binseg.models.backbones
+   bob.ip.binseg.models.backbones.mobilenetv2
+   bob.ip.binseg.models.backbones.resnet
+   bob.ip.binseg.models.backbones.vgg
+   bob.ip.binseg.models.normalizer
+   bob.ip.binseg.models.driu
+   bob.ip.binseg.models.driu_bn
+   bob.ip.binseg.models.driu_od
+   bob.ip.binseg.models.driu_pix
+   bob.ip.binseg.models.hed
+   bob.ip.binseg.models.m2unet
+   bob.ip.binseg.models.resunet
+   bob.ip.binseg.models.unet
+   bob.ip.binseg.models.losses
+   bob.ip.binseg.models.make_layers
 
 
 Toolbox
@@ -86,7 +87,6 @@ Toolbox
    bob.ip.binseg.utils.checkpointer
    bob.ip.binseg.utils.measure
    bob.ip.binseg.utils.model_serialization
-   bob.ip.binseg.utils.model_zoo
    bob.ip.binseg.utils.plot
    bob.ip.binseg.utils.table
    bob.ip.binseg.utils.summary
diff --git a/doc/extras.inv b/doc/extras.inv
index 55baaba61aa2392fa690178e23d9dbea964b8fc4..f02c5b9fd164f673297c60f44cb6f75434a83d82 100644
GIT binary patch
delta 461
zcmZo<IlwZZwm$#1g23_npEZ?!T5c%gpRDCuePWB><XgU>Q_7px8%jSDDDRwB_p7lm
z_Pno3`s&XK7Ro>V|2uH=xI~HSh0QCAv*#&vygz^Oj(pyhP?yiEE04Z;>E>1cV2?&l
zYUeDKjkh8uL@s%BOYBZ??Zv1wpLP^K;yXRJwfM^8MJj(+8r?p1tiIj;TCbbol{bN(
zc(Q(bJCx_>bOvqAl425@UtG_(Ez+K0d(lztbDRGqZ7Xh*>}c&{P7+<5%G1{<?Bi0^
zx~_5m^lLdyT<kY9c|K(7oD{p;;#Y5Y`Pz$kr{sob+b+r`Kl6-><o2<9;D5oVFi`j4
zN&ZLML++hWI-eDKEBLN-SAKm{bJ3TMh0E%8xO6{^{!|jS>03*x|6A4Sc<0|s4A03;
z(cW0+5VpMU3GZ3ngvWo^DsH%Vs&3=4_#=<w4}7%SpJd?w;zyGo`-#Gv5!EU)Y=n(d
zdlrkXJ0R%yx3FYJpgETa_n|B0GZd5m{0Vq^BuywW<X=Pg{^_X)RvPbNw|c5=>2~7j
zyZ<^o41q#9t4>NMA3VxbdEMpvf!D9=@0P5-nX^MU?#=rC#moVZj<M{tU#~acY1XG#
ZcMGr7?G}A~`}X$6C2#*QDr>Bo1pvf2?v4Ne

delta 398
zcmX@W(#SHQw%+@&fxz+m-!<84<C=ab@EmFH>`*u*u`fL{W^tz8le|aYUi!{nxB4!d
z<Skp1oo&1SzxOZ7<=CSmVR0!{=4)8xw@uy0@{FgtDfF#B#yat^=JPz~jVIPu@CjK?
zijX?G*u!z@$rX!tU7ywSXxG*^lApE)K3~zk$m`d-3#$*cwCA0Us#kuUU8Z5kJo~56
z1=GxC9dVUcjG`NVTg-1>`?~spo1XQ9H(%5jZ@ucm)p>Hm0Wrl{tq&x2G;MUyKXC0s
zcHWmu-1{G{b>oj;8@efS_JvpeZkIBCo>Zv_oo{v{`OJ(bErwkeDt`n_G7y;>6Ipk7
z-_C+r>4vKVGt0huCs!mtX+2tA|3S6vu%yHX|EbG2&)TYd#yacwJKL5rPt9}B4m}Gv
z*LY;5?oan(+lXEJ&7XRDC(8AdiGTdTJ>%VH!9DkymU3IYRlK}wV@X5S2fv;vBJSJ~
z8r=q8j$cfg<QI~-L!iy>Fo);I&sv`v&xW)t`4?#RCoWB{?cE2y7gN@sa@_OoUz#F2
LLmh(;kD&knyv@vj

diff --git a/doc/extras.txt b/doc/extras.txt
index 4bd227b7..ee07f144 100644
--- a/doc/extras.txt
+++ b/doc/extras.txt
@@ -18,3 +18,6 @@ torchvision.transforms.transforms.Resize py:class 1 https://pytorch.org/docs/sta
 torchvision.transforms.transforms.Pad py:class 1 https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Pad -
 torchvision.transforms.transforms.CenterCrop py:class 1 https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.CenterCrop -
 torchvision.transforms py:module 1 https://pytorch.org/docs/stable/torchvision/transforms.html -
+torchvision.models.mobilenet.MobileNetV2 py:class 1 https://pytorch.org/docs/stable/torchvision/models.html#mobilenet-v2 -
+torchvision.models.resnet.ResNet py:class 1 https://pytorch.org/docs/stable/torchvision/models.html#id10 -
+torchvision.models.vgg.VGG py:class 1 https://pytorch.org/docs/stable/torchvision/models.html#id2 -
-- 
GitLab