diff --git a/bob/ip/binseg/configs/models/driu.py b/bob/ip/binseg/configs/models/driu.py
index 37911ed7837dd05f8e61906878782f07ad6c325c..deb789d00dd7a4c44124713ac6f8dc59f3b5bbba 100644
--- a/bob/ip/binseg/configs/models/driu.py
+++ b/bob/ip/binseg/configs/models/driu.py
@@ -19,7 +19,7 @@ gamma = 1e-3
 eps = 1e-8
 amsbound = False
 
-scheduler_milestones = [150]
+scheduler_milestones = [200]
 scheduler_gamma = 0.1
 
 # model
diff --git a/bob/ip/binseg/configs/models/driuj01.py b/bob/ip/binseg/configs/models/driuj01.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da8443f729c594186c120d4508255043a1b762e
--- /dev/null
+++ b/bob/ip/binseg/configs/models/driuj01.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.driu import build_driu
+import torch.optim as optim
+from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.engine.adabound import AdaBound
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+scheduler_milestones = [200]
+scheduler_gamma = 0.1
+
+# model
+model = build_driu()
+
+# pretrained backbone
+pretrained_backbone = modelurls['vgg16']
+
+# optimizer
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+    
+# criterion
+criterion = SoftJaccardBCELogitsLoss(alpha=0.1)
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/configs/models/hed.py b/bob/ip/binseg/configs/models/hed.py
index ed79474505d02f453e789c0263d0cde2cebae396..a586e86e44f0048c7ce28df00f882eadcd98a2c9 100644
--- a/bob/ip/binseg/configs/models/hed.py
+++ b/bob/ip/binseg/configs/models/hed.py
@@ -19,7 +19,7 @@ gamma = 1e-3
 eps = 1e-8
 amsbound = False
 
-scheduler_milestones = [150]
+scheduler_milestones = [200]
 scheduler_gamma = 0.1
 
 
diff --git a/bob/ip/binseg/configs/models/m2unet.py b/bob/ip/binseg/configs/models/m2unet.py
index 471ce372bf93b8b7302af0370731e75ccea18c14..e97ae62dba55805bae1d2130075d754098b9cb9d 100644
--- a/bob/ip/binseg/configs/models/m2unet.py
+++ b/bob/ip/binseg/configs/models/m2unet.py
@@ -19,7 +19,7 @@ gamma = 1e-3
 eps = 1e-8
 amsbound = False
 
-scheduler_milestones = [150]
+scheduler_milestones = [200]
 scheduler_gamma = 0.1
 
 # model
diff --git a/bob/ip/binseg/configs/models/m2unetj01.py b/bob/ip/binseg/configs/models/m2unetj01.py
new file mode 100644
index 0000000000000000000000000000000000000000..d17fc4b1fdf359026d47f2284d2d411e948ada15
--- /dev/null
+++ b/bob/ip/binseg/configs/models/m2unetj01.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.m2u import build_m2unet
+import torch.optim as optim
+from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.engine.adabound import AdaBound
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+scheduler_milestones = [200]
+scheduler_gamma = 0.1
+
+# model
+model = build_m2unet()
+
+# pretrained backbone
+pretrained_backbone = modelurls['mobilenetv2']
+
+# optimizer
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+    
+# criterion
+criterion = SoftJaccardBCELogitsLoss(alpha=0.1)
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/configs/models/resunet.py b/bob/ip/binseg/configs/models/resunet.py
index 0725443096d9970a5c8b835bcf051214cec99ecc..17184497d375a8750b933c1ff5ebbc7dba3ec3bf 100644
--- a/bob/ip/binseg/configs/models/resunet.py
+++ b/bob/ip/binseg/configs/models/resunet.py
@@ -19,7 +19,7 @@ gamma = 1e-3
 eps = 1e-8
 amsbound = False
 
-scheduler_milestones = [150]
+scheduler_milestones = [200]
 scheduler_gamma = 0.1
 
 # model
diff --git a/bob/ip/binseg/configs/models/resunetj01.py b/bob/ip/binseg/configs/models/resunetj01.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9c5093ff80a49bb9540d4c41d668acdceb3f83c
--- /dev/null
+++ b/bob/ip/binseg/configs/models/resunetj01.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.resunet import build_res50unet
+import torch.optim as optim
+from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.engine.adabound import AdaBound
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+scheduler_milestones = [200]
+scheduler_gamma = 0.1
+
+# model
+model = build_res50unet()
+
+# pretrained backbone
+pretrained_backbone = modelurls['resnet50']
+
+# optimizer
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+    
+# criterion
+criterion = SoftJaccardBCELogitsLoss(alpha=0.1)
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/configs/models/unet.py b/bob/ip/binseg/configs/models/unet.py
index ccc0f2c554b72eaf4f16b0a137998fbb0467c6c1..f034d94db6420a33a29a54c78293c3e9f8eff7cf 100644
--- a/bob/ip/binseg/configs/models/unet.py
+++ b/bob/ip/binseg/configs/models/unet.py
@@ -19,7 +19,7 @@ gamma = 1e-3
 eps = 1e-8
 amsbound = False
 
-scheduler_milestones = [150]
+scheduler_milestones = [200]
 scheduler_gamma = 0.1
 
 # model
diff --git a/bob/ip/binseg/configs/models/unetj01.py b/bob/ip/binseg/configs/models/unetj01.py
new file mode 100644
index 0000000000000000000000000000000000000000..da1096cd34bf8f3319d90f3bf8992b4235ae9919
--- /dev/null
+++ b/bob/ip/binseg/configs/models/unetj01.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.unet import build_unet
+import torch.optim as optim
+from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.engine.adabound import AdaBound
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+scheduler_milestones = [200]
+scheduler_gamma = 0.1
+
+# model
+model = build_unet()
+
+# pretrained backbone
+pretrained_backbone = modelurls['vgg16']
+
+# optimizer
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+    
+# criterion
+criterion = SoftJaccardBCELogitsLoss(alpha=0.1)
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/data/__init__.py b/bob/ip/binseg/data/__init__.py
index ecae81711bc94c42131a78d4298cc508ca07c78d..d776f7534f77d642a336adab27172f4096e3b023 100644
--- a/bob/ip/binseg/data/__init__.py
+++ b/bob/ip/binseg/data/__init__.py
@@ -1 +1,4 @@
+# see https://docs.python.org/3/library/pkgutil.html
+from pkgutil import extend_path
+__path__ = extend_path(__path__, __name__)
 from .binsegdataset import BinSegDataset
\ No newline at end of file
diff --git a/bob/ip/binseg/data/binsegdataset.py b/bob/ip/binseg/data/binsegdataset.py
index 27853d14cb4b3630933106cf781d16f62087cc90..b75a7dde55eaec9a2d94b5ba712e5154ba29b6d9 100644
--- a/bob/ip/binseg/data/binsegdataset.py
+++ b/bob/ip/binseg/data/binsegdataset.py
@@ -1,38 +1,46 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-
 from torch.utils.data import Dataset
 
 class BinSegDataset(Dataset):
-    """
-    PyTorch dataset wrapper around bob.db binary segmentation datasets. 
+    """PyTorch dataset wrapper around bob.db binary segmentation datasets. 
     A transform object can be passed that will be applied to the image, ground truth and mask (if present). 
+    It supports indexing such that dataset[i] can be used to get ith sample.
     
