diff --git a/bob/ip/binseg/configs/datasets/drivetest.py b/bob/ip/binseg/configs/datasets/drivetest.py
index 67d21cac20df4095716c2735665ee418eb671a88..230598dce92a39276e05dd4b4f842643428546b4 100644
--- a/bob/ip/binseg/configs/datasets/drivetest.py
+++ b/bob/ip/binseg/configs/datasets/drivetest.py
@@ -7,8 +7,9 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([
-                        ToTensor()
+transforms = Compose([  
+                        CenterCrop((544,544))
+                        ,ToTensor()
                     ])
 
 # bob.db.dataset init
diff --git a/bob/ip/binseg/data/binsegdataset.py b/bob/ip/binseg/data/binsegdataset.py
index b75a7dde55eaec9a2d94b5ba712e5154ba29b6d9..f212c4fe680969aaef27a300d08ca6f6ec35a50c 100644
--- a/bob/ip/binseg/data/binsegdataset.py
+++ b/bob/ip/binseg/data/binsegdataset.py
@@ -21,6 +21,11 @@ class BinSegDataset(Dataset):
         self.transform = transform
         self.split = split
     
+    @property
+    def mask(self):
+        # check if first sample contains a mask
+        return hasattr(self.database[0], 'mask')
+
     def __len__(self):
         """
         Returns
@@ -39,16 +44,19 @@ class BinSegDataset(Dataset):
         Returns
         -------
         list
-            dataitem [img, gt, mask, img_name]
+            dataitem [img_name, img, gt, mask]
         """
         img = self.database[index].img.pil_image()
         gt = self.database[index].gt.pil_image()
-        mask = self.database[index].mask.pil_image() if hasattr(self.database[index], 'mask') else None
         img_name = self.database[index].img.basename
+        sample = [img, gt]
+        if self.mask:
+            mask = self.database[index].mask.pil_image()
+            sample.append(mask)
+        
+        if self.transform :
+            sample = self.transform(*sample)
+        
+        sample.insert(0,img_name)
         
-        if self.transform and mask:
-            img, gt, mask = self.transform(img, gt, mask)
-        else:
-            img, gt  = self.transform(img, gt)
-            
-        return img, gt, mask, img_name
+        return sample
diff --git a/bob/ip/binseg/data/transforms.py b/bob/ip/binseg/data/transforms.py
index 34b95ca7b8900cffd9fd6b45807c51f1dda06861..99e941b65099cdd1f85ba31374807f90b471131e 100644
--- a/bob/ip/binseg/data/transforms.py
+++ b/bob/ip/binseg/data/transforms.py
@@ -10,7 +10,7 @@ from torchvision.transforms.transforms import Compose as TorchVisionCompose
 import math
 from math import floor
 import warnings
-
+import collections
 
 _pil_interpolation_to_str = {
     Image.NEAREST: 'PIL.Image.NEAREST',
@@ -20,6 +20,7 @@ _pil_interpolation_to_str = {
     Image.HAMMING: 'PIL.Image.HAMMING',
     Image.BOX: 'PIL.Image.BOX',
 }
+Iterable = collections.abc.Iterable
 
 # Compose 
 
@@ -440,4 +441,32 @@ class Distortion:
                 imgs.append(img)
             return imgs
         else:
-            return args
\ No newline at end of file
+            return args
+
+
+class Resize:
+    """Resize to given size.
+    
+    Attributes
+    ----------
+    size : tuple or int
+        Desired output size. If size is a sequence like
+        (h, w), output size will be matched to this. If size is an int,
+        smaller edge of the image will be matched to this number.
+        i.e, if height > width, then image will be rescaled to
+        (size * height / width, size)
+    interpolation : int
+        Desired interpolation. Default is``PIL.Image.BILINEAR``
+    """
+
+    def __init__(self, size, interpolation=Image.BILINEAR):
+        assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
+        self.size = size
+        self.interpolation = interpolation
+
+    def __call__(self, *args):
+        return [VF.resize(img, self.size, self.interpolation) for img in args]
+
+    def __repr__(self):
+        interpolate_str = _pil_interpolation_to_str[self.interpolation]
+        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py
index 3221277c11ff8545f0d3182952c329a82488ed50..bfdf9c1ba297b27e01206c0b7a77d25f0afd4d15 100644
--- a/bob/ip/binseg/engine/inferencer.py
+++ b/bob/ip/binseg/engine/inferencer.py
@@ -13,10 +13,11 @@ from tqdm import tqdm
 
 from bob.ip.binseg.utils.metric import SmoothedValue, base_metrics
 from bob.ip.binseg.utils.plot import precision_recall_f1iso
+from bob.ip.binseg.utils.summary import summary
 
 
 
-def batch_metrics(predictions, ground_truths, masks, names, output_folder, logger):
+def batch_metrics(predictions, ground_truths, names, output_folder, logger):
     """
     Calculates metrics on the batch and saves it to disc
 
@@ -26,8 +27,6 @@ def batch_metrics(predictions, ground_truths, masks, names, output_folder, logge
         tensor with pixel-wise probabilities
     ground_truths : :py:class:`torch.Tensor`
         tensor with binary ground-truth
-    mask : :py:class:`torch.Tensor`
-        tensor with mask
     names : list
         list of file names 
     output_folder : str
@@ -73,7 +72,6 @@ def batch_metrics(predictions, ground_truths, masks, names, output_folder, logge
                 # true negatives
                 tn_tensor = equals - tp_tensor
                 tn_count = torch.sum(tn_tensor).item()
-                # TODO: Substract masks from True negatives?
 
                 # false negatives
                 fn_tensor = notequals - fp_tensor
@@ -152,9 +150,10 @@ def do_inference(
     # Collect overall metrics 
     metrics = []
 
-    for images, ground_truths, masks, names in tqdm(data_loader):
-        images = images.to(device)
-        ground_truths = ground_truths.to(device)
+    for samples in tqdm(data_loader):
+        names = samples[0]
+        images = samples[1].to(device)
+        ground_truths = samples[2].to(device)
         with torch.no_grad():
             start_time = time.perf_counter()
 
@@ -166,15 +165,12 @@ def do_inference(
                 outputs = outputs[-1]
             
             probabilities = sigmoid(outputs)
-            if hasattr(masks,'dtype'):
-                masks = masks.to(device)
-                probabilities = probabilities * masks
             
             batch_time = time.perf_counter() - start_time
             times.append(batch_time)
             logger.info("Batch time: {:.5f} s".format(batch_time))
             
-            b_metrics = batch_metrics(probabilities, ground_truths, masks, names,results_subfolder, logger)
+            b_metrics = batch_metrics(probabilities, ground_truths, names,results_subfolder, logger)
             metrics.extend(b_metrics)
             
             # Create probability images
@@ -209,7 +205,7 @@ def do_inference(
     
     # Plotting
     np_avg_metrics = avg_metrics.to_numpy().T
-    fig_name = "precision_recall.pdf".format(model.name)
+    fig_name = "precision_recall.pdf"
     logger.info("saving {}".format(fig_name))
     fig = precision_recall_f1iso([np_avg_metrics[0]],[np_avg_metrics[1]], [model.name,None])
     fig_filename = os.path.join(results_subfolder, fig_name)
@@ -222,7 +218,7 @@ def do_inference(
 
     logger.info("Average batch inference time: {:.5f}s".format(average_batch_inference_time))
 
-    times_file = "Times.txt".format(model.name)
+    times_file = "Times.txt"
     logger.info("saving {}".format(times_file))
  
     with open (os.path.join(results_subfolder,times_file), "w+") as outfile:
@@ -232,4 +228,12 @@ def do_inference(
         outfile.write("Average batch inference time: {} \n".format(average_batch_inference_time))
         outfile.write("Total inference time: {} \n".format(total_inference_time))
 
+    # Save model summary 
+    summary_file = 'ModelSummary.txt'
+    logger.info("saving {}".format(summary_file))
+
+    with open (os.path.join(results_subfolder,summary_file), "w+") as outfile:
+        summary(model,outfile)
+
+
 
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index ac24b55b5030ed5a51d2787fae1db5d5ef9af5e0..d703ca3060761cac010b9d99821c7e951566e349 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -70,12 +70,13 @@ def do_train(
             # Epoch time
             start_epoch_time = time.time()
 
-            for images, ground_truths, masks, _ in tqdm(data_loader):
+            for samples in tqdm(data_loader):
 
-                images = images.to(device)
-                ground_truths = ground_truths.to(device)
-                if hasattr(masks,'dtype'):
-                    masks = masks.to(device)
+                images = samples[1].to(device)
+                ground_truths = samples[2].to(device)
+                masks = None
+                if len(samples) == 4:
+                    masks = samples[-1].to(device)
                 
                 outputs = model(images)
                 
diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py
index c57bf7fe1ab851669de7b878b3f2fa3b5383637b..3e32b0904f2eb53ac145fa59b0c5537392210ef8 100644
--- a/bob/ip/binseg/modeling/losses.py
+++ b/bob/ip/binseg/modeling/losses.py
@@ -51,7 +51,7 @@ class SoftJaccardBCELogitsLoss(_Loss):
     alpha : float
         determines the weighting of SoftJaccard and BCE. Default: ``0.3``
     """
-    def __init__(self, alpha=0.1, size_average=None, reduce=None, reduction='mean', pos_weight=None):
+    def __init__(self, alpha=0.3, size_average=None, reduce=None, reduction='mean', pos_weight=None):
         super(SoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction) 
         self.alpha = alpha   
 
@@ -118,4 +118,46 @@ class HEDWeightedBCELogitsLoss(_Loss):
             loss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction=self.reduction)
             loss_over_all_inputs.append(loss.unsqueeze(0))
         final_loss = torch.cat(loss_over_all_inputs).mean()
-        return final_loss 
\ No newline at end of file
+        return final_loss 
+
+
+class HEDSoftJaccardBCELogitsLoss(_Loss):
+    """ 
+    Implements Equation 6 in [SAT17]_ for the hed network. Based on torch.nn.modules.loss.BCEWithLogitsLoss. 
+
+    Attributes
+    ----------
+    alpha : float
+        determines the weighting of SoftJaccard and BCE. Default: ``0.3``
+    """
+    def __init__(self, alpha=0.3, size_average=None, reduce=None, reduction='mean', pos_weight=None):
+        super(HEDSoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction) 
+        self.alpha = alpha   
+
+    @weak_script_method
+    def forward(self, inputlist, target, masks=None):
+        """
+        Parameters
+        ----------
+        input : :py:class:`torch.Tensor`
+        target : :py:class:`torch.Tensor`
+        masks : :py:class:`torch.Tensor`, optional
+        
+        Returns
+        -------
+        :py:class:`torch.Tensor`
+        """
+        eps = 1e-8
+        loss_over_all_inputs = []
+        for input in inputlist:
+            probabilities = torch.sigmoid(input)
+            intersection = (probabilities * target).sum()
+            sums = probabilities.sum() + target.sum()
+            
+            softjaccard = intersection/(sums - intersection + eps)
+    
+            bceloss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=None, reduction=self.reduction)
+            loss = self.alpha * bceloss + (1 - self.alpha) * (1-softjaccard)
+            loss_over_all_inputs.append(loss.unsqueeze(0))
+        final_loss = torch.cat(loss_over_all_inputs).mean()
+        return loss
\ No newline at end of file
diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py
index 8fc667b1bdd3d4c22f926116f82c61b4841f2579..7adedbae04c2b373b57f937b7c8cd26ee694f7a8 100644
--- a/bob/ip/binseg/script/binseg.py
+++ b/bob/ip/binseg/script/binseg.py
@@ -297,11 +297,12 @@ def testcheckpoints(model
     required=True,
     )
 @verbosity_option(cls=ResourceOption)
-def compare(output_path_list, output_path,**kwargs):
+def compare(output_path_list, output_path, **kwargs):
     """ Compares multiple metrics files that are stored in the format mymodel/results/Metrics.csv """
     logger.debug("Output paths: {}".format(output_path_list))
     logger.info('Plotting precision vs recall curves for {}'.format(output_path_list))
     fig = plot_overview(output_path_list)
+    if not os.path.exists(output_path): os.makedirs(output_path)
     fig_filename = os.path.join(output_path, 'precision_recall_comparison.pdf')
     logger.info('saving {}'.format(fig_filename))
     fig.savefig(fig_filename)
diff --git a/bob/ip/binseg/test/test_batchmetrics.py b/bob/ip/binseg/test/test_batchmetrics.py
index 93d573e809e3ee17f2af4a61cb71a36ab49849c1..4988cab6ea8a7ffbc907f2e580ebe46b48ad9611 100644
--- a/bob/ip/binseg/test/test_batchmetrics.py
+++ b/bob/ip/binseg/test/test_batchmetrics.py
@@ -20,7 +20,6 @@ class Tester(unittest.TestCase):
         self.fn = random.randint(1, 100)
         self.predictions = torch.rand(size=(2,1,420,420))
         self.ground_truths = torch.randint(low=0, high=2, size=(2,1,420,420))
-        self.masks = None
         self.names = ['Bob','Tim'] 
         self.output_folder = tempfile.mkdtemp()
         self.logger = logging.getLogger(__name__)
@@ -30,7 +29,7 @@ class Tester(unittest.TestCase):
         shutil.rmtree(self.output_folder)
     
     def test_batch_metrics(self):
-        bm = batch_metrics(self.predictions, self.ground_truths, self.masks, self.names, self.output_folder, self.logger)
+        bm = batch_metrics(self.predictions, self.ground_truths, self.names, self.output_folder, self.logger)
         self.assertEqual(len(bm),2*100)
         for metric in bm:
             # check whether f1 score agree
diff --git a/bob/ip/binseg/test/test_models.py b/bob/ip/binseg/test/test_models.py
index a5f37d3fb3ef05c1241544bef53a79e1166a6b34..35e39db8c2e3641b22a8a4e563651b70c968b4dd 100644
--- a/bob/ip/binseg/test/test_models.py
+++ b/bob/ip/binseg/test/test_models.py
@@ -5,6 +5,7 @@ import torch
 import unittest
 import numpy as np
 from bob.ip.binseg.modeling.driu import build_driu
+from bob.ip.binseg.modeling.driuod import build_driuod
 from bob.ip.binseg.modeling.hed import build_hed
 from bob.ip.binseg.modeling.unet import build_unet
 from bob.ip.binseg.modeling.resunet import build_res50unet
@@ -24,6 +25,12 @@ class Tester(unittest.TestCase):
         self.assertEqual(self.hw.all(), out_hw.all())
 
 
+    def test_driuod(self):
+        model = build_driuod()
+        out = model(self.x)
+        out_hw = np.array(out.shape)[[2,3]]
+        self.assertEqual(self.hw.all(), out_hw.all())
+
     def test_hed(self):
         model = build_hed()
         out = model(self.x)
diff --git a/bob/ip/binseg/test/test_summary.py b/bob/ip/binseg/test/test_summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..7faabf796674db6b7914d631ba41f9160c08a623
--- /dev/null
+++ b/bob/ip/binseg/test/test_summary.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import torch
+import unittest
+import numpy as np
+from bob.ip.binseg.modeling.driu import build_driu
+from bob.ip.binseg.modeling.driuod import build_driuod
+from bob.ip.binseg.modeling.hed import build_hed
+from bob.ip.binseg.modeling.unet import build_unet
+from bob.ip.binseg.modeling.resunet import build_res50unet
+from bob.ip.binseg.utils.summary import summary
+
+class Tester(unittest.TestCase):
+    """
+    Unit test for model architectures
+    """    
+    def test_summary_driu(self):
+        model = build_driu()
+        param = summary(model)
+        self.assertIsInstance(param,int)
+
+
+    def test__summary_driuod(self):
+        model = build_driuod()
+        param = summary(model)
+        self.assertIsInstance(param,int)
+
+
+    def test_summary_hed(self):
+        model = build_hed()
+        param = summary(model)
+        self.assertIsInstance(param,int)
+
+    def test_summary_unet(self):
+        model = build_unet()
+        param = summary(model)
+        self.assertIsInstance(param,int)
+
+    def test_summary_resunet(self):
+        model = build_res50unet()
+        param = summary(model)
+        self.assertIsInstance(param,int)
+
+if __name__ == '__main__':
+    unittest.main()
\ No newline at end of file
diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py
index 5a2677d5b055a9ed9581d08ee50bdc45c329d221..5e6fa29d35658c42cfb0191018ac7bc6d80123f7 100644
--- a/bob/ip/binseg/utils/plot.py
+++ b/bob/ip/binseg/utils/plot.py
@@ -24,7 +24,7 @@ def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds
         of the system's recall coordinates. 
       names : :py:class:`list`
         An iterable over the names of each of the systems along the rows of
-        ``precision`` and ``recall``    
+        ``precision`` and ``recall``      
       title : :py:class:`str`, optional
         A title for the plot. If not set, omits the title   
       human_perf_bsds500 : :py:class:`bool`, optional
@@ -38,7 +38,10 @@ def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds
     import matplotlib
     matplotlib.use('agg')
     import matplotlib.pyplot as plt 
+    from itertools import cycle
     fig, ax1 = plt.subplots(1)  
+    lines = ["-","--","-.",":"]
+    linecycler = cycle(lines)
     for p, r, n in zip(precision, recall, names):   
         # Plots only from the point where recall reaches its maximum, otherwise, we
         # don't see a curve...
@@ -47,8 +50,13 @@ def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds
         ri = r[i:]    
         valid = (pi+ri) > 0
         f1 = 2 * (pi[valid]*ri[valid]) / (pi[valid]+ri[valid])    
+        # optimal point along the curve
+        argmax = f1.argmax()
+        opi = pi[argmax]
+        ori = ri[argmax]
         # Plot Recall/Precision as threshold changes
-        ax1.plot(ri[pi>0], pi[pi>0], label='[F={:.3f}] {}'.format(f1.max(), n)) 
+        ax1.plot(ri[pi>0], pi[pi>0], next(linecycler), label='[F={:.3f}] {}'.format(f1.max(), n),) 
+        ax1.plot(ori,opi, marker='o', linestyle=None, markersize=3, color='black')
     ax1.grid(linestyle='--', linewidth=1, color='gray', alpha=0.2)  
     if len(names) > 1:
         plt.legend(loc='lower left', framealpha=0.5)  
@@ -152,8 +160,7 @@ def plot_overview(outputfolders):
     Arguments
     ---------
     outputfolder : list
-                    list containing output paths of all evaluated models (e.g. ['output/model1', 'output/model2'])
-    
+                    list containing output paths of all evaluated models (e.g. ['DRIVE/model1', 'DRIVE/model2'])
     Returns
     -------
     fig : matplotlib.figure.Figure
@@ -161,11 +168,24 @@ def plot_overview(outputfolders):
     precisions = []
     recalls = []
     names = []
+    params = []
     for folder in outputfolders:
+        # metrics 
         metrics_path = os.path.join(folder,'results/Metrics.csv')
         pr, re = read_metricscsv(metrics_path)
         precisions.append(pr)
         recalls.append(re)
-        names.append(folder)
-    fig = precision_recall_f1iso(precisions,recalls,names)
+        modelname = folder.split('/')[-1]
+        # parameters
+        summary_path = os.path.join(folder,'results/ModelSummary.txt')
+        with open (summary_path, "r") as outfile:
+          rows = outfile.readlines()
+          lastrow = rows[-1]
+          parameter = int(lastrow.split()[1].replace(',',''))
+        name = '[P={:.2f}M] {}'.format(parameter/100**3, modelname)
+        names.append(name)
+    title = folder.split('/')[-2]
+    fig = precision_recall_f1iso(precisions,recalls,names,title)
     return fig
+
+  
diff --git a/bob/ip/binseg/utils/summary.py b/bob/ip/binseg/utils/summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..127c5e66a3d7c97629ef95acebccb4f138ba43b6
--- /dev/null
+++ b/bob/ip/binseg/utils/summary.py
@@ -0,0 +1,63 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488 
+import sys
+import logging
+from functools import reduce
+
+from torch.nn.modules.module import _addindent
+from bob.ip.binseg.modeling.driu import build_driu
+
+
+def summary(model, file=sys.stderr):
+    """Counts the number of paramters in each layers
+    
+    Parameters
+    ----------
+    model : :py:class:`torch.nn.Module`
+    
+    Returns
+    -------
+    int
+        number of parameters
+    """
+    def repr(model):
+        # We treat the extra repr like the sub-module, one item per line
+        extra_lines = []
+        extra_repr = model.extra_repr()
+        # empty string will be split into list ['']
+        if extra_repr:
+            extra_lines = extra_repr.split('\n')
+        child_lines = []
+        total_params = 0
+        for key, module in model._modules.items():
+            mod_str, num_params = repr(module)
+            mod_str = _addindent(mod_str, 2)
+            child_lines.append('(' + key + '): ' + mod_str)
+            total_params += num_params
+        lines = extra_lines + child_lines
+
+        for name, p in model._parameters.items():
+            if hasattr(p,'dtype'):
+                total_params += reduce(lambda x, y: x * y, p.shape)
+
+        main_str = model._get_name() + '('
+        if lines:
+            # simple one-liner info, which most builtin Modules will use
+            if len(extra_lines) == 1 and not child_lines:
+                main_str += extra_lines[0]
+            else:
+                main_str += '\n  ' + '\n  '.join(lines) + '\n'
+
+        main_str += ')'
+        if file is sys.stderr:
+            main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
+        else:
+            main_str += ', {:,} params'.format(total_params)
+        return main_str, total_params
+
+    string, count = repr(model)
+    if file is not None:
+        print(string, file=file)
+    return count
\ No newline at end of file