diff --git a/bob/ip/binseg/configs/models/driubn.py b/bob/ip/binseg/configs/models/driubn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b95501d0a61053fd74b64976f6a761255944ece
--- /dev/null
+++ b/bob/ip/binseg/configs/models/driubn.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.driubn import build_driu
+import torch.optim as optim
+from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import SoftJaccardBCELogitsLoss
+from bob.ip.binseg.engine.adabound import AdaBound
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+scheduler_milestones = [900]
+scheduler_gamma = 0.1
+
+# model
+model = build_driu()
+
+# pretrained backbone
+pretrained_backbone = modelurls['vgg16_bn']
+
+# optimizer
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+# criterion
+criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/configs/models/driubnssl.py b/bob/ip/binseg/configs/models/driubnssl.py
new file mode 100644
index 0000000000000000000000000000000000000000..52b3a2b35272b99d5f47bae8f23d47da15990135
--- /dev/null
+++ b/bob/ip/binseg/configs/models/driubnssl.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.driubn import build_driu
+import torch.optim as optim
+from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import MixJacLoss
+from bob.ip.binseg.engine.adabound import AdaBound
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+scheduler_milestones = [900]
+scheduler_gamma = 0.1
+
+# model
+model = build_driu()
+
+# pretrained backbone
+pretrained_backbone = modelurls['vgg16_bn']
+
+# optimizer
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+
+# criterion
+criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7)
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py
index a1017ed5d192980b431a846c18567a40a4159906..ccff70d4019c6856334e982fb861d315c2d60196 100644
--- a/bob/ip/binseg/engine/inferencer.py
+++ b/bob/ip/binseg/engine/inferencer.py
@@ -192,10 +192,13 @@ def do_inference(
     logger.info("Saving average over all input images: {}".format(metrics_file))
     
     avg_metrics = df_metrics.groupby('threshold').mean()
+    std_metrics = df_metrics.groupby('threshold').std()
 
     avg_metrics["f1_score"] =  (2* avg_metrics["precision"]*avg_metrics["recall"])/ \
         (avg_metrics["precision"]+avg_metrics["recall"])
     
+    avg_metrics["std_f1"] = std_metrics["f1_score"]
+    
     avg_metrics.to_csv(metrics_path)
     maxf1 = avg_metrics['f1_score'].max()
     optimal_f1_threshold = avg_metrics['f1_score'].idxmax()
diff --git a/bob/ip/binseg/modeling/driubn.py b/bob/ip/binseg/modeling/driubn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c70dc4bc13c1a1f4101fa16fba64a4ff9d66e70
--- /dev/null
+++ b/bob/ip/binseg/modeling/driubn.py
@@ -0,0 +1,82 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import torch
+from torch import nn
+from collections import OrderedDict
+from bob.ip.binseg.modeling.backbones.vgg import vgg16_bn
+from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform,convtrans_with_kaiming_uniform, UpsampleCropBlock
+
+class ConcatFuseBlock(nn.Module):
+    """ 
+    Takes in four feature maps with 16 channels each, concatenates them 
+    and applies a 1x1 convolution with 1 output channel. 
+    """
+    def __init__(self):
+        super().__init__()
+        self.conv = nn.Sequential(
+            conv_with_kaiming_uniform(4*16,1,1,1,0)
+            ,nn.BatchNorm2d(1)
+        )
+    def forward(self,x1,x2,x3,x4):
+        x_cat = torch.cat([x1,x2,x3,x4],dim=1)
+        x = self.conv(x_cat)
+        return x 
+            
+class DRIU(nn.Module):
+    """
+    DRIU head module
+    Based on paper by `Maninis et al. (2016)`_ 
+    Parameters
+    ----------
+    in_channels_list : list
+        number of channels for each feature map that is returned from backbone
+    """
+    def __init__(self, in_channels_list=None):
+        super(DRIU, self).__init__()
+        in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8 = in_channels_list
+
+        self.conv1_2_16 = nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
+        # Upsample layers
+        self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0)
+        self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0)
+        self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0)
+        
+        # Concat and Fuse
+        self.concatfuse = ConcatFuseBlock()
+
+    def forward(self,x):
+        """
+        Parameters
+        ----------
+        x : list
+            list of tensors as returned from the backbone network.
+            First element: height and width of input image. 
+            Remaining elements: feature maps for each feature level.
+
+        Returns
+        -------
+        :py:class:`torch.Tensor`
+        """
+        hw = x[0]
+        conv1_2_16 = self.conv1_2_16(x[1])  # conv1_2_16   
+        upsample2 = self.upsample2(x[2], hw) # side-multi2-up
+        upsample4 = self.upsample4(x[3], hw) # side-multi3-up
+        upsample8 = self.upsample8(x[4], hw) # side-multi4-up
+        out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
+        return out
+
+def build_driu():
+    """ 
+    Adds backbone and head together
+
+    Returns
+    -------
+    :py:class:torch.nn.Module
+    """
+    backbone = vgg16_bn(pretrained=False, return_features = [3, 8, 14, 22])
+    driu_head = DRIU([64, 128, 256, 512])
+
+    model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)]))
+    model.name = "DRIUBN"
+    return model
\ No newline at end of file
diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py
index 7ddfe2afb6ed569cb03b770b5408a44b6b66d14d..ee8118afaa3dab9385e8e440a8ad37519b29c29a 100644
--- a/bob/ip/binseg/script/binseg.py
+++ b/bob/ip/binseg/script/binseg.py
@@ -113,6 +113,14 @@ def binseg():
     required=True,
     default='cpu',
     cls=ResourceOption)