-    It supports indexing such that dataset[i] can be used to get ith sample, e.g.: 
-    img, gt, mask, name = db[0]
-    
-    Parameters
-    ----------
-    database  : binary segmentation `bob.db.database`
-               
-    split     : str
-                    train' or 'test'
-
-    transform : :py:class:`bob.ip.binseg.data.transforms.Compose`
- 
+    Attributes
+    ---------- 
+    bobdb : :py:mod:`bob.db.base`
+        Binary segmentation bob database (e.g. bob.db.drive) 
+    split : str 
+        ``'train'`` or ``'test'``. Defaults to ``'train'``
+    transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
+        A transform or composition of transfroms. Defaults to ``None``.
     """
-    def __init__(self, bobdb, split = None, transform = None):
+    def __init__(self, bobdb, split = 'train', transform = None):
         self.database = bobdb.samples(split)
         self.transform = transform
         self.split = split
     
     def __len__(self):
         """
-        Returns the size of the dataset
+        Returns
+        -------
+        int
+            size of the dataset
         """
         return len(self.database)
     
     def __getitem__(self,index):
+        """
+        Parameters
+        ----------
+        index : int
+        
+        Returns
+        -------
+        list
+            dataitem [img, gt, mask, img_name]
+        """
         img = self.database[index].img.pil_image()
         gt = self.database[index].gt.pil_image()
         mask = self.database[index].mask.pil_image() if hasattr(self.database[index], 'mask') else None
diff --git a/bob/ip/binseg/data/transforms.py b/bob/ip/binseg/data/transforms.py
index aa20d5d4e4cf7154f59ee13c63b6f730e653b653..34b95ca7b8900cffd9fd6b45807c51f1dda06861 100644
--- a/bob/ip/binseg/data/transforms.py
+++ b/bob/ip/binseg/data/transforms.py
@@ -3,21 +3,33 @@
 
 import torchvision.transforms.functional as VF
 import random
+import PIL
 from PIL import Image
 from torchvision.transforms.transforms import Lambda
 from torchvision.transforms.transforms import Compose as TorchVisionCompose
+import math
+from math import floor
+import warnings
+
+
+_pil_interpolation_to_str = {
+    Image.NEAREST: 'PIL.Image.NEAREST',
+    Image.BILINEAR: 'PIL.Image.BILINEAR',
+    Image.BICUBIC: 'PIL.Image.BICUBIC',
+    Image.LANCZOS: 'PIL.Image.LANCZOS',
+    Image.HAMMING: 'PIL.Image.HAMMING',
+    Image.BOX: 'PIL.Image.BOX',
+}
 
 # Compose 
 
-class Compose(object):
-    """Composes several transforms together.
-    Args:
-        transforms (list of ``Transform`` objects): list of transforms to compose.
-    Example:
-        >>> transforms.Compose([
-        >>>     transforms.CenterCrop(10),
-        >>>     transforms.ToTensor(),
-        >>> ])
+class Compose:
+    """Composes several transforms.
+
+    Attributes
+    ----------
+    transforms : list
+        list of transforms to compose.
     """
 
     def __init__(self, transforms):
@@ -40,13 +52,12 @@ class Compose(object):
 
 class CenterCrop:
     """
-    Crops the given PIL images the center.
-    
+    Crop at the center.
+
     Attributes
     ----------
-    size: (sequence or int)
-        Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made.
-    
+    size : int
+        target size
     """
     def __init__(self, size):
         self.size = size
@@ -57,14 +68,18 @@ class CenterCrop:
 
 class Crop:
     """
-    Crop the given PIL Image ground_truth at the given coordinates.
+    Crop at the given coordinates.
+    
     Attributes
     ----------
-        img (PIL Image): Image to be cropped.
-        i: Upper pixel coordinate.
-        j: Left pixel coordinate.
-        h: Height of the cropped image.
-        w: Width of the cropped image.
+    i : int 
+        upper pixel coordinate.
+    j : int 
+        left pixel coordinate.
+    h : int 
+        height of the cropped image.
+    w : int 
+        width of the cropped image.
     """
     def __init__(self, i, j, h, w):
         self.i = i
@@ -77,18 +92,18 @@ class Crop:
 
 class Pad:
     """
+    Constant padding
+
     Attributes
     ----------
-    
     padding : int or tuple 
-        Padding on each border. If a single int is provided this is used to pad all borders. 
+        padding on each border. If a single int is provided this is used to pad all borders. 
         If tuple of length 2 is provided this is the padding on left/right and top/bottom respectively.
         If a tuple of length 4 is provided this is the padding for the left, top, right and bottom borders respectively.
     
     fill : int
-        Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels 
-        respectively. This value is only used when the padding_mode is constant
-        
+        pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. 
+        This value is only used when the padding_mode is constant   
     """
     def __init__(self, padding, fill=0):
         self.padding = padding
@@ -98,6 +113,7 @@ class Pad:
         return [VF.pad(img, self.padding, self.fill, padding_mode='constant') for img in args]
     
 class ToTensor:
+    """Converts PIL.Image to torch.tensor """
     def __call__(self, *args):
         return [VF.to_tensor(img) for img in args]
 
@@ -106,12 +122,12 @@ class ToTensor:
 
 class RandomHFlip:
     """
-    Flips the given PIL image and ground truth horizontally
+    Flips horizontally
+    
     Attributes
     ----------
-    
-    prob : float 
-        probability at which imgage is flipped. Default: 0.5
+    prob : float
+        probability at which imgage is flipped. Defaults to ``0.5``
     """
     def __init__(self, prob = 0.5):
         self.prob = prob
@@ -126,12 +142,12 @@ class RandomHFlip:
     
 class RandomVFlip:
     """
-    Flips the given PIL image and ground truth vertically
+    Flips vertically
+    
     Attributes
     ----------
-    
     prob : float 
-        probability at which imgage is flipped. Default: 0.5
+        probability at which imgage is flipped. Defaults to ``0.5``
     """
     def __init__(self, prob = 0.5):
         self.prob = prob
@@ -146,14 +162,14 @@ class RandomVFlip:
 
 class RandomRotation:
     """
-    Rotates the given PIL image and ground truth vertically
+    Rotates by degree
+    
     Attributes
     ----------
-    
-    prob : float 
-        probability at which imgage is rotated. Default: 0.5
     degree_range : tuple
-        range of degrees in which image and ground truth are rotated. Default: (-15, +15) 
+        range of degrees in which image and ground truth are rotated. Defaults to ``(-15, +15)``
+    prob : float 
+        probability at which imgage is rotated. Defaults to ``0.5``
     """
     def __init__(self, degree_range = (-15, +15), prob = 0.5):
         self.prob = prob
@@ -168,26 +184,26 @@ class RandomRotation:
 
 class ColorJitter(object):
     """ 
-    Randomly change the brightness, contrast and saturation of an image.
+    Randomly change the brightness, contrast, saturation and hue
     
-    Parameters
+    Attributes
     ----------
-
-        brightness : float
-                        How much to jitter brightness. brightness_factor
-                        is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
-        contrast : float
-                        How much to jitter contrast. contrast_factor
-                        is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
-        saturation : float
-                        How much to jitter saturation. saturation_factor
-                        is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
-        hue : float
-                How much to jitter hue. hue_factor is chosen uniformly from
-                [-hue, hue]. Should be >=0 and <= 0.5.
-
-    """
-    def __init__(self,prob=0.5, brightness=0.3, contrast=0.3, saturation=0.02, hue=0.02):
+    brightness : float 
+        how much to jitter brightness. brightness_factor
+        is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
+    contrast : float
+        how much to jitter contrast. contrast_factor
+        is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
+    saturation : float 
+        how much to jitter saturation. saturation_factor
+        is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
+    hue : float 
+        how much to jitter hue. hue_factor is chosen uniformly from
+        [-hue, hue]. Should be >=0 and <= 0.5
+    prob : float
+        probability at which the operation is applied
+    """
+    def __init__(self, brightness=0.3, contrast=0.3, saturation=0.02, hue=0.02, prob=0.5):
         self.brightness = brightness
         self.contrast = contrast
         self.saturation = saturation
@@ -224,5 +240,204 @@ class ColorJitter(object):
                                         self.saturation, self.hue)
             trans_img = transform(args[0])
             return [trans_img, *args[1:]]
+        else:
+            return args
+
+
+class RandomResizedCrop:
+    """Crop to random size and aspect ratio.
+    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+    is finally resized to given size.
+    This is popularly used to train the Inception networks.
+    
+    Attributes
+    ----------
+    size : int 
+        expected output size of each edge
+    scale : tuple 
+        range of size of the origin size cropped. Defaults to ``(0.08, 1.0)``
+    ratio : tuple
+        range of aspect ratio of the origin aspect ratio cropped. Defaults to ``(3. / 4., 4. / 3.)``
+    interpolation :
+        Defaults to ``PIL.Image.BILINEAR``
+    prob : float 
+        probability at which the operation is applied. Defaults to ``0.5``
+    """
+
+    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR, prob = 0.5):
+        if isinstance(size, tuple):
+            self.size = size
+        else:
+            self.size = (size, size)
+        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+            warnings.warn("range should be of kind (min, max)")
+
+        self.interpolation = interpolation
+        self.scale = scale
+        self.ratio = ratio
+        self.prob = prob
+
+    @staticmethod
+    def get_params(img, scale, ratio):
+        area = img.size[0] * img.size[1]
+
+        for attempt in range(10):
+            target_area = random.uniform(*scale) * area
+            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+            aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+            w = int(round(math.sqrt(target_area * aspect_ratio)))
+            h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+            if w <= img.size[0] and h <= img.size[1]:
+                i = random.randint(0, img.size[1] - h)
+                j = random.randint(0, img.size[0] - w)
+                return i, j, h, w
+
+        # Fallback to central crop
+        in_ratio = img.size[0] / img.size[1]
+        if (in_ratio < min(ratio)):
+            w = img.size[0]
+            h = w / min(ratio)
+        elif (in_ratio > max(ratio)):
+            h = img.size[1]
+            w = h * max(ratio)
+        else:  # whole image
+            w = img.size[0]
+            h = img.size[1]
+        i = (img.size[1] - h) // 2
+        j = (img.size[0] - w) // 2
+        return i, j, h, w
+
+    def __call__(self, *args):
+        if random.random() < self.prob:
+            imgs = []
+            for img in args:
+                i, j, h, w = self.get_params(img, self.scale, self.ratio)
+                img = VF.resized_crop(img, i, j, h, w, self.size, self.interpolation)
+                imgs.append(img)
+            return imgs
+        else:
+            return args
+
+    def __repr__(self):
+        interpolate_str = _pil_interpolation_to_str[self.interpolation]
+        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
+        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
+        format_string += ', interpolation={0})'.format(interpolate_str)
+        return format_string
+
+
+class Distortion:
+    """ 
+    Applies random elastic distortion to a PIL Image, adapted from https://github.com/mdbloice/Augmentor/blob/master/Augmentor/Operations.py : 
+    As well as the probability, the granularity of the distortions
+    produced by this class can be controlled using the width and
+    height of the overlaying distortion grid. The larger the height
+    and width of the grid, the smaller the distortions. This means
+    that larger grid sizes can result in finer, less severe distortions.
+    As well as this, the magnitude of the distortions vectors can
+    also be adjusted.
+
+    Attributes
+    ----------
+    grid_width : int 
+        the width of the gird overlay, which is used by the class to apply the transformations to the image. Defaults to  ``8``
+    grid_height : int 
+        the height of the gird overlay, which is used by the class to apply the transformations to the image. Defaults to ``8``
+    magnitude : int 
+        controls the degree to which each distortion is applied to the overlaying distortion grid. Defaults to ``1``
+
+    prob : float
+        probability that the operation is performend. Defaults to ``0.5``
+    """
+    def __init__(self,grid_width=8, grid_height=8, magnitude=1, prob=0.5):
+        self.grid_width = grid_width
+        self.grid_height = grid_height
+        self.magnitude = magnitude
+        self.prob = prob 
+        
+    def _generatemesh(self, image):
+        w, h = image.size
+        horizontal_tiles = self.grid_width
+        vertical_tiles = self.grid_height
+        width_of_square = int(floor(w / float(horizontal_tiles)))
+        height_of_square = int(floor(h / float(vertical_tiles)))
+        width_of_last_square = w - (width_of_square * (horizontal_tiles - 1))
+        height_of_last_square = h - (height_of_square * (vertical_tiles - 1))
+        dimensions = []
+        for vertical_tile in range(vertical_tiles):
+            for horizontal_tile in range(horizontal_tiles):
+                if vertical_tile == (vertical_tiles - 1) and horizontal_tile == (horizontal_tiles - 1):
+                    dimensions.append([horizontal_tile * width_of_square,
+                                       vertical_tile * height_of_square,
+                                       width_of_last_square + (horizontal_tile * width_of_square),
+                                       height_of_last_square + (height_of_square * vertical_tile)])
+                elif vertical_tile == (vertical_tiles - 1):
+                    dimensions.append([horizontal_tile * width_of_square,
+                                       vertical_tile * height_of_square,
+                                       width_of_square + (horizontal_tile * width_of_square),
+                                       height_of_last_square + (height_of_square * vertical_tile)])
+                elif horizontal_tile == (horizontal_tiles - 1):
+                    dimensions.append([horizontal_tile * width_of_square,
+                                       vertical_tile * height_of_square,
+                                       width_of_last_square + (horizontal_tile * width_of_square),
+                                       height_of_square + (height_of_square * vertical_tile)])
+                else:
+                    dimensions.append([horizontal_tile * width_of_square,
+                                       vertical_tile * height_of_square,
+                                       width_of_square + (horizontal_tile * width_of_square),
+                                       height_of_square + (height_of_square * vertical_tile)])
+        last_column = []
+        for i in range(vertical_tiles):
+            last_column.append((horizontal_tiles-1)+horizontal_tiles*i)
+        last_row = range((horizontal_tiles * vertical_tiles) - horizontal_tiles, horizontal_tiles * vertical_tiles)
+        polygons = []
+        for x1, y1, x2, y2 in dimensions:
+            polygons.append([x1, y1, x1, y2, x2, y2, x2, y1])
+        polygon_indices = []
+        for i in range((vertical_tiles * horizontal_tiles) - 1):
+            if i not in last_row and i not in last_column:
+                polygon_indices.append([i, i + 1, i + horizontal_tiles, i + 1 + horizontal_tiles])
+        for a, b, c, d in polygon_indices:
+            dx = random.randint(-self.magnitude, self.magnitude)
+            dy = random.randint(-self.magnitude, self.magnitude)
+            x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a]
+            polygons[a] = [x1, y1,
+                           x2, y2,
+                           x3 + dx, y3 + dy,
+                           x4, y4]
+            x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b]
+            polygons[b] = [x1, y1,
+                           x2 + dx, y2 + dy,
+                           x3, y3,
+                           x4, y4]
+            x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c]
+            polygons[c] = [x1, y1,
+                           x2, y2,
+                           x3, y3,
+                           x4 + dx, y4 + dy]
+            x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d]
+            polygons[d] = [x1 + dx, y1 + dy,
+                           x2, y2,
+                           x3, y3,
+                           x4, y4]
+        generated_mesh = []
+        for i in range(len(dimensions)):
+            generated_mesh.append([dimensions[i], polygons[i]])
+
+        return generated_mesh
+    
+    def __call__(self,*args): 
+        if random.random() < self.prob:
+            # img, gt and mask have same resolution, we only generate mesh once:
+            mesh = self._generatemesh(args[0])
+            imgs = []
+            for img in args:
+                img = img.transform(img.size, Image.MESH, mesh, resample=Image.BICUBIC)
+                imgs.append(img)
+            return imgs
         else:
             return args
\ No newline at end of file
diff --git a/bob/ip/binseg/engine/adabound.py b/bob/ip/binseg/engine/adabound.py
index 735d0399f7a4b7b44178f57fc7824c7e4631edb7..e220db5809f96d482d93059f4ee2f2bae1aec8bd 100644
--- a/bob/ip/binseg/engine/adabound.py
+++ b/bob/ip/binseg/engine/adabound.py
@@ -21,18 +21,20 @@ from torch.optim import Optimizer
 class AdaBound(Optimizer):
     """Implements AdaBound algorithm.
     It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_.