+@click.option(
+    '--seed',
+    '-s',
+    help='torch random seed',
+    show_default=True,
+    required=False,
+    default=42,
+    cls=ResourceOption)
 
 @verbosity_option(cls=ResourceOption)
 def train(model
@@ -126,11 +134,12 @@ def train(model
         ,dataset
         ,checkpoint_period
         ,device
+        ,seed
         ,**kwargs):
     """ Train a model """
     
     if not os.path.exists(output_path): os.makedirs(output_path)
-    
+    torch.manual_seed(seed)
     # PyTorch dataloader
     data_loader = DataLoader(
         dataset = dataset
@@ -481,6 +490,14 @@ def visualize(dataset, output_path, **kwargs):
     required=True,
     default='900',
     cls=ResourceOption)
+@click.option(
+    '--seed',
+    '-s',
+    help='torch random seed',
+    show_default=True,
+    required=False,
+    default=42,
+    cls=ResourceOption)
 
 @verbosity_option(cls=ResourceOption)
 def ssltrain(model
@@ -495,11 +512,12 @@ def ssltrain(model
         ,checkpoint_period
         ,device
         ,rampup
+        ,seed
         ,**kwargs):
     """ Train a model """
     
     if not os.path.exists(output_path): os.makedirs(output_path)
-    
+    torch.manual_seed(seed)
     # PyTorch dataloader
     data_loader = DataLoader(
         dataset = dataset
diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py
index e8eab7da9c94b06ed82f353ef3ba626e2879b9d7..594b97f06f2391c2524104f74d751b28a6ffa7db 100644
--- a/bob/ip/binseg/utils/plot.py
+++ b/bob/ip/binseg/utils/plot.py
@@ -190,7 +190,7 @@ def plot_overview(outputfolders,title):
           rows = outfile.readlines()
           lastrow = rows[-1]
           parameter = int(lastrow.split()[1].replace(',',''))
-        name = '[P={:.2f}M] {} {}'.format(parameter/100**3, modelname, datasetname)
+        name = '[P={:.2f}M] {} {}'.format(parameter/100**3, modelname, "")
         names.append(name)
     #title = folder.split('/')[-4]
     fig = precision_recall_f1iso(precisions,recalls,names,title)
diff --git a/doc/benchmarkresults.rst b/doc/benchmarkresults.rst
index fb30961cb203ea3f2fe1f83f96c34d99e26d2403..ad608b2e3d1e8c598057c8ff7751ff86fb9d8758 100644
--- a/doc/benchmarkresults.rst
+++ b/doc/benchmarkresults.rst
@@ -6,20 +6,62 @@
 Benchmark Results
 ==================
 
-Dice Scores
+F1 Scores
 ===========
 
 * Benchmark results for models: DRIU, HED, M2UNet and U-Net.
-* Train-Test split as indicated in :ref:`bob.ip.binseg.datasets`
-
-+--------+----------+--------+---------+--------+--------+
-|        | CHASEDB1 | DRIVE  | HRF1168 | IOSTAR | STARE  |
-+--------+----------+--------+---------+--------+--------+
-| DRIU   | 0.8114   | 0.8226 | 0.7865  | 0.8273 | 0.8286 |
-+--------+----------+--------+---------+--------+--------+
-| HED    | 0.8111   | 0.8192 | 0.7868  | 0.8275 | 0.8250 |
-+--------+----------+--------+---------+--------+--------+
-| M2UNet | 0.8035   | 0.8051 | 0.7838  | 0.8194 | 0.8174 |
-+--------+----------+--------+---------+--------+--------+
-| UNet   | 0.8136   | 0.8237 | 0.7941  | 0.8203 | 0.8306 |
-+--------+----------+--------+---------+--------+--------+
+* Models are trained and tested on the same dataset using the train-test split as indicated in :ref:`bob.ip.binseg.datasets`
+* standard-deviations across all test images are indicated in brakets
+
++----------+-----------------+-----------------+-----------------+-----------------+-----------------+
+| F1 (std) | CHASEDB1        | DRIVE           | HRF1168         | IOSTAR          | STARE           |
++----------+-----------------+-----------------+-----------------+-----------------+-----------------+
+| DRIU     | 0.8114 (0.0206) | 0.8226 (0.0142) | 0.7865 (0.0545) | 0.8273 (0.0199) | 0.8286 (0.0368) |
++----------+-----------------+-----------------+-----------------+-----------------+-----------------+
+| HED      | 0.8111 (0.0214) | 0.8192 (0.0136) | 0.7868 (0.0576) | 0.8275 (0.0201) | 0.8250 (0.0375) |
++----------+-----------------+-----------------+-----------------+-----------------+-----------------+
+| M2UNet   | 0.8035 (0.0195) | 0.8051 (0.0141) | 0.7838 (0.0572) | 0.8194 (0.0201) | 0.8174 (0.0409) |
++----------+-----------------+-----------------+-----------------+-----------------+-----------------+
+| UNet     | 0.8136 (0.0209) | 0.8237 (0.0145) | 0.7914 (0.0516) | 0.8203 (0.0190) | 0.8306 (0.0421) |
++----------+-----------------+-----------------+-----------------+-----------------+-----------------+
+
+
+.. figure:: img/pr_CHASEDB1.png
+   :scale: 30 %
+   :align: center
+   :alt: model comparisons
+
+   CHASE_DB1: Precision vs Recall curve, F1 scores and
+   number of parameter of each model.
+
+.. figure:: img/pr_DRIVE.png
+   :scale: 30 %
+   :align: center
+   :alt: model comparisons
+
+   DRIVE: Precision vs Recall curve, F1 scores and
+   number of parameter of each model.
+
+.. figure:: img/pr_HRF.png
+   :scale: 30 %
+   :align: center
+   :alt: model comparisons
+
+   HRF: Precision vs Recall curve, F1 scores and
+   number of parameter of each model.
+
+.. figure:: img/pr_IOSTARVESSEL.png
+   :scale: 30 %
+   :align: center
+   :alt: model comparisons
+
+   IOSTAR: Precision vs Recall curve, F1 scores and
+   number of parameter of each model.
+
+.. figure:: img/pr_STARE.png
+   :scale: 30 %
+   :align: center
+   :alt: model comparisons
+
+   STARE: Precision vs Recall curve, F1 scores and
+   number of parameter of each model.
diff --git a/doc/covdresults.rst b/doc/covdresults.rst
index 05c48c75481c55b76f3a505c153ae38a15641bea..94ed0ab7c3871e8e985abe01896267f9a9511acc 100644
--- a/doc/covdresults.rst
+++ b/doc/covdresults.rst
@@ -6,7 +6,7 @@
 COVD- and COVD-SLL Results
 ==========================
 
-Dice Scores
+F1 Scores
 ===========
 
 +-------------------+---------------+---------+
diff --git a/doc/img/pr_CHASEDB1.png b/doc/img/pr_CHASEDB1.png
new file mode 100644
index 0000000000000000000000000000000000000000..7fe74f4e6178af9abc8fdda8c3d1142c992110c8
Binary files /dev/null and b/doc/img/pr_CHASEDB1.png differ
diff --git a/doc/img/pr_DRIVE.png b/doc/img/pr_DRIVE.png
new file mode 100644
index 0000000000000000000000000000000000000000..fc9e739e31c47bf319981dc6a561e335acfb261b
Binary files /dev/null and b/doc/img/pr_DRIVE.png differ
diff --git a/doc/img/pr_HRF.png b/doc/img/pr_HRF.png
new file mode 100644
index 0000000000000000000000000000000000000000..ac6f870ece6c4fe9d439ba5c0d5e3914eea3bcbb
Binary files /dev/null and b/doc/img/pr_HRF.png differ
diff --git a/doc/img/pr_IOSTARVESSEL.png b/doc/img/pr_IOSTARVESSEL.png
new file mode 100644
index 0000000000000000000000000000000000000000..97ed5c7a6b8f0d7ab6c0786db55588d7b163e9bb
Binary files /dev/null and b/doc/img/pr_IOSTARVESSEL.png differ
diff --git a/doc/img/pr_STARE.png b/doc/img/pr_STARE.png
new file mode 100644
index 0000000000000000000000000000000000000000..14603d2d3782292e66c813685fc61bca60953976
Binary files /dev/null and b/doc/img/pr_STARE.png differ
diff --git a/setup.py b/setup.py
index 44a23e1f588c69e2d631a9666f54957a21b8f175..859a8ded5789b9c9e000e931151144083a10669a 100644
--- a/setup.py
+++ b/setup.py
@@ -58,7 +58,9 @@ setup(
          #bob train configurations
         'bob.ip.binseg.config': [
           'DRIU = bob.ip.binseg.configs.models.driu',
+          'DRIUBN = bob.ip.binseg.configs.models.driubn',
           'DRIUSSL = bob.ip.binseg.configs.models.driussl',
+          'DRIUBNSSL = bob.ip.binseg.configs.models.driubnssl',
           'DRIUOD = bob.ip.binseg.configs.models.driuod',
           'HED = bob.ip.binseg.configs.models.hed',
           'M2UNet = bob.ip.binseg.configs.models.m2unet',