-    Arguments:
-        params (iterable): iterable of parameters to optimize or dicts defining
-            parameter groups
-        lr (float, optional): Adam learning rate (default: 1e-3)
-        betas (Tuple[float, float], optional): coefficients used for computing
-            running averages of gradient and its square (default: (0.9, 0.999))
-        final_lr (float, optional): final (SGD) learning rate (default: 0.1)
-        gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
-        eps (float, optional): term added to the denominator to improve
-            numerical stability (default: 1e-8)
-        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
-        amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
+    
+    Parameters
+    ----------
+    params (iterable): iterable of parameters to optimize or dicts defining
+        parameter groups
+    lr (float, optional): Adam learning rate (default: 1e-3)
+    betas (Tuple[float, float], optional): coefficients used for computing
+        running averages of gradient and its square (default: (0.9, 0.999))
+    final_lr (float, optional): final (SGD) learning rate (default: 0.1)
+    gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
+    eps (float, optional): term added to the denominator to improve
+        numerical stability (default: 1e-8)
+    weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+    amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
     .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate:
         https://openreview.net/forum?id=Bkg3g2R9FX
     """
@@ -64,9 +66,10 @@ class AdaBound(Optimizer):
 
     def step(self, closure=None):
         """Performs a single optimization step.
-        Arguments:
-            closure (callable, optional): A closure that reevaluates the model
-                and returns the loss.
+        
+        Parameters
+        ----------
+        closure (callable, optional): A closure that reevaluates the model and returns the loss.
         """
         loss = None
         if closure is not None:
@@ -135,18 +138,20 @@ class AdaBound(Optimizer):
 class AdaBoundW(Optimizer):
     """Implements AdaBound algorithm with Decoupled Weight Decay (arxiv.org/abs/1711.05101)
     It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_.
-    Arguments:
-        params (iterable): iterable of parameters to optimize or dicts defining
-            parameter groups
-        lr (float, optional): Adam learning rate (default: 1e-3)
-        betas (Tuple[float, float], optional): coefficients used for computing
-            running averages of gradient and its square (default: (0.9, 0.999))
-        final_lr (float, optional): final (SGD) learning rate (default: 0.1)
-        gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
-        eps (float, optional): term added to the denominator to improve
-            numerical stability (default: 1e-8)
-        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
-        amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
+    
+    Parameters
+    ----------
+    params (iterable): iterable of parameters to optimize or dicts defining
+        parameter groups
+    lr (float, optional): Adam learning rate (default: 1e-3)
+    betas (Tuple[float, float], optional): coefficients used for computing
+        running averages of gradient and its square (default: (0.9, 0.999))
+    final_lr (float, optional): final (SGD) learning rate (default: 0.1)
+    gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
+    eps (float, optional): term added to the denominator to improve
+        numerical stability (default: 1e-8)
+    weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+    amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
     .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate:
         https://openreview.net/forum?id=Bkg3g2R9FX
     """
@@ -178,9 +183,10 @@ class AdaBoundW(Optimizer):
 
     def step(self, closure=None):
         """Performs a single optimization step.
-        Arguments:
-            closure (callable, optional): A closure that reevaluates the model
-                and returns the loss.
+        
+        Parameters
+        ----------
+        closure (callable, optional): A closure that reevaluates the model and returns the loss.
         """
         loss = None
         if closure is not None:
diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py
index b57a2de0351727117c2d105066bae22a0d6acf62..3dfb77f4e9e3705aa7cae17413d1ed6e5e446307 100644
--- a/bob/ip/binseg/engine/inferencer.py
+++ b/bob/ip/binseg/engine/inferencer.py
@@ -22,18 +22,23 @@ def batch_metrics(predictions, ground_truths, masks, names, output_folder, logge
 
     Parameters
     ----------
-    predictions: :py:class:torch.Tensor
-    ground_truths : :py:class:torch.Tensor
-    mask : :py:class:torch.Tensor
+    predictions : :py:class:`torch.Tensor`
+        tensor with pixel-wise probabilities
+    ground_truths : :py:class:`torch.Tensor`
+        tensor with binary ground-truth
+    mask : :py:class:`torch.Tensor`
+        tensor with mask
     names : list
+        list of file names 
     output_folder : str
+        output path
     logger : :py:class:logging
+        python logger
 
     Returns
     -------
-
-    batch_metrics : list
-
+    list 
+        list containing batch metrics (name, threshold, precision, recall, specificity, accuracy, jaccard, f1_score)
     """
     step_size = 0.01
     batch_metrics = []
@@ -88,20 +93,24 @@ def batch_metrics(predictions, ground_truths, masks, names, output_folder, logge
 
 def save_probability_images(predictions, names, output_folder, logger):
     """
-    Saves probability maps as tif image
+    Saves probability maps as image in the same format as the test image
 
     Parameters
     ----------
-    predictions : :py:class:torch.Tensor
+    predictions : :py:class:`torch.Tensor`
+        tensor with pixel-wise probabilities
     names : list
+        list of file names 
     output_folder : str
-    logger :  :py:class:logging
+        output path
+    logger : :py:class:logging
+        python logger
     """
     images_subfolder = os.path.join(output_folder,'images') 
     if not os.path.exists(images_subfolder): os.makedirs(images_subfolder)
     for j in range(predictions.size()[0]):
         img = VF.to_pil_image(predictions.cpu().data[j])
-        filename = '{}.tif'.format(names[j])
+        filename = '{}'.format(names[j])
         logger.info("saving {}".format(filename))
         img.save(os.path.join(images_subfolder, filename))
 
@@ -116,12 +125,14 @@ def do_inference(
     """
     Run inference and calculate metrics
     
-    Paramters
+    Parameters
     ---------
     model : :py:class:torch.nn.Module
+        neural network model (e.g. DRIU, HED, UNet)
     data_loader : py:class:torch.torch.utils.data.DataLoader
+        PyTorch DataLoader
     device : str
-                'cpu' or 'cuda'
+        device to use ('cpu' or 'cuda')
     output_folder : str
     """
     logger = logging.getLogger("bob.ip.binseg.engine.inference")
@@ -155,6 +166,9 @@ def do_inference(
                 outputs = outputs[-1]
             
             probabilities = sigmoid(outputs)
+            if hasattr(masks,'dtype'):
+                masks = masks.to(device)
+                probabilities = probabilities * masks
             
             batch_time = time.perf_counter() - start_time
             times.append(batch_time)
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index b5a60bb292c01fa62ecf9bfbc1a59d82e8cdfdba..b508cda40c250d59ae46b02464f9ecb6c7737d30 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -30,18 +30,24 @@ def do_train(
     
     Parameters
     ----------
-    model : :py:class:torch.nn.Module
-    data_loader : py:class:torch.torch.utils.data.DataLoader
-    optimizer : py:class.torch.torch.optim.Optimizer
-    criterion : py:class.torch.nn.modules.loss._Loss
-    scheduler : py:class.torch.torch.optim._LRScheduler
-    checkpointer : bob.ip.binseg.utils.checkpointer.DetectronCheckpointer
+    model : :py:class:`torch.nn.Module` 
+        Network (e.g. DRIU, HED, UNet)
+    data_loader : :py:class:`torch.torch.utils.data.DataLoader`
+    optimizer : :py:class.`torch.torch.optim.Optimizer`
+    criterion : :py:class.`torch.nn.modules.loss._Loss`
+        loss function
+    scheduler : :py:class.`torch.torch.optim._LRScheduler`
+        learning rate scheduler
+    checkpointer : :py:class.`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
+        checkpointer
     checkpoint_period : int
-    device : str
-                'cpu' or 'cuda'
+        save a checkpoint every n epochs
+    device : str  
+        device to use. 'cpu' or 'cuda'.
     arguments : dict
-    output_folder : str
-
+        start end end epochs
+    output_folder : str 
+        output path
     """
     logger = logging.getLogger("bob.ip.binseg.engine.trainer")
     logger.info("Start training")
@@ -68,10 +74,12 @@ def do_train(
 
                 images = images.to(device)
                 ground_truths = ground_truths.to(device)
-
+                if hasattr(masks,'dtype'):
+                    masks = masks.to(device)
+                
                 outputs = model(images)
-
-                loss = criterion(outputs, ground_truths)
+                
+                loss = criterion(outputs, ground_truths, masks)
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
diff --git a/bob/ip/binseg/modeling/driu.py b/bob/ip/binseg/modeling/driu.py
index fa367e61f992060a56b90950bb8bab6d8f84ec1b..81ad9c4ff47c4162ef7ed372f5cf505a9fa7cb64 100644
--- a/bob/ip/binseg/modeling/driu.py
+++ b/bob/ip/binseg/modeling/driu.py
@@ -28,7 +28,7 @@ class DRIU(nn.Module):
     Parameters
     ----------
     in_channels_list : list
-                        number of channels for each feature map that is returned from backbone
+        number of channels for each feature map that is returned from backbone
     """
     def __init__(self, in_channels_list=None):
         super(DRIU, self).__init__()
@@ -48,9 +48,13 @@ class DRIU(nn.Module):
         Parameters
         ----------
         x : list
-                list of tensors as returned from the backbone network.
-                First element: height and width of input image. 
-                Remaining elements: feature maps for each feature level.
+            list of tensors as returned from the backbone network.
+            First element: height and width of input image. 
+            Remaining elements: feature maps for each feature level.
+
+        Returns
+        -------
+        :py:class:`torch.Tensor`
         """
         hw = x[0]
         conv1_2_16 = self.conv1_2_16(x[1])  # conv1_2_16   
@@ -66,7 +70,7 @@ def build_driu():
 
     Returns
     -------
-    model : :py:class:torch.nn.Module
+    :py:class:torch.nn.Module
     """
     backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22])
     driu_head = DRIU([64, 128, 256, 512])
diff --git a/bob/ip/binseg/modeling/hed.py b/bob/ip/binseg/modeling/hed.py
index 6a8349fd441ad37edf80628a849d11e2f0af56cd..fa44366e4e75dce1059a130e83f78e7884922501 100644
--- a/bob/ip/binseg/modeling/hed.py
+++ b/bob/ip/binseg/modeling/hed.py
@@ -28,7 +28,7 @@ class HED(nn.Module):
     Parameters
     ----------
     in_channels_list : list
-                        number of channels for each feature map that is returned from backbone
+        number of channels for each feature map that is returned from backbone
     """
     def __init__(self, in_channels_list=None):
         super(HED, self).__init__()
@@ -48,9 +48,13 @@ class HED(nn.Module):
         Parameters
         ----------
         x : list
-                list of tensors as returned from the backbone network.
-                First element: height and width of input image. 
-                Remaining elements: feature maps for each feature level.
+            list of tensors as returned from the backbone network.
+            First element: height and width of input image. 
+            Remaining elements: feature maps for each feature level.
+        
+        Returns
+        -------
+        :py:class:`torch.Tensor`
         """
         hw = x[0]
         conv1_2_16 = self.conv1_2_16(x[1])  
@@ -69,7 +73,7 @@ def build_hed():
 
     Returns
     -------
-    model : :py:class:torch.nn.Module
+    :py:class:torch.nn.Module
     """
     backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22, 29])
     hed_head = HED([64, 128, 256, 512, 512])
diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py
index e094dbe6a01d58ea645103d9a78c9553f506eae7..809f19728fd31ba8c57ebc0c6ceb8672e17bc96c 100644
--- a/bob/ip/binseg/modeling/losses.py
+++ b/bob/ip/binseg/modeling/losses.py
@@ -2,9 +2,34 @@ import torch
 from torch.nn.modules.loss import _Loss
 from torch._jit_internal import weak_script_method
 
+
+
+
 class WeightedBCELogitsLoss(_Loss):
     """ 
-    Calculate sum of weighted cross entropy loss. Use for binary classification.
+    Implements Equation 1 in [DRIU16]_. Based on :py:class:`torch.torch.nn.modules.loss.BCEWithLogitsLoss`. 
+    Calculate sum of weighted cross entropy loss.
+
+    Attributes
+    ----------
+    size_average : bool, optional
+        Deprecated (see :attr:`reduction`). By default, the losses are averaged over each loss element in the batch. Note that for 
+        some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are 
+        instead summed for each minibatch. Ignored when reduce is ``False``. Default: ``True``
+    reduce : bool, optional 
+        Deprecated (see :attr:`reduction`). By default, the
+        losses are averaged or summed over observations for each minibatch depending
+        on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+        batch element instead and ignores :attr:`size_average`. Default: ``True``
+    reduction : string, optional
+        Specifies the reduction to apply to the output:
+        ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+        ``'mean'``: the sum of the output will be divided by the number of
+        elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+        and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+        specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+    pos_weight : :py:class:`torch.Tensor`, optional
+        a weight of positive examples. Must be a vector with length equal to the number of classes.
     """
     def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None):
         super(WeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction)
@@ -12,10 +37,14 @@ class WeightedBCELogitsLoss(_Loss):
         self.register_buffer('pos_weight', pos_weight)
 
     @weak_script_method
-    def forward(self, input, target):
+    def forward(self, input, target, masks=None):
         n, c, h, w = target.shape
         num_pos = torch.sum(target, dim=[1, 2, 3]).float().reshape(n,1) # torch.Size([n, 1])
-        num_neg = c * h * w - num_pos  # torch.Size([n, 1])
+        if hasattr(masks,'dtype'):
+            num_mask_neg = c * h * w - torch.sum(masks, dim=[1, 2, 3]).float().reshape(n,1) # torch.Size([n, 1])
+            num_neg =  c * h * w - num_pos - num_mask_neg
+        else:
+            num_neg = c * h * w - num_pos 
         numposnumtotal = torch.ones_like(target) * (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2)
         numnegnumtotal = torch.ones_like(target) * (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2)
         weight = torch.where((target <= 0.5) , numposnumtotal, numnegnumtotal)
@@ -23,9 +52,76 @@ class WeightedBCELogitsLoss(_Loss):
         loss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction=self.reduction)
         return loss 
 
+class SoftJaccardBCELogitsLoss(_Loss):
+    """ 
+    Implements Equation 6 in [SAT17]_. Based on :py:class:`torch.torch.nn.modules.loss.BCEWithLogitsLoss`. 
+
+    Attributes
+    ----------
+    alpha : float
+        determines the weighting of SoftJaccard and BCE. Default: ``0.3``
+    size_average : bool, optional
+        Deprecated (see :attr:`reduction`). By default, the losses are averaged over each loss element in the batch. Note that for 
+        some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are 
+        instead summed for each minibatch. Ignored when reduce is ``False``. Default: ``True``
+    reduce : bool, optional 
+        Deprecated (see :attr:`reduction`). By default, the
+        losses are averaged or summed over observations for each minibatch depending
+        on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+        batch element instead and ignores :attr:`size_average`. Default: ``True``
+    reduction : string, optional
+        Specifies the reduction to apply to the output:
+        ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+        ``'mean'``: the sum of the output will be divided by the number of
+        elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+        and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+        specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+    pos_weight : :py:class:`torch.Tensor`, optional
+        a weight of positive examples. Must be a vector with length equal to the number of classes.
+    """
+    def __init__(self, alpha=0.3, size_average=None, reduce=None, reduction='mean', pos_weight=None):
+        super(SoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction) 
+        self.alpha = alpha   
+
+    @weak_script_method
+    def forward(self, input, target):
+        eps = 1e-8
+        probabilities = torch.sigmoid(input)
+        intersection = (probabilities * target).sum()
+        sums = probabilities.sum() + target.sum()
+        
+        softjaccard = intersection/(sums - intersection + eps)
+
+        bceloss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=None, reduction=self.reduction)
+        loss = self.alpha * bceloss + (1 - self.alpha) * (1-softjaccard)
+        return loss
+
+
 class HEDWeightedBCELogitsLoss(_Loss):
     """ 
-    Calculate sum of weighted cross entropy loss. Use for binary classification.
+    Implements Equation 2 in [HED15]_. Based on :py:class:`torch.torch.nn.modules.loss.BCEWithLogitsLoss`. 
+    Calculate sum of weighted cross entropy loss.
+
+    Attributes
+    ----------
+    size_average : bool, optional
+        Deprecated (see :attr:`reduction`). By default, the losses are averaged over each loss element in the batch. Note that for 
+        some losses, there are multiple elements per sample. If the field :attr:`size_average` is set to ``False``, the losses are 
+        instead summed for each minibatch. Ignored when reduce is ``False``. Default: ``True``
+    reduce : bool, optional 
+        Deprecated (see :attr:`reduction`). By default, the
+        losses are averaged or summed over observations for each minibatch depending
+        on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+        batch element instead and ignores :attr:`size_average`. Default: ``True``
+    reduction : string, optional
+        Specifies the reduction to apply to the output:
+        ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+        ``'mean'``: the sum of the output will be divided by the number of
+        elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+        and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+        specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+    pos_weight : :py:class:`torch.Tensor`, optional
+        a weight of positive examples. Must be a vector with length equal to the number of classes.
     """
     def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None):
         super(HEDWeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction)
@@ -33,16 +129,32 @@ class HEDWeightedBCELogitsLoss(_Loss):
         self.register_buffer('pos_weight', pos_weight)
 
     @weak_script_method
-    def forward(self, inputlist, target):
+    def forward(self, inputlist, target, masks=None):
+        """[summary]
+        
+        Parameters
+        ----------
+        inputlist : list of :py:class:`torch.Tensor`
+            HED uses multiple side-output feature maps for the loss calculation
+        target : :py:class:`torch.Tensor`
+        
+        Returns
+        -------
+        :py:class:`torch.Tensor`
+            
+        """
         loss_over_all_inputs = []
         for input in inputlist:
             n, c, h, w = target.shape
             num_pos = torch.sum(target, dim=[1, 2, 3]).float().reshape(n,1) # torch.Size([n, 1])
-            num_neg = c * h * w - num_pos  # torch.Size([n, 1])
+            if hasattr(masks,'dtype'):
+                num_mask_neg = c * h * w - torch.sum(masks, dim=[1, 2, 3]).float().reshape(n,1) # torch.Size([n, 1])
+                num_neg =  c * h * w - num_pos - num_mask_neg
+            else: 
+                num_neg = c * h * w - num_pos  # torch.Size([n, 1])
             numposnumtotal = torch.ones_like(target) * (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2)
             numnegnumtotal = torch.ones_like(target) * (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2)
             weight = torch.where((target <= 0.5) , numposnumtotal, numnegnumtotal)
-
             loss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction=self.reduction)
             loss_over_all_inputs.append(loss.unsqueeze(0))
         final_loss = torch.cat(loss_over_all_inputs).mean()
diff --git a/bob/ip/binseg/modeling/m2u.py b/bob/ip/binseg/modeling/m2u.py
index 13602eb3d23abb04905b2e75d32791802ac608e5..7db86168c0b6f703546de4dca2e22539e73adeb4 100644
--- a/bob/ip/binseg/modeling/m2u.py
+++ b/bob/ip/binseg/modeling/m2u.py
@@ -44,7 +44,7 @@ class M2U(nn.Module):
     Parameters
     ----------
     in_channels_list : list
-                        number of channels for each feature map that is returned from backbone
+        number of channels for each feature map that is returned from backbone
     """
     def __init__(self, in_channels_list=None,upsamplemode='bilinear',expand_ratio=0.15):
         super(M2U, self).__init__()
@@ -73,9 +73,12 @@ class M2U(nn.Module):
         Parameters
         ----------
         x : list
-                list of tensors as returned from the backbone network.
-                First element: height and width of input image. 
-                Remaining elements: feature maps for each feature level.
+            list of tensors as returned from the backbone network.
+            First element: height and width of input image. 
+            Remaining elements: feature maps for each feature level.
+        Returns
+        -------
+        :py:class:`torch.Tensor`
         """
         decode4 = self.decode4(x[5],x[4])    # 96, 32
         decode3 = self.decode3(decode4,x[3]) # 64, 24
@@ -90,9 +93,9 @@ def build_m2unet():
 
     Returns
     -------
-    model : :py:class:torch.nn.Module
+    :py:class:torch.nn.Module
     """
-    backbone = MobileNetV2(return_features = [1,3,6,13], m2u=True)
+    backbone = MobileNetV2(return_features = [1, 3, 6, 13], m2u=True)
     m2u_head = M2U(in_channels_list=[16, 24, 32, 96])
 
     model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", m2u_head)]))
diff --git a/bob/ip/binseg/modeling/make_layers.py b/bob/ip/binseg/modeling/make_layers.py
index fbe40fd3a0eb3d2ae024cc51f46848f587133d65..7e3984433273eaa0d7f86b3e720682c9460552f3 100644
--- a/bob/ip/binseg/modeling/make_layers.py
+++ b/bob/ip/binseg/modeling/make_layers.py
@@ -41,7 +41,7 @@ def convtrans_with_kaiming_uniform(in_channels, out_channels, kernel_size, strid
 
 
 class UpsampleCropBlock(nn.Module):
-    def __init__(self, in_channels, out_channels, up_kernel_size, up_stride, up_padding):
+    def __init__(self, in_channels, out_channels, up_kernel_size, up_stride, up_padding, pixelshuffle=False):
         """
         Combines Conv2d, ConvTransposed2d and Cropping. Simulates the caffe2 crop layer in the forward function.
         Used for DRIU and HED. 
@@ -57,7 +57,10 @@ class UpsampleCropBlock(nn.Module):
         super().__init__()
         # NOTE: Kaiming init, replace with nn.Conv2d and nn.ConvTranspose2d to get original DRIU impl.
         self.conv = conv_with_kaiming_uniform(in_channels, out_channels, 3, 1, 1)
-        self.upconv = convtrans_with_kaiming_uniform(out_channels, out_channels, up_kernel_size, up_stride, up_padding)        
+        if pixelshuffle:
+            self.upconv = PixelShuffle_ICNR( out_channels, out_channels, scale = up_stride)
+        else:
+            self.upconv = convtrans_with_kaiming_uniform(out_channels, out_channels, up_kernel_size, up_stride, up_padding)        
         
         
     def forward(self, x, input_res):
diff --git a/bob/ip/binseg/modeling/unet.py b/bob/ip/binseg/modeling/unet.py
index d0db666bfa12d479ada6aa0f52e601b085a3a484..d1102592b74d2ea2c8af8ea2657ac6f1775a92d7 100644
--- a/bob/ip/binseg/modeling/unet.py
+++ b/bob/ip/binseg/modeling/unet.py
@@ -35,9 +35,9 @@ class UNet(nn.Module):
         Parameters
         ----------
         x : list
-                list of tensors as returned from the backbone network.
-                First element: height and width of input image. 
-                Remaining elements: feature maps for each feature level.
+            list of tensors as returned from the backbone network.
+            First element: height and width of input image. 
+            Remaining elements: feature maps for each feature level.
         """
         # NOTE: x[0]: height and width of input image not needed in U-Net architecture
         decode4 = self.decode4(x[5], x[4])  
diff --git a/bob/ip/binseg/utils/click.py b/bob/ip/binseg/utils/click.py
index fa3582878cf1cc3d1634fcd66f416fa468448030..8b8294d97f869167f0908f22c4521c6fcda1a243 100644
--- a/bob/ip/binseg/utils/click.py
+++ b/bob/ip/binseg/utils/click.py
@@ -1,11 +1,15 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-# https://stackoverflow.com/questions/48391777/nargs-equivalent-for-options-in-click 
+
 
 import click
 
 class OptionEatAll(click.Option):
+    """
+    Allows for *args and **kwargs to be passed to click 
+    https://stackoverflow.com/questions/48391777/nargs-equivalent-for-options-in-click 
+    """
 
     def __init__(self, *args, **kwargs):
         self.save_other_options = kwargs.pop('save_other_options', True)
diff --git a/bob/ip/binseg/utils/model_serialization.py b/bob/ip/binseg/utils/model_serialization.py
index 3e7b565ad3cf19e699c48f403ddb744f74bda072..84ff2491ea85751cc0b0910d278f035e9571f0eb 100644
--- a/bob/ip/binseg/utils/model_serialization.py
+++ b/bob/ip/binseg/utils/model_serialization.py
@@ -1,4 +1,5 @@
 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# https://github.com/facebookresearch/maskrcnn-benchmark
 from collections import OrderedDict
 import logging
 
diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py
index 98c9adef42af9dc3555864923cfec8a55063d126..5a2677d5b055a9ed9581d08ee50bdc45c329d221 100644
--- a/bob/ip/binseg/utils/plot.py
+++ b/bob/ip/binseg/utils/plot.py
@@ -124,7 +124,7 @@ def read_metricscsv(file):
     """
     Read precision and recall from csv file
     
-    Arguments
+    Parameters
     ---------
     file: str
            path to file
diff --git a/conda/meta.yaml b/conda/meta.yaml
index 1e569ac4d7897703dd3a32cb0e2ceb885c478172..beff5e34bec1e41de2916f9901eee7975445ed7a 100644
--- a/conda/meta.yaml
+++ b/conda/meta.yaml
@@ -26,16 +26,16 @@ requirements:
   host:
     - python {{ python }}
     - setuptools {{ setuptools }}
-    - torchvision  {{ torchvision }}
-    - pytorch {{ pytorch }}
+    - torchvision  {{ torchvision }} # [linux]
+    - pytorch {{ pytorch }} # [linux]
     - numpy {{ numpy }}
     - bob.extension
     # place your other host dependencies here
   run:
     - python
     - setuptools
-    - {{ pin_compatible('pytorch') }}
-    - {{ pin_compatible('torchvision') }}
+    - {{ pin_compatible('pytorch') }} # [linux]
+    - {{ pin_compatible('torchvision') }} # [linux]
     - {{ pin_compatible('numpy') }}
     - pandas
     - matplotlib
diff --git a/doc/api.rst b/doc/api.rst
new file mode 100644
index 0000000000000000000000000000000000000000..9cfacb8d322ff9b5e9491f6e9d5231b031d17685
--- /dev/null
+++ b/doc/api.rst
@@ -0,0 +1,24 @@
+.. -*- coding: utf-8 -*-
+.. _bob.ip.binseg.api:
+
+============
+ Python API
+============
+
+This section lists all the functionality available in this library allowing to
+run HED-based experiments.
+
+
+PyTorch Dataset
+---------------
+.. automodule:: bob.ip.binseg.data.binsegdataset
+
+Transforms
+----------
+.. note:: 
+    All transforms work with PIL.Image.Image objects.
+
+.. automodule:: bob.ip.binseg.data.transforms
+
+
+.. include:: links.rst
diff --git a/doc/conf.py b/doc/conf.py
index 72f4305249b6a0801d28edb667df89a556c00b99..4c2c5a759f8c3144f7a9a7bde6f92337222eb9f8 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -234,6 +234,7 @@ else:
     intersphinx_mapping = link_documentation()
 
 intersphinx_mapping['torch'] = ('https://pytorch.org/docs/stable/', None)
+intersphinx_mapping['PIL'] = ('http://pillow.readthedocs.io/en/stable', None)
 # We want to remove all private (i.e. _. or __.__) members
 # that are not in the list of accepted functions
 accepted_private_functions = ['__array__']
@@ -254,4 +255,5 @@ def member_function_test(app, what, name, obj, skip, options):
 
 
 def setup(app):
-    app.connect('autodoc-skip-member', member_function_test)
\ No newline at end of file
+    app.connect('autodoc-skip-member', member_function_test)
+    
\ No newline at end of file
diff --git a/doc/datasets.rst b/doc/datasets.rst
index 1e99ff2a3399208df376dd2e12326aa44fb6716f..67939144b975ed25b1b01d996b656d9aa4c2fe6e 100644
--- a/doc/datasets.rst
+++ b/doc/datasets.rst
@@ -1,27 +1,31 @@
+.. -*- coding: utf-8 -*-
+.. _bob.ip.binseg.datasets:
+
+
 ==================
 Supported Datasets 
 ==================
 
-+-----------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
-| # | Name        | H x W              | # imgs | Train | Test | Mask | Vessel | OD | Cup | Ethnicity | 
-+===+=============+====================+========+=======+======+======+========+====+=====+===========+
-| 1 | DRIVE       | 584 x 565          | 40     |  20   | 20   |   x  |    x   |    |     |           |
-+---+-------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
-| 2 | STARE       | 605 x 700          | 20     |  10   | 10   |      |    x   |    |     |           |
-+---+-------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
-| 3 | CHASE_DB1   | 960 x 999          | 28     |   8   | 20   |      |    x   |    |     |           |
-+---+-------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
-| 4 | HRF         | 2336 x 3504        | 45     |  15   | 30   |   x  |    x   |    |     |           |
-+---+-------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
-| 5 | IOSTAR      | 1024 x 1024        | 30     |   20  | 10   |   x  |    x   |  x |     |           |
-+---+-------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
-| 6 | DRIONS-DB   | 400 x 600          | 110    |   60  | 50   |      |        |  x |     |           |
-+---+-------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
-| 7 | RIM-ONE r3  | 1424 x 1072        | 159    |   99  | 60   |      |        |  x |  x  |           |
-+---+-------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
-| 8 | Drishti-GS1 | varying            | 101    |  50   |   51 |      |        |  x |  x  |           |
-+---+-------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
-| 9 | REFUGE train| 2056 x 2124        | 400    | 400   |      |      |        |  x |  x  |           |
-+---+-------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
-| 9 | REFUGE val  | 1634 x 1634        | 400    |       | 400  |      |        |  x |  x  |           |
-+---+-------------+--------------------+--------+-------+------+------+--------+----+-----+-----------+
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
+| # | Name        | H x W              | # imgs | Train | Test | Mask | Vessel | OD | Cup | Ethnicity                 | 
++===+=============+====================+========+=======+======+======+========+====+=====+===========================+
+| 1 | DRIVE       | 584 x 565          | 40     |  20   | 20   |   x  |    x   |    |     |   Dutch (adult)           |
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
+| 2 | STARE       | 605 x 700          | 20     |  10   | 10   |      |    x   |    |     |   White American (adult)  |
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
+| 3 | CHASE_DB1   | 960 x 999          | 28     |   8   | 20   |      |    x   |    |     |   British (child)         |
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
+| 4 | HRF         | 2336 x 3504        | 45     |  15   | 30   |   x  |    x   |    |     |                           |
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
+| 5 | IOSTAR      | 1024 x 1024        | 30     |   20  | 10   |   x  |    x   |  x |     |                           |
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
+| 6 | DRIONS-DB   | 400 x 600          | 110    |   60  | 50   |      |        |  x |     |                           |
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
+| 7 | RIM-ONE r3  | 1424 x 1072        | 159    |   99  | 60   |      |        |  x |  x  |                           |
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
+| 8 | Drishti-GS1 | varying            | 101    |  50   |   51 |      |        |  x |  x  |                           |
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
+| 9 | REFUGE train| 2056 x 2124        | 400    | 400   |      |      |        |  x |  x  |                           |
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
+| 9 | REFUGE val  | 1634 x 1634        | 400    |       | 400  |      |        |  x |  x  |                           |
++---+-------------+--------------------+--------+-------+------+------+--------+----+-----+---------------------------+
diff --git a/doc/extra-intersphinx.txt b/doc/extra-intersphinx.txt
new file mode 100644
index 0000000000000000000000000000000000000000..37f700a78ecad7bea22172702e640797d7136a17
--- /dev/null
+++ b/doc/extra-intersphinx.txt
@@ -0,0 +1,2 @@
+torch
+torchvision
\ No newline at end of file
diff --git a/doc/index.rst b/doc/index.rst
index 6a1ed53194cca38e018aec676724d493cc31da33..fe6cf753734f6167476572120efa5ad3283073f7 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -6,19 +6,18 @@
  Binary Segmentation Benchmark Package for Bob
 ===============================================
 
-.. todo ::
-   Write here a small (1 paragraph) introduction explaining this project. See
-   other projects for examples.
+Package to benchmark and evaluate a range of neural network architectures for binary segmentation tasks on 2D Eye Fundus Images (2DFI). 
 
-Datasets
-========
+Users Guide
+===========
 
 .. toctree::
    :maxdepth: 2
-
+ 
    datasets
+   api
+   references
 
-Users Guide
-===========
+.. todolist::
 
 .. include:: links.rst
\ No newline at end of file
diff --git a/doc/nitpick-exceptions.txt b/doc/nitpick-exceptions.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8d734225d8ec201efaa64dbaa1f16445fa70e22c
--- /dev/null
+++ b/doc/nitpick-exceptions.txt
@@ -0,0 +1,5 @@
+py:class torch.nn.modules.module.Module
+py:class torch.nn.modules.loss._Loss
+py:class torch.utils.data.dataset.Dataset
+py:mod bob.db.base
+py:obj list
diff --git a/doc/references.rst b/doc/references.rst
new file mode 100644
index 0000000000000000000000000000000000000000..3255b0a9bd3418faf8d5115c0417f101062ac49d
--- /dev/null
+++ b/doc/references.rst
@@ -0,0 +1,5 @@
+.. vim: set fileencoding=utf-8 :
+
+===========
+References
+===========
\ No newline at end of file
diff --git a/setup.py b/setup.py
index eb8429e7503e244f61d92e8f9d693ee38db4df2f..19649eff3d3edac46c13899151499d4592aaa632 100644
--- a/setup.py
+++ b/setup.py
@@ -55,12 +55,16 @@ setup(
          #bob hed train configurations
         'bob.ip.binseg.config': [
           'DRIU = bob.ip.binseg.configs.models.driu',
+          'DRIUJ01 = bob.ip.binseg.configs.models.driuj01',
+          'DRIUPAPER = bob.ip.binseg.configs.models.driupaper',
           'HED = bob.ip.binseg.configs.models.hed',
           'M2UNet = bob.ip.binseg.configs.models.m2unet',
+          'M2UNetJ01 = bob.ip.binseg.configs.models.m2unetj01',
           'UNet = bob.ip.binseg.configs.models.unet',
+          'UNetJ01 = bob.ip.binseg.configs.models.unetj01',
           'ResUNet = bob.ip.binseg.configs.models.resunet',
+          'ResUNetJ01 = bob.ip.binseg.configs.models.resunetj01',
           'ShapeResUNet = bob.ip.binseg.configs.models.shaperesunet',
-          'DRIUADABOUND = bob.ip.binseg.configs.models.driuadabound',
           'DRIVETRAIN = bob.ip.binseg.configs.datasets.drivetrain',
           'DRIVECROPTRAIN = bob.ip.binseg.configs.datasets.drivecroptrain',
           'DRIVECROPTEST = bob.ip.binseg.configs.datasets.drivecroptest',