From 833335daf186216dd45825ebaa9601506aee6c3a Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 19 Mar 2020 17:14:34 +0100
Subject: [PATCH] [all] Passed black on all python files

---
 bob/__init__.py                               |   1 +
 bob/ip/__init__.py                            |   3 +-
 bob/ip/binseg/configs/datasets/amdrive.py     |  59 +-
 bob/ip/binseg/configs/datasets/amdrivetest.py |  46 +-
 bob/ip/binseg/configs/datasets/chasedb1.py    |  22 +-
 .../binseg/configs/datasets/chasedb11024.py   |  24 +-
 .../binseg/configs/datasets/chasedb11168.py   |  24 +-
 bob/ip/binseg/configs/datasets/chasedb1544.py |  24 +-
 bob/ip/binseg/configs/datasets/chasedb1608.py |  24 +-
 .../binseg/configs/datasets/chasedb1test.py   |   9 +-
 bob/ip/binseg/configs/datasets/drionsdb.py    |  22 +-
 .../binseg/configs/datasets/drionsdbtest.py   |   9 +-
 .../binseg/configs/datasets/dristhigs1cup.py  |  22 +-
 .../configs/datasets/dristhigs1cuptest.py     |   9 +-
 .../binseg/configs/datasets/dristhigs1od.py   |  22 +-
 .../configs/datasets/dristhigs1odtest.py      |   9 +-
 bob/ip/binseg/configs/datasets/drive.py       |  22 +-
 bob/ip/binseg/configs/datasets/drive1024.py   |  24 +-
 .../binseg/configs/datasets/drive1024test.py  |  10 +-
 bob/ip/binseg/configs/datasets/drive1168.py   |  26 +-
 bob/ip/binseg/configs/datasets/drive608.py    |  26 +-
 bob/ip/binseg/configs/datasets/drive960.py    |  24 +-
 .../datasets/drivechasedb1iostarhrf608.py     |   2 +-
 .../drivechasedb1iostarhrf608sslstare.py      |  34 +-
 .../datasets/drivestarechasedb11168.py        |   2 +-
 .../datasets/drivestarechasedb1hrf1024.py     |   2 +-
 .../drivestarechasedb1hrf1024ssliostar.py     |  26 +-
 .../datasets/drivestarechasedb1iostar1168.py  |   2 +-
 .../drivestarechasedb1iostar1168sslhrf.py     |  36 +-
 .../datasets/drivestareiostarhrf960.py        |   2 +-
 .../drivestareiostarhrf960sslchase.py         |  34 +-
 bob/ip/binseg/configs/datasets/drivetest.py   |   9 +-
 bob/ip/binseg/configs/datasets/hrf.py         |  22 +-
 bob/ip/binseg/configs/datasets/hrf1024.py     |  24 +-
 bob/ip/binseg/configs/datasets/hrf1168.py     |  24 +-
 bob/ip/binseg/configs/datasets/hrf1168test.py |  10 +-
 bob/ip/binseg/configs/datasets/hrf544.py      |  24 +-
 bob/ip/binseg/configs/datasets/hrf544test.py  |  10 +-
 bob/ip/binseg/configs/datasets/hrf608.py      |  24 +-
 bob/ip/binseg/configs/datasets/hrf960.py      |  24 +-
 bob/ip/binseg/configs/datasets/hrftest.py     |   9 +-
 bob/ip/binseg/configs/datasets/imagefolder.py |  22 +-
 .../configs/datasets/imagefolderinference.py  |  10 +-
 .../configs/datasets/imagefoldertest.py       |   9 +-
 bob/ip/binseg/configs/datasets/iostarod.py    |  14 +-
 .../binseg/configs/datasets/iostarodtest.py   |   8 +-
 .../binseg/configs/datasets/iostarvessel.py   |  14 +-
 .../configs/datasets/iostarvessel1168.py      |  26 +-
 .../configs/datasets/iostarvessel544.py       |  22 +-
 .../configs/datasets/iostarvessel544test.py   |   9 +-
 .../configs/datasets/iostarvessel608.py       |  24 +-
 .../configs/datasets/iostarvessel960.py       |  22 +-
 .../configs/datasets/iostarvesseltest.py      |   8 +-
 bob/ip/binseg/configs/datasets/refugecup.py   |  24 +-
 .../binseg/configs/datasets/refugecuptest.py  |   9 +-
 bob/ip/binseg/configs/datasets/refugeod.py    |  24 +-
 .../binseg/configs/datasets/refugeodtest.py   |   9 +-
 bob/ip/binseg/configs/datasets/rimoner3cup.py |  22 +-
 .../configs/datasets/rimoner3cuptest.py       |   9 +-
 bob/ip/binseg/configs/datasets/rimoner3od.py  |  22 +-
 .../binseg/configs/datasets/rimoner3odtest.py |   9 +-
 bob/ip/binseg/configs/datasets/stare.py       |  22 +-
 bob/ip/binseg/configs/datasets/stare1024.py   |  26 +-
 bob/ip/binseg/configs/datasets/stare1168.py   |  26 +-
 bob/ip/binseg/configs/datasets/stare544.py    |  23 +-
 bob/ip/binseg/configs/datasets/stare960.py    |  26 +-
 .../datasets/starechasedb1iostarhrf544.py     |   2 +-
 .../starechasedb1iostarhrf544ssldrive.py      |  34 +-
 bob/ip/binseg/configs/datasets/staretest.py   |   9 +-
 bob/ip/binseg/configs/models/driubn.py        |  18 +-
 bob/ip/binseg/configs/models/driubnssl.py     |  18 +-
 bob/ip/binseg/configs/models/driuod.py        |  18 +-
 bob/ip/binseg/configs/models/driussl.py       |  18 +-
 bob/ip/binseg/configs/models/hed.py           |  18 +-
 bob/ip/binseg/configs/models/m2unet.py        |  20 +-
 bob/ip/binseg/configs/models/m2unetssl.py     |  20 +-
 bob/ip/binseg/configs/models/resunet.py       |  20 +-
 bob/ip/binseg/configs/models/unet.py          |  20 +-
 bob/ip/binseg/data/binsegdataset.py           |  47 +-
 bob/ip/binseg/data/imagefolder.py             |  43 +-
 bob/ip/binseg/data/imagefolderinference.py    |  11 +-
 bob/ip/binseg/data/transforms.py              | 107 ++-
 bob/ip/binseg/engine/adabound.py              | 143 ++--
 bob/ip/binseg/engine/inferencer.py            | 142 ++--
 bob/ip/binseg/engine/predicter.py             |  30 +-
 bob/ip/binseg/engine/ssltrainer.py            | 151 ++--
 bob/ip/binseg/engine/trainer.py               |  79 +-
 .../binseg/modeling/backbones/mobilenetv2.py  |  45 +-
 bob/ip/binseg/modeling/backbones/resnet.py    |  27 +-
 bob/ip/binseg/modeling/backbones/vgg.py       | 120 ++-
 bob/ip/binseg/modeling/driu.py                |   7 +-
 bob/ip/binseg/modeling/driubn.py              |  33 +-
 bob/ip/binseg/modeling/driuod.py              |  34 +-
 bob/ip/binseg/modeling/driupix.py             |  39 +-
 bob/ip/binseg/modeling/hed.py                 |  57 +-
 bob/ip/binseg/modeling/losses.py              |  24 +-
 bob/ip/binseg/modeling/m2u.py                 |  65 +-
 bob/ip/binseg/modeling/make_layers.py         |  90 ++-
 bob/ip/binseg/modeling/resunet.py             |  18 +-
 bob/ip/binseg/modeling/unet.py                |  14 +-
 bob/ip/binseg/script/binseg.py                | 689 +++++++-----------
 bob/ip/binseg/test/test_basemetrics.py        |  27 +-
 bob/ip/binseg/test/test_batchmetrics.py       |  29 +-
 bob/ip/binseg/test/test_checkpointer.py       |   8 +-
 bob/ip/binseg/test/test_summary.py            |  21 +-
 bob/ip/binseg/test/test_transforms.py         |  39 +-
 bob/ip/binseg/utils/checkpointer.py           |   4 +-
 bob/ip/binseg/utils/click.py                  |   8 +-
 bob/ip/binseg/utils/evaluate.py               | 122 ++--
 bob/ip/binseg/utils/metric.py                 |  17 +-
 bob/ip/binseg/utils/model_serialization.py    |   3 +-
 bob/ip/binseg/utils/model_zoo.py              |  18 +-
 bob/ip/binseg/utils/plot.py                   | 301 +++++---
 bob/ip/binseg/utils/rsttable.py               |  37 +-
 bob/ip/binseg/utils/summary.py                |  21 +-
 bob/ip/binseg/utils/transformfolder.py        |  11 +-
 116 files changed, 2241 insertions(+), 1875 deletions(-)

diff --git a/bob/__init__.py b/bob/__init__.py
index 2ab1e28b..edbb4090 100644
--- a/bob/__init__.py
+++ b/bob/__init__.py
@@ -1,3 +1,4 @@
 # see https://docs.python.org/3/library/pkgutil.html
 from pkgutil import extend_path
+
 __path__ = extend_path(__path__, __name__)
diff --git a/bob/ip/__init__.py b/bob/ip/__init__.py
index 2ca5e07c..edbb4090 100644
--- a/bob/ip/__init__.py
+++ b/bob/ip/__init__.py
@@ -1,3 +1,4 @@
 # see https://docs.python.org/3/library/pkgutil.html
 from pkgutil import extend_path
-__path__ = extend_path(__path__, __name__)
\ No newline at end of file
+
+__path__ = extend_path(__path__, __name__)
diff --git a/bob/ip/binseg/configs/datasets/amdrive.py b/bob/ip/binseg/configs/datasets/amdrive.py
index 6f7cc310..0b1b1639 100644
--- a/bob/ip/binseg/configs/datasets/amdrive.py
+++ b/bob/ip/binseg/configs/datasets/amdrive.py
@@ -11,68 +11,55 @@ import torch
 
 # Target size: 544x544 (DRIVE)
 
-defaulttransforms = [RandomHFlip()
-                    ,RandomVFlip()
-                    ,RandomRotation()
-                    ,ColorJitter()
-                    ,ToTensor()]
+defaulttransforms = [
+    RandomHFlip(),
+    RandomVFlip(),
+    RandomRotation(),
+    ColorJitter(),
+    ToTensor(),
+]
 
 
-
-# CHASE_DB1 
-transforms_chase = Compose([      
-                        Resize(544)
-                        ,Crop(0,12,544,544)
-                        ,*defaulttransforms
-                    ])
+# CHASE_DB1
+transforms_chase = Compose([Resize(544), Crop(0, 12, 544, 544), *defaulttransforms])
 
 # bob.db.dataset init
-bobdb_chase = CHASEDB1(protocol = 'default')
+bobdb_chase = CHASEDB1(protocol="default")
 
 # PyTorch dataset
-torch_chase = BinSegDataset(bobdb_chase, split='train', transform=transforms_chase)
+torch_chase = BinSegDataset(bobdb_chase, split="train", transform=transforms_chase)
 
 
 # IOSTAR VESSEL
-transforms_iostar = Compose([  
-                        Resize(544)
-                        ,*defaulttransforms
-                    ])
+transforms_iostar = Compose([Resize(544), *defaulttransforms])
 
 # bob.db.dataset init
-bobdb_iostar = IOSTAR(protocol='default_vessel')
+bobdb_iostar = IOSTAR(protocol="default_vessel")
 
 # PyTorch dataset
-torch_iostar = BinSegDataset(bobdb_iostar, split='train', transform=transforms_iostar)
+torch_iostar = BinSegDataset(bobdb_iostar, split="train", transform=transforms_iostar)
 
 # STARE
-transforms = Compose([  
-                        Resize(471)
-                        ,Pad((0,37,0,36))
-                        ,*defaulttransforms
-                    ])
+transforms = Compose([Resize(471), Pad((0, 37, 0, 36)), *defaulttransforms])
 
 # bob.db.dataset init
-bobdb_stare = STARE(protocol = 'default')
+bobdb_stare = STARE(protocol="default")
 
 # PyTorch dataset
-torch_stare = BinSegDataset(bobdb_stare, split='train', transform=transforms)
+torch_stare = BinSegDataset(bobdb_stare, split="train", transform=transforms)
 
 
 # HRF
-transforms_hrf = Compose([  
-                        Resize((363))
-                        ,Pad((0,90,0,91))
-                        ,*defaulttransforms
-                    ])
+transforms_hrf = Compose([Resize((363)), Pad((0, 90, 0, 91)), *defaulttransforms])
 
 # bob.db.dataset init
-bobdb_hrf = HRF(protocol = 'default')
+bobdb_hrf = HRF(protocol="default")
 
 # PyTorch dataset
-torch_hrf = BinSegDataset(bobdb_hrf, split='train', transform=transforms_hrf)
-
+torch_hrf = BinSegDataset(bobdb_hrf, split="train", transform=transforms_hrf)
 
 
 # Merge
-dataset = torch.utils.data.ConcatDataset([torch_stare, torch_chase, torch_iostar, torch_hrf])
\ No newline at end of file
+dataset = torch.utils.data.ConcatDataset(
+    [torch_stare, torch_chase, torch_iostar, torch_hrf]
+)
diff --git a/bob/ip/binseg/configs/datasets/amdrivetest.py b/bob/ip/binseg/configs/datasets/amdrivetest.py
index 026ac236..5a6cc4af 100644
--- a/bob/ip/binseg/configs/datasets/amdrivetest.py
+++ b/bob/ip/binseg/configs/datasets/amdrivetest.py
@@ -14,60 +14,46 @@ import torch
 defaulttransforms = [ToTensor()]
 
 
-# CHASE_DB1 
-transforms_chase = Compose([      
-                        Resize(544)
-                        ,Crop(0,12,544,544)
-                        ,*defaulttransforms
-                    ])
+# CHASE_DB1
+transforms_chase = Compose([Resize(544), Crop(0, 12, 544, 544), *defaulttransforms])
 
 # bob.db.dataset init
-bobdb_chase = CHASEDB1(protocol = 'default')
+bobdb_chase = CHASEDB1(protocol="default")
 
 # PyTorch dataset
-torch_chase = BinSegDataset(bobdb_chase, split='test', transform=transforms_chase)
+torch_chase = BinSegDataset(bobdb_chase, split="test", transform=transforms_chase)
 
 
 # IOSTAR VESSEL
-transforms_iostar = Compose([  
-                        Resize(544)
-                        ,*defaulttransforms
-                    ])
+transforms_iostar = Compose([Resize(544), *defaulttransforms])
 
 # bob.db.dataset init
-bobdb_iostar = IOSTAR(protocol='default_vessel')
+bobdb_iostar = IOSTAR(protocol="default_vessel")
 
 # PyTorch dataset
-torch_iostar = BinSegDataset(bobdb_iostar, split='test', transform=transforms_iostar)
+torch_iostar = BinSegDataset(bobdb_iostar, split="test", transform=transforms_iostar)
 
 # STARE
-transforms = Compose([  
-                        Resize(471)
-                        ,Pad((0,37,0,36))
-                        ,*defaulttransforms
-                    ])
+transforms = Compose([Resize(471), Pad((0, 37, 0, 36)), *defaulttransforms])
 
 # bob.db.dataset init
-bobdb_stare = STARE(protocol = 'default')
+bobdb_stare = STARE(protocol="default")
 
 # PyTorch dataset
-torch_stare = BinSegDataset(bobdb_stare, split='test', transform=transforms)
+torch_stare = BinSegDataset(bobdb_stare, split="test", transform=transforms)
 
 
 # HRF
-transforms_hrf = Compose([  
-                        Resize((363))
-                        ,Pad((0,90,0,91))
-                        ,*defaulttransforms
-                    ])
+transforms_hrf = Compose([Resize((363)), Pad((0, 90, 0, 91)), *defaulttransforms])
 
 # bob.db.dataset init
-bobdb_hrf = HRF(protocol = 'default')
+bobdb_hrf = HRF(protocol="default")
 
 # PyTorch dataset
-torch_hrf = BinSegDataset(bobdb_hrf, split='test', transform=transforms_hrf)
-
+torch_hrf = BinSegDataset(bobdb_hrf, split="test", transform=transforms_hrf)
 
 
 # Merge
-dataset = torch.utils.data.ConcatDataset([torch_stare, torch_chase, torch_iostar, torch_hrf])
\ No newline at end of file
+dataset = torch.utils.data.ConcatDataset(
+    [torch_stare, torch_chase, torch_iostar, torch_hrf]
+)
diff --git a/bob/ip/binseg/configs/datasets/chasedb1.py b/bob/ip/binseg/configs/datasets/chasedb1.py
index 7fa0dc09..605fd0a7 100644
--- a/bob/ip/binseg/configs/datasets/chasedb1.py
+++ b/bob/ip/binseg/configs/datasets/chasedb1.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Crop(0,18,960,960)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Crop(0, 18, 960, 960),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = CHASEDB1(protocol = 'default')
+bobdb = CHASEDB1(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/chasedb11024.py b/bob/ip/binseg/configs/datasets/chasedb11024.py
index 028f10fb..27b7ab37 100644
--- a/bob/ip/binseg/configs/datasets/chasedb11024.py
+++ b/bob/ip/binseg/configs/datasets/chasedb11024.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,Crop(0,18,960,960)
-                        ,Resize(1024)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        Crop(0, 18, 960, 960),
+        Resize(1024),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = CHASEDB1(protocol = 'default')
+bobdb = CHASEDB1(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/chasedb11168.py b/bob/ip/binseg/configs/datasets/chasedb11168.py
index d221ea48..b85726e4 100644
--- a/bob/ip/binseg/configs/datasets/chasedb11168.py
+++ b/bob/ip/binseg/configs/datasets/chasedb11168.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,Crop(140,18,680,960)
-                        ,Resize(1168)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        Crop(140, 18, 680, 960),
+        Resize(1168),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = CHASEDB1(protocol = 'default')
+bobdb = CHASEDB1(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/chasedb1544.py b/bob/ip/binseg/configs/datasets/chasedb1544.py
index 9632d539..8ea0a9c6 100644
--- a/bob/ip/binseg/configs/datasets/chasedb1544.py
+++ b/bob/ip/binseg/configs/datasets/chasedb1544.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Resize(544)
-                        ,Crop(0,12,544,544)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Resize(544),
+        Crop(0, 12, 544, 544),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = CHASEDB1(protocol = 'default')
+bobdb = CHASEDB1(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/chasedb1608.py b/bob/ip/binseg/configs/datasets/chasedb1608.py
index 9a475cae..1800574b 100644
--- a/bob/ip/binseg/configs/datasets/chasedb1608.py
+++ b/bob/ip/binseg/configs/datasets/chasedb1608.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,CenterCrop((829,960))                    
-                        ,Resize(608)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        CenterCrop((829, 960)),
+        Resize(608),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = CHASEDB1(protocol = 'default')
+bobdb = CHASEDB1(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/chasedb1test.py b/bob/ip/binseg/configs/datasets/chasedb1test.py
index 4b267b0f..17be7aa1 100644
--- a/bob/ip/binseg/configs/datasets/chasedb1test.py
+++ b/bob/ip/binseg/configs/datasets/chasedb1test.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Crop(0,18,960,960)
-                        ,ToTensor()
-                    ])
+transforms = Compose([Crop(0, 18, 960, 960), ToTensor()])
 
 # bob.db.dataset init
-bobdb = CHASEDB1(protocol = 'default')
+bobdb = CHASEDB1(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/drionsdb.py b/bob/ip/binseg/configs/datasets/drionsdb.py
index cd33c1d2..0a03dadf 100644
--- a/bob/ip/binseg/configs/datasets/drionsdb.py
+++ b/bob/ip/binseg/configs/datasets/drionsdb.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((4,8,4,8))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Pad((4, 8, 4, 8)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = DRIONS(protocol = 'default')
+bobdb = DRIONS(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/drionsdbtest.py b/bob/ip/binseg/configs/datasets/drionsdbtest.py
index b65100a6..75bcbb58 100644
--- a/bob/ip/binseg/configs/datasets/drionsdbtest.py
+++ b/bob/ip/binseg/configs/datasets/drionsdbtest.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((4,8,4,8))
-                        ,ToTensor()
-                    ])
+transforms = Compose([Pad((4, 8, 4, 8)), ToTensor()])
 
 # bob.db.dataset init
-bobdb = DRIONS(protocol = 'default')
+bobdb = DRIONS(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/dristhigs1cup.py b/bob/ip/binseg/configs/datasets/dristhigs1cup.py
index f7a69dad..a1da30ad 100644
--- a/bob/ip/binseg/configs/datasets/dristhigs1cup.py
+++ b/bob/ip/binseg/configs/datasets/dristhigs1cup.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        CenterCrop((1760,2048))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        CenterCrop((1760, 2048)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = DRISHTI(protocol = 'default_cup')
+bobdb = DRISHTI(protocol="default_cup")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/dristhigs1cuptest.py b/bob/ip/binseg/configs/datasets/dristhigs1cuptest.py
index 5c2b634e..e35eabf0 100644
--- a/bob/ip/binseg/configs/datasets/dristhigs1cuptest.py
+++ b/bob/ip/binseg/configs/datasets/dristhigs1cuptest.py
@@ -6,13 +6,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        CenterCrop((1760,2048))
-                        ,ToTensor()
-                    ])
+transforms = Compose([CenterCrop((1760, 2048)), ToTensor()])
 
 # bob.db.dataset init
-bobdb = DRISHTI(protocol = 'default_cup')
+bobdb = DRISHTI(protocol="default_cup")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/dristhigs1od.py b/bob/ip/binseg/configs/datasets/dristhigs1od.py
index 0bd483c1..3421ebe6 100644
--- a/bob/ip/binseg/configs/datasets/dristhigs1od.py
+++ b/bob/ip/binseg/configs/datasets/dristhigs1od.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        CenterCrop((1760,2048))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        CenterCrop((1760, 2048)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = DRISHTI(protocol = 'default_od')
+bobdb = DRISHTI(protocol="default_od")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/dristhigs1odtest.py b/bob/ip/binseg/configs/datasets/dristhigs1odtest.py
index ab1edd65..1fdc8a28 100644
--- a/bob/ip/binseg/configs/datasets/dristhigs1odtest.py
+++ b/bob/ip/binseg/configs/datasets/dristhigs1odtest.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        CenterCrop((1760,2048))
-                        ,ToTensor()
-                    ])
+transforms = Compose([CenterCrop((1760, 2048)), ToTensor()])
 
 # bob.db.dataset init
-bobdb = DRISHTI(protocol = 'default_od')
+bobdb = DRISHTI(protocol="default_od")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive.py b/bob/ip/binseg/configs/datasets/drive.py
index 5b6fa356..04819dc0 100644
--- a/bob/ip/binseg/configs/datasets/drive.py
+++ b/bob/ip/binseg/configs/datasets/drive.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        CenterCrop((544,544))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        CenterCrop((544, 544)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = DRIVE(protocol = 'default')
+bobdb = DRIVE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive1024.py b/bob/ip/binseg/configs/datasets/drive1024.py
index dae199f5..ea99feb0 100644
--- a/bob/ip/binseg/configs/datasets/drive1024.py
+++ b/bob/ip/binseg/configs/datasets/drive1024.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,CenterCrop((540,540))
-                        ,Resize(1024)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        CenterCrop((540, 540)),
+        Resize(1024),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = DRIVE(protocol = 'default')
+bobdb = DRIVE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive1024test.py b/bob/ip/binseg/configs/datasets/drive1024test.py
index 9e9cb3e9..c409dae5 100644
--- a/bob/ip/binseg/configs/datasets/drive1024test.py
+++ b/bob/ip/binseg/configs/datasets/drive1024test.py
@@ -7,14 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        CenterCrop((540,540))
-                        ,Resize(1024)
-                        ,ToTensor()
-                    ])
+transforms = Compose([CenterCrop((540, 540)), Resize(1024), ToTensor()])
 
 # bob.db.dataset init
-bobdb = DRIVE(protocol = 'default')
+bobdb = DRIVE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive1168.py b/bob/ip/binseg/configs/datasets/drive1168.py
index 3f0f0537..f4a51d95 100644
--- a/bob/ip/binseg/configs/datasets/drive1168.py
+++ b/bob/ip/binseg/configs/datasets/drive1168.py
@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,Crop(75,10,416,544)
-                        ,Pad((21,0,22,0))
-                        ,Resize(1168)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        Crop(75, 10, 416, 544),
+        Pad((21, 0, 22, 0)),
+        Resize(1168),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = DRIVE(protocol = 'default')
+bobdb = DRIVE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive608.py b/bob/ip/binseg/configs/datasets/drive608.py
index 65bc5e65..e251930b 100644
--- a/bob/ip/binseg/configs/datasets/drive608.py
+++ b/bob/ip/binseg/configs/datasets/drive608.py
@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,CenterCrop((470,544))
-                        ,Pad((10,9,10,8))
-                        ,Resize(608)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        CenterCrop((470, 544)),
+        Pad((10, 9, 10, 8)),
+        Resize(608),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = DRIVE(protocol = 'default')
+bobdb = DRIVE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/drive960.py b/bob/ip/binseg/configs/datasets/drive960.py
index ab3ac5a9..1a29eeee 100644
--- a/bob/ip/binseg/configs/datasets/drive960.py
+++ b/bob/ip/binseg/configs/datasets/drive960.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,CenterCrop((544,544))
-                        ,Resize(960)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        CenterCrop((544, 544)),
+        Resize(960),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = DRIVE(protocol = 'default')
+bobdb = DRIVE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/drivechasedb1iostarhrf608.py b/bob/ip/binseg/configs/datasets/drivechasedb1iostarhrf608.py
index 5ab57b8a..a3410d46 100644
--- a/bob/ip/binseg/configs/datasets/drivechasedb1iostarhrf608.py
+++ b/bob/ip/binseg/configs/datasets/drivechasedb1iostarhrf608.py
@@ -7,4 +7,4 @@ import torch
 #### Config ####
 
 # PyTorch dataset
-dataset = torch.utils.data.ConcatDataset([drive,chase,iostar,hrf])
\ No newline at end of file
+dataset = torch.utils.data.ConcatDataset([drive, chase, iostar, hrf])
diff --git a/bob/ip/binseg/configs/datasets/drivechasedb1iostarhrf608sslstare.py b/bob/ip/binseg/configs/datasets/drivechasedb1iostarhrf608sslstare.py
index 928452f4..37c60f8a 100644
--- a/bob/ip/binseg/configs/datasets/drivechasedb1iostarhrf608sslstare.py
+++ b/bob/ip/binseg/configs/datasets/drivechasedb1iostarhrf608sslstare.py
@@ -5,30 +5,38 @@ from bob.ip.binseg.configs.datasets.hrf608 import dataset as hrf
 from bob.db.stare import Database as STARE
 from bob.ip.binseg.data.transforms import *
 import torch
-from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset
+from bob.ip.binseg.data.binsegdataset import (
+    BinSegDataset,
+    SSLBinSegDataset,
+    UnLabeledBinSegDataset,
+)
 
 
 #### Config ####
 
 # PyTorch dataset
-labeled_dataset = torch.utils.data.ConcatDataset([drive,chase,iostar,hrf])
+labeled_dataset = torch.utils.data.ConcatDataset([drive, chase, iostar, hrf])
 
 #### Unlabeled STARE TRAIN ####
-unlabeled_transforms = Compose([  
-                        Pad((2,1,2,2))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+unlabeled_transforms = Compose(
+    [
+        Pad((2, 1, 2, 2)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-starebobdb = STARE(protocol = 'default')
+starebobdb = STARE(protocol="default")
 
 # PyTorch dataset
-unlabeled_dataset = UnLabeledBinSegDataset(starebobdb, split='train', transform=unlabeled_transforms)
+unlabeled_dataset = UnLabeledBinSegDataset(
+    starebobdb, split="train", transform=unlabeled_transforms
+)
 
 # SSL Dataset
 
-dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
\ No newline at end of file
+dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb11168.py b/bob/ip/binseg/configs/datasets/drivestarechasedb11168.py
index 0e36eff7..1a7274ee 100644
--- a/bob/ip/binseg/configs/datasets/drivestarechasedb11168.py
+++ b/bob/ip/binseg/configs/datasets/drivestarechasedb11168.py
@@ -6,4 +6,4 @@ import torch
 #### Config ####
 
 # PyTorch dataset
-dataset = torch.utils.data.ConcatDataset([drive,stare,chase])
\ No newline at end of file
+dataset = torch.utils.data.ConcatDataset([drive, stare, chase])
diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py
index 74628abe..4bff79a5 100644
--- a/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py
+++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py
@@ -7,4 +7,4 @@ import torch
 #### Config ####
 
 # PyTorch dataset
-dataset = torch.utils.data.ConcatDataset([drive,stare,hrf,chase])
+dataset = torch.utils.data.ConcatDataset([drive, stare, hrf, chase])
diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024ssliostar.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024ssliostar.py
index dd9016b7..f9a9107f 100644
--- a/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024ssliostar.py
+++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024ssliostar.py
@@ -5,29 +5,31 @@ from bob.ip.binseg.configs.datasets.chasedb11024 import dataset as chasedb
 from bob.db.iostar import Database as IOSTAR
 from bob.ip.binseg.data.transforms import *
 import torch
-from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset
+from bob.ip.binseg.data.binsegdataset import (
+    BinSegDataset,
+    SSLBinSegDataset,
+    UnLabeledBinSegDataset,
+)
 
 
 #### Config ####
 
 # PyTorch dataset
-labeled_dataset = torch.utils.data.ConcatDataset([drive,stare,hrf,chasedb])
+labeled_dataset = torch.utils.data.ConcatDataset([drive, stare, hrf, chasedb])
 
 #### Unlabeled IOSTAR Train ####
-unlabeled_transforms = Compose([  
-                        RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+unlabeled_transforms = Compose(
+    [RandomHFlip(), RandomVFlip(), RandomRotation(), ColorJitter(), ToTensor()]
+)
 
 # bob.db.dataset init
-iostarbobdb = IOSTAR(protocol='default_vessel')
+iostarbobdb = IOSTAR(protocol="default_vessel")
 
 # PyTorch dataset
-unlabeled_dataset = UnLabeledBinSegDataset(iostarbobdb, split='train', transform=unlabeled_transforms)
+unlabeled_dataset = UnLabeledBinSegDataset(
+    iostarbobdb, split="train", transform=unlabeled_transforms
+)
 
 # SSL Dataset
 
-dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
\ No newline at end of file
+dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py
index 62e2972d..3bb4f28e 100644
--- a/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py
+++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168.py
@@ -7,4 +7,4 @@ import torch
 #### Config ####
 
 # PyTorch dataset
-dataset = torch.utils.data.ConcatDataset([drive,stare,chase,iostar])
\ No newline at end of file
+dataset = torch.utils.data.ConcatDataset([drive, stare, chase, iostar])
diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168sslhrf.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168sslhrf.py
index 01705e15..3b3d2f2e 100644
--- a/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168sslhrf.py
+++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1iostar1168sslhrf.py
@@ -5,31 +5,39 @@ from bob.ip.binseg.configs.datasets.iostarvessel1168 import dataset as iostar
 from bob.db.hrf import Database as HRF
 from bob.ip.binseg.data.transforms import *
 import torch
-from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset
+from bob.ip.binseg.data.binsegdataset import (
+    BinSegDataset,
+    SSLBinSegDataset,
+    UnLabeledBinSegDataset,
+)
 
 
 #### Config ####
 
 # PyTorch dataset
-labeled_dataset = torch.utils.data.ConcatDataset([drive,stare,iostar,chasedb])
+labeled_dataset = torch.utils.data.ConcatDataset([drive, stare, iostar, chasedb])
 
 #### Unlabeled HRF TRAIN ####
-unlabeled_transforms = Compose([  
-                        RandomRotation()
-                        ,Crop(0,108,2336,3296)
-                        ,Resize((1168))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+unlabeled_transforms = Compose(
+    [
+        RandomRotation(),
+        Crop(0, 108, 2336, 3296),
+        Resize((1168)),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-hrfbobdb = HRF(protocol='default')
+hrfbobdb = HRF(protocol="default")
 
 # PyTorch dataset
-unlabeled_dataset = UnLabeledBinSegDataset(hrfbobdb, split='train', transform=unlabeled_transforms)
+unlabeled_dataset = UnLabeledBinSegDataset(
+    hrfbobdb, split="train", transform=unlabeled_transforms
+)
 
 # SSL Dataset
 
-dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
\ No newline at end of file
+dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
diff --git a/bob/ip/binseg/configs/datasets/drivestareiostarhrf960.py b/bob/ip/binseg/configs/datasets/drivestareiostarhrf960.py
index 9343cf22..5ea90c4d 100644
--- a/bob/ip/binseg/configs/datasets/drivestareiostarhrf960.py
+++ b/bob/ip/binseg/configs/datasets/drivestareiostarhrf960.py
@@ -7,4 +7,4 @@ import torch
 #### Config ####
 
 # PyTorch dataset
-dataset = torch.utils.data.ConcatDataset([drive,stare,hrf,iostar])
+dataset = torch.utils.data.ConcatDataset([drive, stare, hrf, iostar])
diff --git a/bob/ip/binseg/configs/datasets/drivestareiostarhrf960sslchase.py b/bob/ip/binseg/configs/datasets/drivestareiostarhrf960sslchase.py
index a7bd4576..46a351d7 100644
--- a/bob/ip/binseg/configs/datasets/drivestareiostarhrf960sslchase.py
+++ b/bob/ip/binseg/configs/datasets/drivestareiostarhrf960sslchase.py
@@ -6,30 +6,38 @@ from bob.db.chasedb1 import Database as CHASE
 from bob.db.hrf import Database as HRF
 from bob.ip.binseg.data.transforms import *
 import torch
-from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset
+from bob.ip.binseg.data.binsegdataset import (
+    BinSegDataset,
+    SSLBinSegDataset,
+    UnLabeledBinSegDataset,
+)
 
 
 #### Config ####
 
 # PyTorch dataset
-labeled_dataset = torch.utils.data.ConcatDataset([drive,stare,hrf,iostar])
+labeled_dataset = torch.utils.data.ConcatDataset([drive, stare, hrf, iostar])
 
 #### Unlabeled CHASE TRAIN ####
-unlabeled_transforms = Compose([  
-                        Crop(0,18,960,960)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+unlabeled_transforms = Compose(
+    [
+        Crop(0, 18, 960, 960),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-chasebobdb = CHASE(protocol = 'default')
+chasebobdb = CHASE(protocol="default")
 
 # PyTorch dataset
-unlabeled_dataset = UnLabeledBinSegDataset(chasebobdb, split='train', transform=unlabeled_transforms)
+unlabeled_dataset = UnLabeledBinSegDataset(
+    chasebobdb, split="train", transform=unlabeled_transforms
+)
 
 # SSL Dataset
 
-dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
\ No newline at end of file
+dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
diff --git a/bob/ip/binseg/configs/datasets/drivetest.py b/bob/ip/binseg/configs/datasets/drivetest.py
index 230598dc..c6bff8ca 100644
--- a/bob/ip/binseg/configs/datasets/drivetest.py
+++ b/bob/ip/binseg/configs/datasets/drivetest.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        CenterCrop((544,544))
-                        ,ToTensor()
-                    ])
+transforms = Compose([CenterCrop((544, 544)), ToTensor()])
 
 # bob.db.dataset init
-bobdb = DRIVE(protocol = 'default')
+bobdb = DRIVE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/hrf.py b/bob/ip/binseg/configs/datasets/hrf.py
index cb008f7d..b1330209 100644
--- a/bob/ip/binseg/configs/datasets/hrf.py
+++ b/bob/ip/binseg/configs/datasets/hrf.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Crop(0,108,2336,3296)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Crop(0, 108, 2336, 3296),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = HRF(protocol = 'default')
+bobdb = HRF(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/hrf1024.py b/bob/ip/binseg/configs/datasets/hrf1024.py
index 48168445..e07bd883 100644
--- a/bob/ip/binseg/configs/datasets/hrf1024.py
+++ b/bob/ip/binseg/configs/datasets/hrf1024.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((0,584,0,584))                    
-                        ,Resize((1024))
-                        ,RandomRotation()
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Pad((0, 584, 0, 584)),
+        Resize((1024)),
+        RandomRotation(),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = HRF(protocol = 'default')
+bobdb = HRF(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/hrf1168.py b/bob/ip/binseg/configs/datasets/hrf1168.py
index 4d0c4d9e..4467c02c 100644
--- a/bob/ip/binseg/configs/datasets/hrf1168.py
+++ b/bob/ip/binseg/configs/datasets/hrf1168.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Crop(0,108,2336,3296)
-                        ,Resize((1168))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Crop(0, 108, 2336, 3296),
+        Resize((1168)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = HRF(protocol = 'default')
+bobdb = HRF(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/hrf1168test.py b/bob/ip/binseg/configs/datasets/hrf1168test.py
index 86014b75..86d968f4 100644
--- a/bob/ip/binseg/configs/datasets/hrf1168test.py
+++ b/bob/ip/binseg/configs/datasets/hrf1168test.py
@@ -7,14 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Crop(0,108,2336,3296)
-                        ,Resize((1168))
-                        ,ToTensor()
-                    ])
+transforms = Compose([Crop(0, 108, 2336, 3296), Resize((1168)), ToTensor()])
 
 # bob.db.dataset init
-bobdb = HRF(protocol = 'default')
+bobdb = HRF(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/hrf544.py b/bob/ip/binseg/configs/datasets/hrf544.py
index 0e2cc051..6cd9ccf0 100644
--- a/bob/ip/binseg/configs/datasets/hrf544.py
+++ b/bob/ip/binseg/configs/datasets/hrf544.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Resize((363))
-                        ,Pad((0,90,0,91))
-                        ,RandomRotation()
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Resize((363)),
+        Pad((0, 90, 0, 91)),
+        RandomRotation(),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = HRF(protocol = 'default')
+bobdb = HRF(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/hrf544test.py b/bob/ip/binseg/configs/datasets/hrf544test.py
index 86da428b..45a3e61d 100644
--- a/bob/ip/binseg/configs/datasets/hrf544test.py
+++ b/bob/ip/binseg/configs/datasets/hrf544test.py
@@ -7,14 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Resize((363))
-                        ,Pad((0,90,0,91))
-                        ,ToTensor()
-                    ])
+transforms = Compose([Resize((363)), Pad((0, 90, 0, 91)), ToTensor()])
 
 # bob.db.dataset init
-bobdb = HRF(protocol = 'default')
+bobdb = HRF(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/hrf608.py b/bob/ip/binseg/configs/datasets/hrf608.py
index b26e772a..7b232ea7 100644
--- a/bob/ip/binseg/configs/datasets/hrf608.py
+++ b/bob/ip/binseg/configs/datasets/hrf608.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((0,345,0,345))
-                        ,Resize(608)
-                        ,RandomRotation()
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Pad((0, 345, 0, 345)),
+        Resize(608),
+        RandomRotation(),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = HRF(protocol = 'default')
+bobdb = HRF(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/hrf960.py b/bob/ip/binseg/configs/datasets/hrf960.py
index dd43cf00..059a831c 100644
--- a/bob/ip/binseg/configs/datasets/hrf960.py
+++ b/bob/ip/binseg/configs/datasets/hrf960.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((0,584,0,584))                    
-                        ,Resize((960))
-                        ,RandomRotation()
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Pad((0, 584, 0, 584)),
+        Resize((960)),
+        RandomRotation(),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = HRF(protocol = 'default')
+bobdb = HRF(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/hrftest.py b/bob/ip/binseg/configs/datasets/hrftest.py
index 45f95272..d7c32c2a 100644
--- a/bob/ip/binseg/configs/datasets/hrftest.py
+++ b/bob/ip/binseg/configs/datasets/hrftest.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Crop(0,108,2336,3296)
-                        ,ToTensor()
-                    ])
+transforms = Compose([Crop(0, 108, 2336, 3296), ToTensor()])
 
 # bob.db.dataset init
-bobdb = HRF(protocol = 'default')
+bobdb = HRF(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/imagefolder.py b/bob/ip/binseg/configs/datasets/imagefolder.py
index de90eeb0..d73de0c4 100644
--- a/bob/ip/binseg/configs/datasets/imagefolder.py
+++ b/bob/ip/binseg/configs/datasets/imagefolder.py
@@ -7,15 +7,17 @@ from bob.ip.binseg.data.imagefolder import ImageFolder
 #### Config ####
 
 # add your transforms below
-transforms = Compose([
-                        CenterCrop((544,544))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        CenterCrop((544, 544)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # PyTorch dataset
-path = '/path/to/dataset'
-dataset = ImageFolder(path,transform=transforms)
+path = "/path/to/dataset"
+dataset = ImageFolder(path, transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/imagefolderinference.py b/bob/ip/binseg/configs/datasets/imagefolderinference.py
index 634566b9..869f7e51 100644
--- a/bob/ip/binseg/configs/datasets/imagefolderinference.py
+++ b/bob/ip/binseg/configs/datasets/imagefolderinference.py
@@ -7,12 +7,8 @@ from bob.ip.binseg.data.imagefolderinference import ImageFolderInference
 #### Config ####
 
 # add your transforms below
-transforms = Compose([
-                        ToRGB(),
-                        CenterCrop((544,544))
-                        ,ToTensor()
-                    ])
+transforms = Compose([ToRGB(), CenterCrop((544, 544)), ToTensor()])
 
 # PyTorch dataset
-path = '/path/to/folder/containing/images'
-dataset = ImageFolderInference(path,transform=transforms)
+path = "/path/to/folder/containing/images"
+dataset = ImageFolderInference(path, transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/imagefoldertest.py b/bob/ip/binseg/configs/datasets/imagefoldertest.py
index 15d9b038..474b0384 100644
--- a/bob/ip/binseg/configs/datasets/imagefoldertest.py
+++ b/bob/ip/binseg/configs/datasets/imagefoldertest.py
@@ -7,11 +7,8 @@ from bob.ip.binseg.data.imagefolder import ImageFolder
 #### Config ####
 
 # add your transforms below
-transforms = Compose([  
-                        CenterCrop((544,544))
-                        ,ToTensor()
-                    ])
+transforms = Compose([CenterCrop((544, 544)), ToTensor()])
 
 # PyTorch dataset
-path = '/path/to/testdataset'
-dataset = ImageFolder(path,transform=transforms)
+path = "/path/to/testdataset"
+dataset = ImageFolder(path, transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/iostarod.py b/bob/ip/binseg/configs/datasets/iostarod.py
index 334df2a4..e043f416 100644
--- a/bob/ip/binseg/configs/datasets/iostarod.py
+++ b/bob/ip/binseg/configs/datasets/iostarod.py
@@ -7,16 +7,12 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [RandomHFlip(), RandomVFlip(), RandomRotation(), ColorJitter(), ToTensor()]
+)
 
 # bob.db.dataset init
-bobdb = IOSTAR(protocol='default_od')
+bobdb = IOSTAR(protocol="default_od")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/iostarodtest.py b/bob/ip/binseg/configs/datasets/iostarodtest.py
index ba064507..a4e9b4c8 100644
--- a/bob/ip/binseg/configs/datasets/iostarodtest.py
+++ b/bob/ip/binseg/configs/datasets/iostarodtest.py
@@ -7,12 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        ToTensor()
-                    ])
+transforms = Compose([ToTensor()])
 
 # bob.db.dataset init
-bobdb = IOSTAR(protocol='default_od')
+bobdb = IOSTAR(protocol="default_od")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/iostarvessel.py b/bob/ip/binseg/configs/datasets/iostarvessel.py
index ded01bb4..5fa8ebb6 100644
--- a/bob/ip/binseg/configs/datasets/iostarvessel.py
+++ b/bob/ip/binseg/configs/datasets/iostarvessel.py
@@ -7,16 +7,12 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [RandomHFlip(), RandomVFlip(), RandomRotation(), ColorJitter(), ToTensor()]
+)
 
 # bob.db.dataset init
-bobdb = IOSTAR(protocol='default_vessel')
+bobdb = IOSTAR(protocol="default_vessel")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/iostarvessel1168.py b/bob/ip/binseg/configs/datasets/iostarvessel1168.py
index 5da5ed1e..c58bdf66 100644
--- a/bob/ip/binseg/configs/datasets/iostarvessel1168.py
+++ b/bob/ip/binseg/configs/datasets/iostarvessel1168.py
@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,Crop(144,0,768,1024)
-                        ,Pad((30,0,30,0))
-                        ,Resize(1168)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        Crop(144, 0, 768, 1024),
+        Pad((30, 0, 30, 0)),
+        Resize(1168),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = IOSTAR(protocol='default_vessel')
+bobdb = IOSTAR(protocol="default_vessel")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/iostarvessel544.py b/bob/ip/binseg/configs/datasets/iostarvessel544.py
index aa03abe2..31c51395 100644
--- a/bob/ip/binseg/configs/datasets/iostarvessel544.py
+++ b/bob/ip/binseg/configs/datasets/iostarvessel544.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Resize(544)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Resize(544),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = IOSTAR(protocol='default_vessel')
+bobdb = IOSTAR(protocol="default_vessel")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/iostarvessel544test.py b/bob/ip/binseg/configs/datasets/iostarvessel544test.py
index e3ccd854..321579bf 100644
--- a/bob/ip/binseg/configs/datasets/iostarvessel544test.py
+++ b/bob/ip/binseg/configs/datasets/iostarvessel544test.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Resize(544)
-                        ,ToTensor()
-                    ])
+transforms = Compose([Resize(544), ToTensor()])
 
 # bob.db.dataset init
-bobdb = IOSTAR(protocol='default_vessel')
+bobdb = IOSTAR(protocol="default_vessel")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/iostarvessel608.py b/bob/ip/binseg/configs/datasets/iostarvessel608.py
index 7fce4507..04280727 100644
--- a/bob/ip/binseg/configs/datasets/iostarvessel608.py
+++ b/bob/ip/binseg/configs/datasets/iostarvessel608.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((81,0,81,0))
-                        ,Resize(608)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Pad((81, 0, 81, 0)),
+        Resize(608),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = IOSTAR(protocol='default_vessel')
+bobdb = IOSTAR(protocol="default_vessel")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/iostarvessel960.py b/bob/ip/binseg/configs/datasets/iostarvessel960.py
index 32feec85..600a8cff 100644
--- a/bob/ip/binseg/configs/datasets/iostarvessel960.py
+++ b/bob/ip/binseg/configs/datasets/iostarvessel960.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Resize(960)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Resize(960),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = IOSTAR(protocol='default_vessel')
+bobdb = IOSTAR(protocol="default_vessel")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/iostarvesseltest.py b/bob/ip/binseg/configs/datasets/iostarvesseltest.py
index d8fe1371..18ec9f2e 100644
--- a/bob/ip/binseg/configs/datasets/iostarvesseltest.py
+++ b/bob/ip/binseg/configs/datasets/iostarvesseltest.py
@@ -7,12 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        ToTensor()
-                    ])
+transforms = Compose([ToTensor()])
 
 # bob.db.dataset init
-bobdb = IOSTAR(protocol='default_vessel')
+bobdb = IOSTAR(protocol="default_vessel")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/refugecup.py b/bob/ip/binseg/configs/datasets/refugecup.py
index 9efac529..1100f508 100644
--- a/bob/ip/binseg/configs/datasets/refugecup.py
+++ b/bob/ip/binseg/configs/datasets/refugecup.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Resize((1539))
-                        ,Pad((21,46,22,47))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Resize((1539)),
+        Pad((21, 46, 22, 47)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = REFUGE(protocol = 'default_cup')
+bobdb = REFUGE(protocol="default_cup")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/refugecuptest.py b/bob/ip/binseg/configs/datasets/refugecuptest.py
index 8ff916e3..5e600307 100644
--- a/bob/ip/binseg/configs/datasets/refugecuptest.py
+++ b/bob/ip/binseg/configs/datasets/refugecuptest.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        CenterCrop(1632)
-                        ,ToTensor()
-                    ])
+transforms = Compose([CenterCrop(1632), ToTensor()])
 
 # bob.db.dataset init
-bobdb = REFUGE(protocol = 'default_cup')
+bobdb = REFUGE(protocol="default_cup")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/refugeod.py b/bob/ip/binseg/configs/datasets/refugeod.py
index 5faaf05a..4435640e 100644
--- a/bob/ip/binseg/configs/datasets/refugeod.py
+++ b/bob/ip/binseg/configs/datasets/refugeod.py
@@ -7,18 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Resize((1539))
-                        ,Pad((21,46,22,47))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Resize((1539)),
+        Pad((21, 46, 22, 47)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = REFUGE(protocol = 'default_od')
+bobdb = REFUGE(protocol="default_od")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/refugeodtest.py b/bob/ip/binseg/configs/datasets/refugeodtest.py
index 30085a2f..b77d3e28 100644
--- a/bob/ip/binseg/configs/datasets/refugeodtest.py
+++ b/bob/ip/binseg/configs/datasets/refugeodtest.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        CenterCrop(1632)
-                        ,ToTensor()
-                    ])
+transforms = Compose([CenterCrop(1632), ToTensor()])
 
 # bob.db.dataset init
-bobdb = REFUGE(protocol = 'default_od')
+bobdb = REFUGE(protocol="default_od")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/rimoner3cup.py b/bob/ip/binseg/configs/datasets/rimoner3cup.py
index 47b62ba0..0fad0285 100644
--- a/bob/ip/binseg/configs/datasets/rimoner3cup.py
+++ b/bob/ip/binseg/configs/datasets/rimoner3cup.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((8,8,8,8))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Pad((8, 8, 8, 8)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = RIMONER3(protocol = 'default_cup')
+bobdb = RIMONER3(protocol="default_cup")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/rimoner3cuptest.py b/bob/ip/binseg/configs/datasets/rimoner3cuptest.py
index 9f227be8..86465331 100644
--- a/bob/ip/binseg/configs/datasets/rimoner3cuptest.py
+++ b/bob/ip/binseg/configs/datasets/rimoner3cuptest.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((8,8,8,8))
-                        ,ToTensor()
-                    ])
+transforms = Compose([Pad((8, 8, 8, 8)), ToTensor()])
 
 # bob.db.dataset init
-bobdb = RIMONER3(protocol = 'default_cup')
+bobdb = RIMONER3(protocol="default_cup")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/rimoner3od.py b/bob/ip/binseg/configs/datasets/rimoner3od.py
index 4905bec3..a465342a 100644
--- a/bob/ip/binseg/configs/datasets/rimoner3od.py
+++ b/bob/ip/binseg/configs/datasets/rimoner3od.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((8,8,8,8))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Pad((8, 8, 8, 8)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = RIMONER3(protocol = 'default_od')
+bobdb = RIMONER3(protocol="default_od")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/rimoner3odtest.py b/bob/ip/binseg/configs/datasets/rimoner3odtest.py
index 390f20d7..6e4dd1a6 100644
--- a/bob/ip/binseg/configs/datasets/rimoner3odtest.py
+++ b/bob/ip/binseg/configs/datasets/rimoner3odtest.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((8,8,8,8))
-                        ,ToTensor()
-                    ])
+transforms = Compose([Pad((8, 8, 8, 8)), ToTensor()])
 
 # bob.db.dataset init
-bobdb = RIMONER3(protocol = 'default_od')
+bobdb = RIMONER3(protocol="default_od")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/stare.py b/bob/ip/binseg/configs/datasets/stare.py
index f2c784a9..0f93cc78 100644
--- a/bob/ip/binseg/configs/datasets/stare.py
+++ b/bob/ip/binseg/configs/datasets/stare.py
@@ -7,17 +7,19 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((2,1,2,2))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        Pad((2, 1, 2, 2)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = STARE(protocol = 'default')
+bobdb = STARE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/stare1024.py b/bob/ip/binseg/configs/datasets/stare1024.py
index 8f6df507..a8931ff2 100644
--- a/bob/ip/binseg/configs/datasets/stare1024.py
+++ b/bob/ip/binseg/configs/datasets/stare1024.py
@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,Pad((0,32,0,32))
-                        ,Resize(1024)
-                        ,CenterCrop(1024)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        Pad((0, 32, 0, 32)),
+        Resize(1024),
+        CenterCrop(1024),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = STARE(protocol = 'default')
+bobdb = STARE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/stare1168.py b/bob/ip/binseg/configs/datasets/stare1168.py
index 77e934bf..516a9267 100644
--- a/bob/ip/binseg/configs/datasets/stare1168.py
+++ b/bob/ip/binseg/configs/datasets/stare1168.py
@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,Crop(50,0,500,705)
-                        ,Resize(1168)
-                        ,Pad((1,0,1,0))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        Crop(50, 0, 500, 705),
+        Resize(1168),
+        Pad((1, 0, 1, 0)),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = STARE(protocol = 'default')
+bobdb = STARE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/stare544.py b/bob/ip/binseg/configs/datasets/stare544.py
index f03fcefb..b972d1f3 100644
--- a/bob/ip/binseg/configs/datasets/stare544.py
+++ b/bob/ip/binseg/configs/datasets/stare544.py
@@ -7,17 +7,20 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  RandomRotation()
-                        ,Resize(471)
-                        ,Pad((0,37,0,36))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        Resize(471),
+        Pad((0, 37, 0, 36)),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = STARE(protocol = 'default')
+bobdb = STARE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/stare960.py b/bob/ip/binseg/configs/datasets/stare960.py
index 0d1ed788..211a8448 100644
--- a/bob/ip/binseg/configs/datasets/stare960.py
+++ b/bob/ip/binseg/configs/datasets/stare960.py
@@ -7,19 +7,21 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        RandomRotation()
-                        ,Pad((0,32,0,32))
-                        ,Resize(960)
-                        ,CenterCrop(960)
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+transforms = Compose(
+    [
+        RandomRotation(),
+        Pad((0, 32, 0, 32)),
+        Resize(960),
+        CenterCrop(960),
+        RandomHFlip(),
+        RandomVFlip(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-bobdb = STARE(protocol = 'default')
+bobdb = STARE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="train", transform=transforms)
diff --git a/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544.py b/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544.py
index 72349b3b..200d7842 100644
--- a/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544.py
+++ b/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544.py
@@ -7,4 +7,4 @@ import torch
 #### Config ####
 
 # PyTorch dataset
-dataset = torch.utils.data.ConcatDataset([stare,chase,hrf,iostar])
\ No newline at end of file
+dataset = torch.utils.data.ConcatDataset([stare, chase, hrf, iostar])
diff --git a/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py b/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py
index 3a5e3008..f126871f 100644
--- a/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py
+++ b/bob/ip/binseg/configs/datasets/starechasedb1iostarhrf544ssldrive.py
@@ -5,30 +5,38 @@ from bob.ip.binseg.configs.datasets.hrf544 import dataset as hrf
 from bob.db.drive import Database as DRIVE
 from bob.ip.binseg.data.transforms import *
 import torch
-from bob.ip.binseg.data.binsegdataset import BinSegDataset, SSLBinSegDataset, UnLabeledBinSegDataset
+from bob.ip.binseg.data.binsegdataset import (
+    BinSegDataset,
+    SSLBinSegDataset,
+    UnLabeledBinSegDataset,
+)
 
 
 #### Config ####
 
 # PyTorch dataset
-labeled_dataset = torch.utils.data.ConcatDataset([stare,chase,iostar,hrf])
+labeled_dataset = torch.utils.data.ConcatDataset([stare, chase, iostar, hrf])
 
 #### Unlabeled STARE TRAIN ####
-unlabeled_transforms = Compose([  
-                        CenterCrop((544,544))
-                        ,RandomHFlip()
-                        ,RandomVFlip()
-                        ,RandomRotation()
-                        ,ColorJitter()
-                        ,ToTensor()
-                    ])
+unlabeled_transforms = Compose(
+    [
+        CenterCrop((544, 544)),
+        RandomHFlip(),
+        RandomVFlip(),
+        RandomRotation(),
+        ColorJitter(),
+        ToTensor(),
+    ]
+)
 
 # bob.db.dataset init
-drivebobdb = DRIVE(protocol = 'default')
+drivebobdb = DRIVE(protocol="default")
 
 # PyTorch dataset
-unlabeled_dataset = UnLabeledBinSegDataset(drivebobdb, split='train', transform=unlabeled_transforms)
+unlabeled_dataset = UnLabeledBinSegDataset(
+    drivebobdb, split="train", transform=unlabeled_transforms
+)
 
 # SSL Dataset
 
-dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
\ No newline at end of file
+dataset = SSLBinSegDataset(labeled_dataset, unlabeled_dataset)
diff --git a/bob/ip/binseg/configs/datasets/staretest.py b/bob/ip/binseg/configs/datasets/staretest.py
index aab80b9b..ac03e2a7 100644
--- a/bob/ip/binseg/configs/datasets/staretest.py
+++ b/bob/ip/binseg/configs/datasets/staretest.py
@@ -7,13 +7,10 @@ from bob.ip.binseg.data.binsegdataset import BinSegDataset
 
 #### Config ####
 
-transforms = Compose([  
-                        Pad((2,1,2,2))
-                        ,ToTensor()
-                    ])
+transforms = Compose([Pad((2, 1, 2, 2)), ToTensor()])
 
 # bob.db.dataset init
-bobdb = STARE(protocol = 'default')
+bobdb = STARE(protocol="default")
 
 # PyTorch dataset
-dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
+dataset = BinSegDataset(bobdb, split="test", transform=transforms)
diff --git a/bob/ip/binseg/configs/models/driubn.py b/bob/ip/binseg/configs/models/driubn.py
index 0b95501d..aedf52ed 100644
--- a/bob/ip/binseg/configs/models/driubn.py
+++ b/bob/ip/binseg/configs/models/driubn.py
@@ -26,13 +26,23 @@ scheduler_gamma = 0.1
 model = build_driu()
 
 # pretrained backbone
-pretrained_backbone = modelurls['vgg16_bn']
+pretrained_backbone = modelurls["vgg16_bn"]
 
 # optimizer
-optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+optimizer = AdaBound(
+    model.parameters(),
+    lr=lr,
+    betas=betas,
+    final_lr=final_lr,
+    gamma=gamma,
+    eps=eps,
+    weight_decay=weight_decay,
+    amsbound=amsbound,
+)
 # criterion
 criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
 
 # scheduler
-scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
+scheduler = MultiStepLR(
+    optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
+)
diff --git a/bob/ip/binseg/configs/models/driubnssl.py b/bob/ip/binseg/configs/models/driubnssl.py
index 52b3a2b3..429a0500 100644
--- a/bob/ip/binseg/configs/models/driubnssl.py
+++ b/bob/ip/binseg/configs/models/driubnssl.py
@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
 model = build_driu()
 
 # pretrained backbone
-pretrained_backbone = modelurls['vgg16_bn']
+pretrained_backbone = modelurls["vgg16_bn"]
 
 # optimizer
-optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+optimizer = AdaBound(
+    model.parameters(),
+    lr=lr,
+    betas=betas,
+    final_lr=final_lr,
+    gamma=gamma,
+    eps=eps,
+    weight_decay=weight_decay,
+    amsbound=amsbound,
+)
 
 # criterion
 criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7)
 
 # scheduler
-scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
+scheduler = MultiStepLR(
+    optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
+)
diff --git a/bob/ip/binseg/configs/models/driuod.py b/bob/ip/binseg/configs/models/driuod.py
index 7ad9bb83..b53fc751 100644
--- a/bob/ip/binseg/configs/models/driuod.py
+++ b/bob/ip/binseg/configs/models/driuod.py
@@ -26,13 +26,23 @@ scheduler_gamma = 0.1
 model = build_driuod()
 
 # pretrained backbone
-pretrained_backbone = modelurls['vgg16']
+pretrained_backbone = modelurls["vgg16"]
 
 # optimizer
-optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+optimizer = AdaBound(
+    model.parameters(),
+    lr=lr,
+    betas=betas,
+    final_lr=final_lr,
+    gamma=gamma,
+    eps=eps,
+    weight_decay=weight_decay,
+    amsbound=amsbound,
+)
 # criterion
 criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
 
 # scheduler
-scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
+scheduler = MultiStepLR(
+    optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
+)
diff --git a/bob/ip/binseg/configs/models/driussl.py b/bob/ip/binseg/configs/models/driussl.py
index 39afd4a0..a5d49950 100644
--- a/bob/ip/binseg/configs/models/driussl.py
+++ b/bob/ip/binseg/configs/models/driussl.py
@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
 model = build_driu()
 
 # pretrained backbone
-pretrained_backbone = modelurls['vgg16']
+pretrained_backbone = modelurls["vgg16"]
 
 # optimizer
-optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+optimizer = AdaBound(
+    model.parameters(),
+    lr=lr,
+    betas=betas,
+    final_lr=final_lr,
+    gamma=gamma,
+    eps=eps,
+    weight_decay=weight_decay,
+    amsbound=amsbound,
+)
 
 # criterion
 criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7)
 
 # scheduler
-scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
+scheduler = MultiStepLR(
+    optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
+)
diff --git a/bob/ip/binseg/configs/models/hed.py b/bob/ip/binseg/configs/models/hed.py
index eeb0e599..ee905048 100644
--- a/bob/ip/binseg/configs/models/hed.py
+++ b/bob/ip/binseg/configs/models/hed.py
@@ -27,13 +27,23 @@ scheduler_gamma = 0.1
 model = build_hed()
 
 # pretrained backbone
-pretrained_backbone = modelurls['vgg16']
+pretrained_backbone = modelurls["vgg16"]
 
 # optimizer
-optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+optimizer = AdaBound(
+    model.parameters(),
+    lr=lr,
+    betas=betas,
+    final_lr=final_lr,
+    gamma=gamma,
+    eps=eps,
+    weight_decay=weight_decay,
+    amsbound=amsbound,
+)
 # criterion
 criterion = HEDSoftJaccardBCELogitsLoss(alpha=0.7)
 
 # scheduler
-scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
+scheduler = MultiStepLR(
+    optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
+)
diff --git a/bob/ip/binseg/configs/models/m2unet.py b/bob/ip/binseg/configs/models/m2unet.py
index b15a2779..4dd0da54 100644
--- a/bob/ip/binseg/configs/models/m2unet.py
+++ b/bob/ip/binseg/configs/models/m2unet.py
@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
 model = build_m2unet()
 
 # pretrained backbone
-pretrained_backbone = modelurls['mobilenetv2']
+pretrained_backbone = modelurls["mobilenetv2"]
 
 # optimizer
-optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
-    
+optimizer = AdaBound(
+    model.parameters(),
+    lr=lr,
+    betas=betas,
+    final_lr=final_lr,
+    gamma=gamma,
+    eps=eps,
+    weight_decay=weight_decay,
+    amsbound=amsbound,
+)
+
 # criterion
 criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
 
 # scheduler
-scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
+scheduler = MultiStepLR(
+    optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
+)
diff --git a/bob/ip/binseg/configs/models/m2unetssl.py b/bob/ip/binseg/configs/models/m2unetssl.py
index 3497cea2..4eab3c6c 100644
--- a/bob/ip/binseg/configs/models/m2unetssl.py
+++ b/bob/ip/binseg/configs/models/m2unetssl.py
@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
 model = build_m2unet()
 
 # pretrained backbone
-pretrained_backbone = modelurls['mobilenetv2']
+pretrained_backbone = modelurls["mobilenetv2"]
 
 # optimizer
-optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
-    
+optimizer = AdaBound(
+    model.parameters(),
+    lr=lr,
+    betas=betas,
+    final_lr=final_lr,
+    gamma=gamma,
+    eps=eps,
+    weight_decay=weight_decay,
+    amsbound=amsbound,
+)
+
 # criterion
 criterion = MixJacLoss(lambda_u=0.05, jacalpha=0.7)
 
 # scheduler
-scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
+scheduler = MultiStepLR(
+    optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
+)
diff --git a/bob/ip/binseg/configs/models/resunet.py b/bob/ip/binseg/configs/models/resunet.py
index a1db473c..b2129d5e 100644
--- a/bob/ip/binseg/configs/models/resunet.py
+++ b/bob/ip/binseg/configs/models/resunet.py
@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
 model = build_res50unet()
 
 # pretrained backbone
-pretrained_backbone = modelurls['resnet50']
+pretrained_backbone = modelurls["resnet50"]
 
 # optimizer
-optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
-    
+optimizer = AdaBound(
+    model.parameters(),
+    lr=lr,
+    betas=betas,
+    final_lr=final_lr,
+    gamma=gamma,
+    eps=eps,
+    weight_decay=weight_decay,
+    amsbound=amsbound,
+)
+
 # criterion
 criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
 
 # scheduler
-scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
+scheduler = MultiStepLR(
+    optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
+)
diff --git a/bob/ip/binseg/configs/models/unet.py b/bob/ip/binseg/configs/models/unet.py
index 8182c7fa..c129c5d9 100644
--- a/bob/ip/binseg/configs/models/unet.py
+++ b/bob/ip/binseg/configs/models/unet.py
@@ -26,14 +26,24 @@ scheduler_gamma = 0.1
 model = build_unet()
 
 # pretrained backbone
-pretrained_backbone = modelurls['vgg16']
+pretrained_backbone = modelurls["vgg16"]
 
 # optimizer
-optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
-    
+optimizer = AdaBound(
+    model.parameters(),
+    lr=lr,
+    betas=betas,
+    final_lr=final_lr,
+    gamma=gamma,
+    eps=eps,
+    weight_decay=weight_decay,
+    amsbound=amsbound,
+)
+
 # criterion
 criterion = SoftJaccardBCELogitsLoss(alpha=0.7)
 
 # scheduler
-scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
+scheduler = MultiStepLR(
+    optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma
+)
diff --git a/bob/ip/binseg/data/binsegdataset.py b/bob/ip/binseg/data/binsegdataset.py
index 2917203c..e2977d37 100644
--- a/bob/ip/binseg/data/binsegdataset.py
+++ b/bob/ip/binseg/data/binsegdataset.py
@@ -3,6 +3,7 @@
 from torch.utils.data import Dataset
 import random
 
+
 class BinSegDataset(Dataset):
     """PyTorch dataset wrapper around bob.db binary segmentation datasets. 
     A transform object can be passed that will be applied to the image, ground truth and mask (if present). 
@@ -19,18 +20,19 @@ class BinSegDataset(Dataset):
     mask : bool
         whether dataset contains masks or not
     """
-    def __init__(self, bobdb, split = 'train', transform = None,index_to = None):
+
+    def __init__(self, bobdb, split="train", transform=None, index_to=None):
         if index_to:
             self.database = bobdb.samples(split)[:index_to]
         else:
             self.database = bobdb.samples(split)
         self.transform = transform
         self.split = split
-    
+
     @property
     def mask(self):
         # check if first sample contains a mask
-        return hasattr(self.database[0], 'mask')
+        return hasattr(self.database[0], "mask")
 
     def __len__(self):
         """
@@ -40,8 +42,8 @@ class BinSegDataset(Dataset):
             size of the dataset
         """
         return len(self.database)
-    
-    def __getitem__(self,index):
+
+    def __getitem__(self, index):
         """
         Parameters
         ----------
@@ -56,12 +58,12 @@ class BinSegDataset(Dataset):
         gt = self.database[index].gt.pil_image()
         img_name = self.database[index].img.basename
         sample = [img, gt]
-        
-        if self.transform :
+
+        if self.transform:
             sample = self.transform(*sample)
-        
-        sample.insert(0,img_name)
-        
+
+        sample.insert(0, img_name)
+
         return sample
 
 
@@ -77,10 +79,10 @@ class SSLBinSegDataset(Dataset):
     unlabeled_dataset : :py:class:`torch.utils.data.Dataset`
         UnLabeledBinSegDataset with unlabeled data
     """
+
     def __init__(self, labeled_dataset, unlabeled_dataset):
         self.labeled_dataset = labeled_dataset
         self.unlabeled_dataset = unlabeled_dataset
-    
 
     def __len__(self):
         """
@@ -90,8 +92,8 @@ class SSLBinSegDataset(Dataset):
             size of the dataset
         """
         return len(self.labeled_dataset)
-    
-    def __getitem__(self,index):
+
+    def __getitem__(self, index):
         """
         Parameters
         ----------
@@ -123,13 +125,14 @@ class UnLabeledBinSegDataset(Dataset):
     transform : :py:mod:`bob.ip.binseg.data.transforms`, optional
         A transform or composition of transfroms. Defaults to ``None``.
     """
-    def __init__(self, db, split = 'train', transform = None,index_from= None):
+
+    def __init__(self, db, split="train", transform=None, index_from=None):
         if index_from:
             self.database = db.samples(split)[index_from:]
         else:
             self.database = db.samples(split)
         self.transform = transform
-        self.split = split   
+        self.split = split
 
     def __len__(self):
         """
@@ -139,8 +142,8 @@ class UnLabeledBinSegDataset(Dataset):
             size of the dataset
         """
         return len(self.database)
-    
-    def __getitem__(self,index):
+
+    def __getitem__(self, index):
         """
         Parameters
         ----------
@@ -155,9 +158,9 @@ class UnLabeledBinSegDataset(Dataset):
         img = self.database[index].img.pil_image()
         img_name = self.database[index].img.basename
         sample = [img]
-        if self.transform :
+        if self.transform:
             sample = self.transform(img)
-        
-        sample.insert(0,img_name)
-        
-        return sample
\ No newline at end of file
+
+        sample.insert(0, img_name)
+
+        return sample
diff --git a/bob/ip/binseg/data/imagefolder.py b/bob/ip/binseg/data/imagefolder.py
index 7ec9dd9d..9794f5c1 100644
--- a/bob/ip/binseg/data/imagefolder.py
+++ b/bob/ip/binseg/data/imagefolder.py
@@ -8,16 +8,18 @@ import torch
 import torchvision.transforms.functional as VF
 import bob.io.base
 
+
 def get_file_lists(data_path):
     data_path = Path(data_path)
-    
-    image_path = data_path.joinpath('images')
-    image_file_names = np.array(sorted(list(image_path.glob('*'))))
 
-    gt_path = data_path.joinpath('gt')
-    gt_file_names = np.array(sorted(list(gt_path.glob('*'))))
+    image_path = data_path.joinpath("images")
+    image_file_names = np.array(sorted(list(image_path.glob("*"))))
+
+    gt_path = data_path.joinpath("gt")
+    gt_file_names = np.array(sorted(list(gt_path.glob("*"))))
     return image_file_names, gt_file_names
 
+
 class ImageFolder(Dataset):
     """
     Generic ImageFolder dataset, that contains two folders:
@@ -32,7 +34,8 @@ class ImageFolder(Dataset):
         full path to root of dataset
     
     """
-    def __init__(self, path, transform = None):
+
+    def __init__(self, path, transform=None):
         self.transform = transform
         self.img_file_list, self.gt_file_list = get_file_lists(path)
 
@@ -44,8 +47,8 @@ class ImageFolder(Dataset):
             size of the dataset
         """
         return len(self.img_file_list)
-    
-    def __getitem__(self,index):
+
+    def __getitem__(self, index):
         """
         Parameters
         ----------
@@ -58,22 +61,22 @@ class ImageFolder(Dataset):
         """
         img_path = self.img_file_list[index]
         img_name = img_path.name
-        img = Image.open(img_path).convert(mode='RGB')
-    
+        img = Image.open(img_path).convert(mode="RGB")
+
         gt_path = self.gt_file_list[index]
-        if gt_path.suffix == '.hdf5':
-            gt = bob.io.base.load(str(gt_path)).astype('float32')
+        if gt_path.suffix == ".hdf5":
+            gt = bob.io.base.load(str(gt_path)).astype("float32")
             # not elegant but since transforms require PIL images we do this hacky workaround here
             gt = torch.from_numpy(gt)
-            gt = VF.to_pil_image(gt).convert(mode='1', dither=None)
+            gt = VF.to_pil_image(gt).convert(mode="1", dither=None)
         else:
-            gt = Image.open(gt_path).convert(mode='1', dither=None)
-        
+            gt = Image.open(gt_path).convert(mode="1", dither=None)
+
         sample = [img, gt]
-        
-        if self.transform :
+
+        if self.transform:
             sample = self.transform(*sample)
-        
-        sample.insert(0,img_name)
-        
+
+        sample.insert(0, img_name)
+
         return sample
diff --git a/bob/ip/binseg/data/imagefolderinference.py b/bob/ip/binseg/data/imagefolderinference.py
index c4218755..5a3fdfa0 100644
--- a/bob/ip/binseg/data/imagefolderinference.py
+++ b/bob/ip/binseg/data/imagefolderinference.py
@@ -8,6 +8,7 @@ import torch
 import torchvision.transforms.functional as VF
 import bob.io.base
 
+
 def get_file_lists(data_path, glob):
     """
     Recursively retrieves file lists from a given path, matching a given glob
@@ -20,6 +21,7 @@ def get_file_lists(data_path, glob):
     image_file_names = np.array(sorted(list(data_path.rglob(glob))))
     return image_file_names
 
+
 class ImageFolderInference(Dataset):
     """
     Generic ImageFolder containing images for inference
@@ -43,7 +45,8 @@ class ImageFolderInference(Dataset):
         List of transformations to apply to every input sample
 
     """
-    def __init__(self, path, glob='*', transform = None):
+
+    def __init__(self, path, glob="*", transform=None):
         self.transform = transform
         self.path = path
         self.img_file_list = get_file_lists(path, glob)
@@ -57,7 +60,7 @@ class ImageFolderInference(Dataset):
         """
         return len(self.img_file_list)
 
-    def __getitem__(self,index):
+    def __getitem__(self, index):
         """
         Parameters
         ----------
@@ -74,9 +77,9 @@ class ImageFolderInference(Dataset):
 
         sample = [img]
 
-        if self.transform :
+        if self.transform:
             sample = self.transform(*sample)
 
-        sample.insert(0,img_name)
+        sample.insert(0, img_name)
 
         return sample
diff --git a/bob/ip/binseg/data/transforms.py b/bob/ip/binseg/data/transforms.py
index 659a090e..05040e46 100644
--- a/bob/ip/binseg/data/transforms.py
+++ b/bob/ip/binseg/data/transforms.py
@@ -22,17 +22,18 @@ import collections
 import bob.core
 
 _pil_interpolation_to_str = {
-    Image.NEAREST: 'PIL.Image.NEAREST',
-    Image.BILINEAR: 'PIL.Image.BILINEAR',
-    Image.BICUBIC: 'PIL.Image.BICUBIC',
-    Image.LANCZOS: 'PIL.Image.LANCZOS',
-    Image.HAMMING: 'PIL.Image.HAMMING',
-    Image.BOX: 'PIL.Image.BOX',
+    Image.NEAREST: "PIL.Image.NEAREST",
+    Image.BILINEAR: "PIL.Image.BILINEAR",
+    Image.BICUBIC: "PIL.Image.BICUBIC",
+    Image.LANCZOS: "PIL.Image.LANCZOS",
+    Image.HAMMING: "PIL.Image.HAMMING",
+    Image.BOX: "PIL.Image.BOX",
 }
 Iterable = collections.abc.Iterable
 
 # Compose
 
+
 class Compose:
     """Composes several transforms.
 
@@ -51,15 +52,17 @@ class Compose:
         return args
 
     def __repr__(self):
-        format_string = self.__class__.__name__ + '('
+        format_string = self.__class__.__name__ + "("
         for t in self.transforms:
-            format_string += '\n'
-            format_string += '    {0}'.format(t)
-        format_string += '\n)'
+            format_string += "\n"
+            format_string += "    {0}".format(t)
+        format_string += "\n)"
         return format_string
 
+
 # Preprocessing
 
+
 class CenterCrop:
     """
     Crop at the center.
@@ -69,6 +72,7 @@ class CenterCrop:
     size : int
         target size
     """
+
     def __init__(self, size):
         self.size = size
 
@@ -91,6 +95,7 @@ class Crop:
     w : int
         width of the cropped image.
     """
+
     def __init__(self, i, j, h, w):
         self.i = i
         self.j = j
@@ -98,7 +103,10 @@ class Crop:
         self.w = w
 
     def __call__(self, *args):
-        return [img.crop((self.j, self.i, self.j + self.w, self.i + self.h)) for img in args]
+        return [
+            img.crop((self.j, self.i, self.j + self.w, self.i + self.h)) for img in args
+        ]
+
 
 class Pad:
     """
@@ -115,12 +123,17 @@ class Pad:
         pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
         This value is only used when the padding_mode is constant
     """
+
     def __init__(self, padding, fill=0):
         self.padding = padding
         self.fill = fill
 
     def __call__(self, *args):
-        return [VF.pad(img, self.padding, self.fill, padding_mode='constant') for img in args]
+        return [
+            VF.pad(img, self.padding, self.fill, padding_mode="constant")
+            for img in args
+        ]
+
 
 class AutoLevel16to8:
     """Converts a 16-bit image to 8-bit representation using "auto-level"
@@ -131,13 +144,16 @@ class AutoLevel16to8:
     consider such a range should be mapped to the [0,255] range of the
     destination image.
     """
+
     def _process_one(self, img):
-        return Image.fromarray(bob.core.convert(img, 'uint8', (0,255),
-            img.getextrema()))
+        return Image.fromarray(
+            bob.core.convert(img, "uint8", (0, 255), img.getextrema())
+        )
 
     def __call__(self, *args):
         return [self._process_one(img) for img in args]
 
+
 class ToRGB:
     """Converts from any input format to RGB, using an ADAPTIVE conversion.
 
@@ -146,17 +162,21 @@ class ToRGB:
     defaults.  This may be aggressive if applied to 16-bit images without
     further considerations.
     """
+
     def __call__(self, *args):
         return [img.convert(mode="RGB") for img in args]
 
+
 class ToTensor:
     """Converts :py:class:`PIL.Image.Image` to :py:class:`torch.Tensor` """
+
     def __call__(self, *args):
         return [VF.to_tensor(img) for img in args]
 
 
 # Augmentations
 
+
 class RandomHFlip:
     """
     Flips horizontally
@@ -166,7 +186,8 @@ class RandomHFlip:
     prob : float
         probability at which imgage is flipped. Defaults to ``0.5``
     """
-    def __init__(self, prob = 0.5):
+
+    def __init__(self, prob=0.5):
         self.prob = prob
 
     def __call__(self, *args):
@@ -186,7 +207,8 @@ class RandomVFlip:
     prob : float
         probability at which imgage is flipped. Defaults to ``0.5``
     """
-    def __init__(self, prob = 0.5):
+
+    def __init__(self, prob=0.5):
         self.prob = prob
 
     def __call__(self, *args):
@@ -208,17 +230,19 @@ class RandomRotation:
     prob : float
         probability at which imgage is rotated. Defaults to ``0.5``
     """
-    def __init__(self, degree_range = (-15, +15), prob = 0.5):
+
+    def __init__(self, degree_range=(-15, +15), prob=0.5):
         self.prob = prob
         self.degree_range = degree_range
 
     def __call__(self, *args):
         if random.random() < self.prob:
             degree = random.randint(*self.degree_range)
-            return [VF.rotate(img, degree, resample = Image.BILINEAR) for img in args]
+            return [VF.rotate(img, degree, resample=Image.BILINEAR) for img in args]
         else:
             return args
 
+
 class ColorJitter(object):
     """
     Randomly change the brightness, contrast, saturation and hue
@@ -240,7 +264,10 @@ class ColorJitter(object):
     prob : float
         probability at which the operation is applied
     """
-    def __init__(self, brightness=0.3, contrast=0.3, saturation=0.02, hue=0.02, prob=0.5):
+
+    def __init__(
+        self, brightness=0.3, contrast=0.3, saturation=0.02, hue=0.02, prob=0.5
+    ):
         self.brightness = brightness
         self.contrast = contrast
         self.saturation = saturation
@@ -252,15 +279,21 @@ class ColorJitter(object):
         transforms = []
         if brightness > 0:
             brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
-            transforms.append(Lambda(lambda img: VF.adjust_brightness(img, brightness_factor)))
+            transforms.append(
+                Lambda(lambda img: VF.adjust_brightness(img, brightness_factor))
+            )
 
         if contrast > 0:
             contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
-            transforms.append(Lambda(lambda img: VF.adjust_contrast(img, contrast_factor)))
+            transforms.append(
+                Lambda(lambda img: VF.adjust_contrast(img, contrast_factor))
+            )
 
         if saturation > 0:
             saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
-            transforms.append(Lambda(lambda img: VF.adjust_saturation(img, saturation_factor)))
+            transforms.append(
+                Lambda(lambda img: VF.adjust_saturation(img, saturation_factor))
+            )
 
         if hue > 0:
             hue_factor = random.uniform(-hue, hue)
@@ -273,8 +306,9 @@ class ColorJitter(object):
 
     def __call__(self, *args):
         if random.random() < self.prob:
-            transform = self.get_params(self.brightness, self.contrast,
-                                        self.saturation, self.hue)
+            transform = self.get_params(
+                self.brightness, self.contrast, self.saturation, self.hue
+            )
             trans_img = transform(args[0])
             return [trans_img, *args[1:]]
         else:
@@ -301,7 +335,14 @@ class RandomResizedCrop:
         probability at which the operation is applied. Defaults to ``0.5``
     """
 
-    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR, prob = 0.5):
+    def __init__(
+        self,
+        size,
+        scale=(0.08, 1.0),
+        ratio=(3.0 / 4.0, 4.0 / 3.0),
+        interpolation=Image.BILINEAR,
+        prob=0.5,
+    ):
         if isinstance(size, tuple):
             self.size = size
         else:
@@ -333,10 +374,10 @@ class RandomResizedCrop:
 
         # Fallback to central crop
         in_ratio = img.size[0] / img.size[1]
-        if (in_ratio < min(ratio)):
+        if in_ratio < min(ratio):
             w = img.size[0]
             h = w / min(ratio)
-        elif (in_ratio > max(ratio)):
+        elif in_ratio > max(ratio):
             h = img.size[1]
             w = h * max(ratio)
         else:  # whole image
@@ -359,10 +400,10 @@ class RandomResizedCrop:
 
     def __repr__(self):
         interpolate_str = _pil_interpolation_to_str[self.interpolation]
-        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
-        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
-        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
-        format_string += ', interpolation={0})'.format(interpolate_str)
+        format_string = self.__class__.__name__ + "(size={0}".format(self.size)
+        format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale))
+        format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio))
+        format_string += ", interpolation={0})".format(interpolate_str)
         return format_string
 
 
@@ -391,4 +432,6 @@ class Resize:
 
     def __repr__(self):
         interpolate_str = _pil_interpolation_to_str[self.interpolation]
-        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
+        return self.__class__.__name__ + "(size={0}, interpolation={1})".format(
+            self.size, interpolate_str
+        )
diff --git a/bob/ip/binseg/engine/adabound.py b/bob/ip/binseg/engine/adabound.py
index 9e658b32..683bd76f 100644
--- a/bob/ip/binseg/engine/adabound.py
+++ b/bob/ip/binseg/engine/adabound.py
@@ -54,8 +54,17 @@ class AdaBound(torch.optim.Optimizer):
 
     """
 
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1,
-            gamma=1e-3, eps=1e-8, weight_decay=0, amsbound=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        final_lr=0.1,
+        gamma=1e-3,
+        eps=1e-8,
+        weight_decay=0,
+        amsbound=False,
+    ):
         if not 0.0 <= lr:
             raise ValueError("Invalid learning rate: {}".format(lr))
         if not 0.0 <= eps:
@@ -68,16 +77,23 @@ class AdaBound(torch.optim.Optimizer):
             raise ValueError("Invalid final learning rate: {}".format(final_lr))
         if not 0.0 <= gamma < 1.0:
             raise ValueError("Invalid gamma parameter: {}".format(gamma))
-        defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps,
-                        weight_decay=weight_decay, amsbound=amsbound)
+        defaults = dict(
+            lr=lr,
+            betas=betas,
+            final_lr=final_lr,
+            gamma=gamma,
+            eps=eps,
+            weight_decay=weight_decay,
+            amsbound=amsbound,
+        )
         super(AdaBound, self).__init__(params, defaults)
 
-        self.base_lrs = list(map(lambda group: group['lr'], self.param_groups))
+        self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
 
     def __setstate__(self, state):
         super(AdaBound, self).__setstate__(state)
         for group in self.param_groups:
-            group.setdefault('amsbound', False)
+            group.setdefault("amsbound", False)
 
     def step(self, closure=None):
         """Performs a single optimization step.
@@ -94,37 +110,38 @@ class AdaBound(torch.optim.Optimizer):
             loss = closure()
 
         for group, base_lr in zip(self.param_groups, self.base_lrs):
-            for p in group['params']:
+            for p in group["params"]:
                 if p.grad is None:
                     continue
                 grad = p.grad.data
                 if grad.is_sparse:
                     raise RuntimeError(
-                        'Adam does not support sparse gradients, please consider SparseAdam instead')
-                amsbound = group['amsbound']
+                        "Adam does not support sparse gradients, please consider SparseAdam instead"
+                    )
+                amsbound = group["amsbound"]
 
                 state = self.state[p]
 
                 # State initialization
                 if len(state) == 0:
-                    state['step'] = 0
+                    state["step"] = 0
                     # Exponential moving average of gradient values
-                    state['exp_avg'] = torch.zeros_like(p.data)
+                    state["exp_avg"] = torch.zeros_like(p.data)
                     # Exponential moving average of squared gradient values
-                    state['exp_avg_sq'] = torch.zeros_like(p.data)
+                    state["exp_avg_sq"] = torch.zeros_like(p.data)
                     if amsbound:
                         # Maintains max of all exp. moving avg. of sq. grad. values
-                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)
+                        state["max_exp_avg_sq"] = torch.zeros_like(p.data)
 
-                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                 if amsbound:
-                    max_exp_avg_sq = state['max_exp_avg_sq']
-                beta1, beta2 = group['betas']
+                    max_exp_avg_sq = state["max_exp_avg_sq"]
+                beta1, beta2 = group["betas"]
 
-                state['step'] += 1
+                state["step"] += 1
 
-                if group['weight_decay'] != 0:
-                    grad = grad.add(group['weight_decay'], p.data)
+                if group["weight_decay"] != 0:
+                    grad = grad.add(group["weight_decay"], p.data)
 
                 # Decay the first and second moment running average coefficient
                 exp_avg.mul_(beta1).add_(1 - beta1, grad)
@@ -133,19 +150,19 @@ class AdaBound(torch.optim.Optimizer):
                     # Maintains the maximum of all 2nd moment running avg. till now
                     torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                     # Use the max. for normalizing running avg. of gradient
-                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
+                    denom = max_exp_avg_sq.sqrt().add_(group["eps"])
                 else:
-                    denom = exp_avg_sq.sqrt().add_(group['eps'])
+                    denom = exp_avg_sq.sqrt().add_(group["eps"])
 
-                bias_correction1 = 1 - beta1 ** state['step']
-                bias_correction2 = 1 - beta2 ** state['step']
-                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
+                bias_correction1 = 1 - beta1 ** state["step"]
+                bias_correction2 = 1 - beta2 ** state["step"]
+                step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
 
                 # Applies bounds on actual learning rate
                 # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
-                final_lr = group['final_lr'] * group['lr'] / base_lr
-                lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))
-                upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step']))
+                final_lr = group["final_lr"] * group["lr"] / base_lr
+                lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1))
+                upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"]))
                 step_size = torch.full_like(denom, step_size)
                 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
 
@@ -153,6 +170,7 @@ class AdaBound(torch.optim.Optimizer):
 
         return loss
 
+
 class AdaBoundW(torch.optim.Optimizer):
     """Implements AdaBound algorithm with Decoupled Weight Decay
     (See https://arxiv.org/abs/1711.05101)
@@ -187,8 +205,17 @@ class AdaBoundW(torch.optim.Optimizer):
 
     """
 
-    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1,
-            gamma=1e-3, eps=1e-8, weight_decay=0, amsbound=False):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        final_lr=0.1,
+        gamma=1e-3,
+        eps=1e-8,
+        weight_decay=0,
+        amsbound=False,
+    ):
 
         if not 0.0 <= lr:
             raise ValueError("Invalid learning rate: {}".format(lr))
@@ -202,16 +229,23 @@ class AdaBoundW(torch.optim.Optimizer):
             raise ValueError("Invalid final learning rate: {}".format(final_lr))
         if not 0.0 <= gamma < 1.0:
             raise ValueError("Invalid gamma parameter: {}".format(gamma))
-        defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
-                eps=eps, weight_decay=weight_decay, amsbound=amsbound)
+        defaults = dict(
+            lr=lr,
+            betas=betas,
+            final_lr=final_lr,
+            gamma=gamma,
+            eps=eps,
+            weight_decay=weight_decay,
+            amsbound=amsbound,
+        )
         super(AdaBoundW, self).__init__(params, defaults)
 
-        self.base_lrs = list(map(lambda group: group['lr'], self.param_groups))
+        self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
 
     def __setstate__(self, state):
         super(AdaBoundW, self).__setstate__(state)
         for group in self.param_groups:
-            group.setdefault('amsbound', False)
+            group.setdefault("amsbound", False)
 
     def step(self, closure=None):
         """Performs a single optimization step.
@@ -229,34 +263,35 @@ class AdaBoundW(torch.optim.Optimizer):
             loss = closure()
 
         for group, base_lr in zip(self.param_groups, self.base_lrs):
-            for p in group['params']:
+            for p in group["params"]:
                 if p.grad is None:
                     continue
                 grad = p.grad.data
                 if grad.is_sparse:
                     raise RuntimeError(
-                        'Adam does not support sparse gradients, please consider SparseAdam instead')
-                amsbound = group['amsbound']
+                        "Adam does not support sparse gradients, please consider SparseAdam instead"
+                    )
+                amsbound = group["amsbound"]
 
                 state = self.state[p]
 
                 # State initialization
                 if len(state) == 0:
-                    state['step'] = 0
+                    state["step"] = 0
                     # Exponential moving average of gradient values
-                    state['exp_avg'] = torch.zeros_like(p.data)
+                    state["exp_avg"] = torch.zeros_like(p.data)
                     # Exponential moving average of squared gradient values
-                    state['exp_avg_sq'] = torch.zeros_like(p.data)
+                    state["exp_avg_sq"] = torch.zeros_like(p.data)
                     if amsbound:
                         # Maintains max of all exp. moving avg. of sq. grad. values
-                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)
+                        state["max_exp_avg_sq"] = torch.zeros_like(p.data)
 
-                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                 if amsbound:
-                    max_exp_avg_sq = state['max_exp_avg_sq']
-                beta1, beta2 = group['betas']
+                    max_exp_avg_sq = state["max_exp_avg_sq"]
+                beta1, beta2 = group["betas"]
 
-                state['step'] += 1
+                state["step"] += 1
 
                 # Decay the first and second moment running average coefficient
                 exp_avg.mul_(beta1).add_(1 - beta1, grad)
@@ -265,25 +300,25 @@ class AdaBoundW(torch.optim.Optimizer):
                     # Maintains the maximum of all 2nd moment running avg. till now
                     torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                     # Use the max. for normalizing running avg. of gradient
-                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
+                    denom = max_exp_avg_sq.sqrt().add_(group["eps"])
                 else:
-                    denom = exp_avg_sq.sqrt().add_(group['eps'])
+                    denom = exp_avg_sq.sqrt().add_(group["eps"])
 
-                bias_correction1 = 1 - beta1 ** state['step']
-                bias_correction2 = 1 - beta2 ** state['step']
-                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
+                bias_correction1 = 1 - beta1 ** state["step"]
+                bias_correction2 = 1 - beta2 ** state["step"]
+                step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
 
                 # Applies bounds on actual learning rate
                 # lr_scheduler cannot affect final_lr, this is a workaround to
                 # apply lr decay
-                final_lr = group['final_lr'] * group['lr'] / base_lr
-                lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))
-                upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step']))
+                final_lr = group["final_lr"] * group["lr"] / base_lr
+                lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1))
+                upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"]))
                 step_size = torch.full_like(denom, step_size)
                 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
 
-                if group['weight_decay'] != 0:
-                    decayed_weights = torch.mul(p.data, group['weight_decay'])
+                if group["weight_decay"] != 0:
+                    decayed_weights = torch.mul(p.data, group["weight_decay"])
                     p.data.add_(-step_size)
                     p.data.sub_(decayed_weights)
                 else:
diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py
index e76153c0..43f2dd7e 100644
--- a/bob/ip/binseg/engine/inferencer.py
+++ b/bob/ip/binseg/engine/inferencer.py
@@ -18,7 +18,6 @@ from bob.ip.binseg.utils.plot import precision_recall_f1iso_confintval
 from bob.ip.binseg.utils.summary import summary
 
 
-
 def batch_metrics(predictions, ground_truths, names, output_folder, logger):
     """
     Calculates metrics on the batch and saves it to disc
@@ -51,21 +50,23 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
         file_name = "{}.csv".format(names[j])
         logger.info("saving {}".format(file_name))
 
-        with open (os.path.join(output_folder,file_name), "w+") as outfile:
+        with open(os.path.join(output_folder, file_name), "w+") as outfile:
 
-            outfile.write("threshold, precision, recall, specificity, accuracy, jaccard, f1_score\n")
+            outfile.write(
+                "threshold, precision, recall, specificity, accuracy, jaccard, f1_score\n"
+            )
 
-            for threshold in np.arange(0.0,1.0,step_size):
+            for threshold in np.arange(0.0, 1.0, step_size):
                 # threshold
                 binary_pred = torch.gt(predictions[j], threshold).byte()
 
                 # equals and not-equals
-                equals = torch.eq(binary_pred, gts).type(torch.uint8) # tensor
-                notequals = torch.ne(binary_pred, gts).type(torch.uint8) # tensor
+                equals = torch.eq(binary_pred, gts).type(torch.uint8)  # tensor
+                notequals = torch.ne(binary_pred, gts).type(torch.uint8)  # tensor
 
                 # true positives
-                tp_tensor = (gts * binary_pred ) # tensor
-                tp_count = torch.sum(tp_tensor).item() # scalar
+                tp_tensor = gts * binary_pred  # tensor
+                tp_count = torch.sum(tp_tensor).item()  # scalar
 
                 # false positives
                 fp_tensor = torch.eq((binary_pred + tp_tensor), 1)
@@ -83,10 +84,13 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
                 metrics = base_metrics(tp_count, fp_count, tn_count, fn_count)
 
                 # write to disk
-                outfile.write("{:.2f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f} \n".format(threshold, *metrics))
-
-                batch_metrics.append([names[j],threshold, *metrics ])
+                outfile.write(
+                    "{:.2f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f} \n".format(
+                        threshold, *metrics
+                    )
+                )
 
+                batch_metrics.append([names[j], threshold, *metrics])
 
     return batch_metrics
 
@@ -106,16 +110,18 @@ def save_probability_images(predictions, names, output_folder, logger):
     logger : :py:class:`logging.Logger`
         python logger
     """
-    images_subfolder = os.path.join(output_folder,'images')
+    images_subfolder = os.path.join(output_folder, "images")
     for j in range(predictions.size()[0]):
         img = VF.to_pil_image(predictions.cpu().data[j])
-        filename = '{}.png'.format(names[j].split(".")[0])
+        filename = "{}.png".format(names[j].split(".")[0])
         fullpath = os.path.join(images_subfolder, filename)
         logger.info("saving {}".format(fullpath))
         fulldir = os.path.dirname(fullpath)
-        if not os.path.exists(fulldir): os.makedirs(fulldir)
+        if not os.path.exists(fulldir):
+            os.makedirs(fulldir)
         img.save(fullpath)
 
+
 def save_hdf(predictions, names, output_folder, logger):
     """
     Saves probability maps as image in the same format as the test image
@@ -131,23 +137,21 @@ def save_hdf(predictions, names, output_folder, logger):
     logger : :py:class:`logging.Logger`
         python logger
     """
-    hdf5_subfolder = os.path.join(output_folder,'hdf5')
-    if not os.path.exists(hdf5_subfolder): os.makedirs(hdf5_subfolder)
+    hdf5_subfolder = os.path.join(output_folder, "hdf5")
+    if not os.path.exists(hdf5_subfolder):
+        os.makedirs(hdf5_subfolder)
     for j in range(predictions.size()[0]):
         img = predictions.cpu().data[j].squeeze(0).numpy()
-        filename = '{}.hdf5'.format(names[j].split(".")[0])
+        filename = "{}.hdf5".format(names[j].split(".")[0])
         fullpath = os.path.join(hdf5_subfolder, filename)
         logger.info("saving {}".format(filename))
         fulldir = os.path.dirname(fullpath)
-        if not os.path.exists(fulldir): os.makedirs(fulldir)
+        if not os.path.exists(fulldir):
+            os.makedirs(fulldir)
         bob.io.base.save(img, fullpath)
 
-def do_inference(
-    model,
-    data_loader,
-    device,
-    output_folder = None
-):
+
+def do_inference(model, data_loader, device, output_folder=None):
 
     """
     Run inference and calculate metrics
@@ -164,8 +168,8 @@ def do_inference(
     logger = logging.getLogger("bob.ip.binseg.engine.inference")
     logger.info("Start evaluation")
     logger.info("Output folder: {}, Device: {}".format(output_folder, device))
-    results_subfolder = os.path.join(output_folder,'results')
-    os.makedirs(results_subfolder,exist_ok=True)
+    results_subfolder = os.path.join(output_folder, "results")
+    os.makedirs(results_subfolder, exist_ok=True)
 
     model.eval().to(device)
     # Sigmoid for probabilities
@@ -189,7 +193,7 @@ def do_inference(
 
             # necessary check for hed architecture that uses several outputs
             # for loss calculation instead of just the last concatfuse block
-            if isinstance(outputs,list):
+            if isinstance(outputs, list):
                 outputs = outputs[-1]
 
             probabilities = sigmoid(outputs)
@@ -198,7 +202,9 @@ def do_inference(
             times.append(batch_time)
             logger.info("Batch time: {:.5f} s".format(batch_time))
 
-            b_metrics = batch_metrics(probabilities, ground_truths, names,results_subfolder, logger)
+            b_metrics = batch_metrics(
+                probabilities, ground_truths, names, results_subfolder, logger
+            )
             metrics.extend(b_metrics)
 
             # Create probability images
@@ -207,74 +213,94 @@ def do_inference(
             save_hdf(probabilities, names, output_folder, logger)
 
     # DataFrame
-    df_metrics = pd.DataFrame(metrics,columns= \
-                           ["name",
-                            "threshold",
-                            "precision",
-                            "recall",
-                            "specificity",
-                            "accuracy",
-                            "jaccard",
-                            "f1_score"])
+    df_metrics = pd.DataFrame(
+        metrics,
+        columns=[
+            "name",
+            "threshold",
+            "precision",
+            "recall",
+            "specificity",
+            "accuracy",
+            "jaccard",
+            "f1_score",
+        ],
+    )
 
     # Report and Averages
     metrics_file = "Metrics.csv".format(model.name)
     metrics_path = os.path.join(results_subfolder, metrics_file)
     logger.info("Saving average over all input images: {}".format(metrics_file))
 
-    avg_metrics = df_metrics.groupby('threshold').mean()
-    std_metrics = df_metrics.groupby('threshold').std()
+    avg_metrics = df_metrics.groupby("threshold").mean()
+    std_metrics = df_metrics.groupby("threshold").std()
 
     # Uncomment below for F1-score calculation based on average precision and metrics instead of
     # F1-scores of individual images. This method is in line with Maninis et. al. (2016)
-    #avg_metrics["f1_score"] =  (2* avg_metrics["precision"]*avg_metrics["recall"])/ \
+    # avg_metrics["f1_score"] =  (2* avg_metrics["precision"]*avg_metrics["recall"])/ \
     #    (avg_metrics["precision"]+avg_metrics["recall"])
 
     avg_metrics["std_pr"] = std_metrics["precision"]
-    avg_metrics["pr_upper"] = avg_metrics['precision'] + avg_metrics["std_pr"]
-    avg_metrics["pr_lower"] = avg_metrics['precision'] - avg_metrics["std_pr"]
+    avg_metrics["pr_upper"] = avg_metrics["precision"] + avg_metrics["std_pr"]
+    avg_metrics["pr_lower"] = avg_metrics["precision"] - avg_metrics["std_pr"]
     avg_metrics["std_re"] = std_metrics["recall"]
-    avg_metrics["re_upper"] = avg_metrics['recall'] + avg_metrics["std_re"]
-    avg_metrics["re_lower"] = avg_metrics['recall'] - avg_metrics["std_re"]
+    avg_metrics["re_upper"] = avg_metrics["recall"] + avg_metrics["std_re"]
+    avg_metrics["re_lower"] = avg_metrics["recall"] - avg_metrics["std_re"]
     avg_metrics["std_f1"] = std_metrics["f1_score"]
 
     avg_metrics.to_csv(metrics_path)
-    maxf1 = avg_metrics['f1_score'].max()
-    optimal_f1_threshold = avg_metrics['f1_score'].idxmax()
+    maxf1 = avg_metrics["f1_score"].max()
+    optimal_f1_threshold = avg_metrics["f1_score"].idxmax()
 
-    logger.info("Highest F1-score of {:.5f}, achieved at threshold {}".format(maxf1, optimal_f1_threshold))
+    logger.info(
+        "Highest F1-score of {:.5f}, achieved at threshold {}".format(
+            maxf1, optimal_f1_threshold
+        )
+    )
 
     # Plotting
     np_avg_metrics = avg_metrics.to_numpy().T
     fig_name = "precision_recall.pdf"
     logger.info("saving {}".format(fig_name))
-    fig = precision_recall_f1iso_confintval([np_avg_metrics[0]],[np_avg_metrics[1]],[np_avg_metrics[7]],[np_avg_metrics[8]],[np_avg_metrics[10]],[np_avg_metrics[11]], [model.name,None], title=output_folder)
+    fig = precision_recall_f1iso_confintval(
+        [np_avg_metrics[0]],
+        [np_avg_metrics[1]],
+        [np_avg_metrics[7]],
+        [np_avg_metrics[8]],
+        [np_avg_metrics[10]],
+        [np_avg_metrics[11]],
+        [model.name, None],
+        title=output_folder,
+    )
     fig_filename = os.path.join(results_subfolder, fig_name)
     fig.savefig(fig_filename)
 
     # Report times
     total_inference_time = str(datetime.timedelta(seconds=int(sum(times))))
     average_batch_inference_time = np.mean(times)
-    total_evalution_time = str(datetime.timedelta(seconds=int(time.time() - start_total_time )))
+    total_evalution_time = str(
+        datetime.timedelta(seconds=int(time.time() - start_total_time))
+    )
 
-    logger.info("Average batch inference time: {:.5f}s".format(average_batch_inference_time))
+    logger.info(
+        "Average batch inference time: {:.5f}s".format(average_batch_inference_time)
+    )
 
     times_file = "Times.txt"
     logger.info("saving {}".format(times_file))
 
-    with open (os.path.join(results_subfolder,times_file), "w+") as outfile:
+    with open(os.path.join(results_subfolder, times_file), "w+") as outfile:
         date = datetime.datetime.now()
         outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S")))
         outfile.write("Total evaluation run-time: {} \n".format(total_evalution_time))
-        outfile.write("Average batch inference time: {} \n".format(average_batch_inference_time))
+        outfile.write(
+            "Average batch inference time: {} \n".format(average_batch_inference_time)
+        )
         outfile.write("Total inference time: {} \n".format(total_inference_time))
 
     # Save model summary
-    summary_file = 'ModelSummary.txt'
+    summary_file = "ModelSummary.txt"
     logger.info("saving {}".format(summary_file))
 
-    with open (os.path.join(results_subfolder,summary_file), "w+") as outfile:
-        summary(model,outfile)
-
-
-
+    with open(os.path.join(results_subfolder, summary_file), "w+") as outfile:
+        summary(model, outfile)
diff --git a/bob/ip/binseg/engine/predicter.py b/bob/ip/binseg/engine/predicter.py
index ebd09ac5..d8fb2de3 100644
--- a/bob/ip/binseg/engine/predicter.py
+++ b/bob/ip/binseg/engine/predicter.py
@@ -15,12 +15,7 @@ from bob.ip.binseg.engine.inferencer import save_probability_images
 from bob.ip.binseg.engine.inferencer import save_hdf
 
 
-def do_predict(
-    model,
-    data_loader,
-    device,
-    output_folder = None
-):
+def do_predict(model, data_loader, device, output_folder=None):
 
     """
     Run inference and calculate metrics
@@ -37,8 +32,8 @@ def do_predict(
     logger = logging.getLogger("bob.ip.binseg.engine.inference")
     logger.info("Start evaluation")
     logger.info("Output folder: {}, Device: {}".format(output_folder, device))
-    results_subfolder = os.path.join(output_folder,'results')
-    os.makedirs(results_subfolder,exist_ok=True)
+    results_subfolder = os.path.join(output_folder, "results")
+    os.makedirs(results_subfolder, exist_ok=True)
 
     model.eval().to(device)
     # Sigmoid for probabilities
@@ -58,7 +53,7 @@ def do_predict(
 
             # necessary check for hed architecture that uses several outputs
             # for loss calculation instead of just the last concatfuse block
-            if isinstance(outputs,list):
+            if isinstance(outputs, list):
                 outputs = outputs[-1]
 
             probabilities = sigmoid(outputs)
@@ -72,22 +67,25 @@ def do_predict(
             # Save hdf5
             save_hdf(probabilities, names, output_folder, logger)
 
-
     # Report times
     total_inference_time = str(datetime.timedelta(seconds=int(sum(times))))
     average_batch_inference_time = np.mean(times)
-    total_evalution_time = str(datetime.timedelta(seconds=int(time.time() - start_total_time )))
+    total_evalution_time = str(
+        datetime.timedelta(seconds=int(time.time() - start_total_time))
+    )
 
-    logger.info("Average batch inference time: {:.5f}s".format(average_batch_inference_time))
+    logger.info(
+        "Average batch inference time: {:.5f}s".format(average_batch_inference_time)
+    )
 
     times_file = "Times.txt"
     logger.info("saving {}".format(times_file))
 
-    with open (os.path.join(results_subfolder,times_file), "w+") as outfile:
+    with open(os.path.join(results_subfolder, times_file), "w+") as outfile:
         date = datetime.datetime.now()
         outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S")))
         outfile.write("Total evaluation run-time: {} \n".format(total_evalution_time))
-        outfile.write("Average batch inference time: {} \n".format(average_batch_inference_time))
+        outfile.write(
+            "Average batch inference time: {} \n".format(average_batch_inference_time)
+        )
         outfile.write("Total inference time: {} \n".format(total_inference_time))
-
-
diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index dfc73c86..46083544 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -13,10 +13,12 @@ import numpy as np
 from bob.ip.binseg.utils.metric import SmoothedValue
 from bob.ip.binseg.utils.plot import loss_curve
 
+
 def sharpen(x, T):
-    temp = x**(1/T)
+    temp = x ** (1 / T)
     return temp / temp.sum(dim=1, keepdim=True)
 
+
 def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
     """Applies mix up as described in [MIXMATCH_19].
 
@@ -41,21 +43,30 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
     """
     # TODO:
     with torch.no_grad():
-        l = np.random.beta(alpha, alpha) # Eq (8)
-        l = max(l, 1 - l) # Eq (9)
+        l = np.random.beta(alpha, alpha)  # Eq (8)
+        l = max(l, 1 - l)  # Eq (9)
         # Shuffle and concat. Alg. 1 Line: 12
-        w_inputs = torch.cat([input,unlabeled_input],0)
-        w_targets = torch.cat([target,unlabled_target],0)
-        idx = torch.randperm(w_inputs.size(0)) # get random index
+        w_inputs = torch.cat([input, unlabeled_input], 0)
+        w_targets = torch.cat([target, unlabled_target], 0)
+        idx = torch.randperm(w_inputs.size(0))  # get random index
 
         # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13
-        input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]]
-        target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]]
+        input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input) :]]
+        target_mixedup = l * target + (1 - l) * w_targets[idx[len(target) :]]
 
         # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14
-        unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]]
-        unlabled_target_mixedup =  l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]]
-        return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup
+        unlabeled_input_mixedup = (
+            l * unlabeled_input + (1 - l) * w_inputs[idx[: len(unlabeled_input)]]
+        )
+        unlabled_target_mixedup = (
+            l * unlabled_target + (1 - l) * w_targets[idx[: len(unlabled_target)]]
+        )
+        return (
+            input_mixedup,
+            target_mixedup,
+            unlabeled_input_mixedup,
+            unlabled_target_mixedup,
+        )
 
 
 def square_rampup(current, rampup_length=16):
@@ -80,9 +91,10 @@ def square_rampup(current, rampup_length=16):
     if rampup_length == 0:
         return 1.0
     else:
-        current = np.clip((current/ float(rampup_length))**2, 0.0, 1.0)
+        current = np.clip((current / float(rampup_length)) ** 2, 0.0, 1.0)
     return float(current)
 
+
 def linear_rampup(current, rampup_length=16):
     """slowly ramp-up ``lambda_u``
 
@@ -107,6 +119,7 @@ def linear_rampup(current, rampup_length=16):
         current = np.clip(current / rampup_length, 0.0, 1.0)
     return float(current)
 
+
 def guess_labels(unlabeled_images, model):
     """
     Calculate the average predictions by 2 augmentations: horizontal and vertical flips
@@ -130,15 +143,16 @@ def guess_labels(unlabeled_images, model):
         guess1 = torch.sigmoid(model(unlabeled_images)).unsqueeze(0)
         # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1)
         hflip = torch.sigmoid(model(unlabeled_images.flip(2))).unsqueeze(0)
-        guess2  = hflip.flip(3)
+        guess2 = hflip.flip(3)
         # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1)
         vflip = torch.sigmoid(model(unlabeled_images.flip(3))).unsqueeze(0)
         guess3 = vflip.flip(4)
         # Concat
-        concat = torch.cat([guess1,guess2,guess3],0)
-        avg_guess = torch.mean(concat,0)
+        concat = torch.cat([guess1, guess2, guess3], 0)
+        avg_guess = torch.mean(concat, 0)
         return avg_guess
 
+
 def do_ssltrain(
     model,
     data_loader,
@@ -150,7 +164,7 @@ def do_ssltrain(
     device,
     arguments,
     output_folder,
-    rampup_length
+    rampup_length,
 ):
     """
     Train model and save to disk.
@@ -196,7 +210,9 @@ def do_ssltrain(
     max_epoch = arguments["max_epoch"]
 
     # Logg to file
-    with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+",1) as outfile:
+    with open(
+        os.path.join(output_folder, "{}_trainlog.csv".format(model.name)), "a+", 1
+    ) as outfile:
         for state in optimizer.state.values():
             for k, v in state.items():
                 if isinstance(v, torch.Tensor):
@@ -226,11 +242,17 @@ def do_ssltrain(
                 unlabeled_outputs = model(unlabeled_images)
                 # guessed unlabeled outputs
                 unlabeled_ground_truths = guess_labels(unlabeled_images, model)
-                #unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5)
-                #images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths)
-                ramp_up_factor = square_rampup(epoch,rampup_length=rampup_length)
-
-                loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor)
+                # unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5)
+                # images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths)
+                ramp_up_factor = square_rampup(epoch, rampup_length=rampup_length)
+
+                loss, ll, ul = criterion(
+                    outputs,
+                    ground_truths,
+                    unlabeled_outputs,
+                    unlabeled_ground_truths,
+                    ramp_up_factor,
+                )
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
@@ -247,60 +269,77 @@ def do_ssltrain(
 
             epoch_time = time.time() - start_epoch_time
 
-
             eta_seconds = epoch_time * (max_epoch - epoch)
             eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
 
-            outfile.write(("{epoch}, "
-                        "{avg_loss:.6f}, "
-                        "{median_loss:.6f}, "
-                        "{median_labeled_loss},"
-                        "{median_unlabeled_loss},"
-                        "{lr:.6f}, "
-                        "{memory:.0f}"
-                        "\n"
-                        ).format(
+            outfile.write(
+                (
+                    "{epoch}, "
+                    "{avg_loss:.6f}, "
+                    "{median_loss:.6f}, "
+                    "{median_labeled_loss},"
+                    "{median_unlabeled_loss},"
+                    "{lr:.6f}, "
+                    "{memory:.0f}"
+                    "\n"
+                ).format(
                     eta=eta_string,
                     epoch=epoch,
                     avg_loss=losses.avg,
                     median_loss=losses.median,
-                    median_labeled_loss = labeled_loss.median,
-                    median_unlabeled_loss = unlabeled_loss.median,
+                    median_labeled_loss=labeled_loss.median,
+                    median_unlabeled_loss=unlabeled_loss.median,
                     lr=optimizer.param_groups[0]["lr"],
-                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0,
-                    )
+                    memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
+                    if torch.cuda.is_available()
+                    else 0.0,
                 )
-            logger.info(("eta: {eta}, "
-                        "epoch: {epoch}, "
-                        "avg. loss: {avg_loss:.6f}, "
-                        "median loss: {median_loss:.6f}, "
-                        "labeled loss: {median_labeled_loss}, "
-                        "unlabeled loss: {median_unlabeled_loss}, "
-                        "lr: {lr:.6f}, "
-                        "max mem: {memory:.0f}"
-                        ).format(
+            )
+            logger.info(
+                (
+                    "eta: {eta}, "
+                    "epoch: {epoch}, "
+                    "avg. loss: {avg_loss:.6f}, "
+                    "median loss: {median_loss:.6f}, "
+                    "labeled loss: {median_labeled_loss}, "
+                    "unlabeled loss: {median_unlabeled_loss}, "
+                    "lr: {lr:.6f}, "
+                    "max mem: {memory:.0f}"
+                ).format(
                     eta=eta_string,
                     epoch=epoch,
                     avg_loss=losses.avg,
                     median_loss=losses.median,
-                    median_labeled_loss = labeled_loss.median,
-                    median_unlabeled_loss = unlabeled_loss.median,
+                    median_labeled_loss=labeled_loss.median,
+                    median_unlabeled_loss=unlabeled_loss.median,
                     lr=optimizer.param_groups[0]["lr"],
-                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0
-                    )
+                    memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
+                    if torch.cuda.is_available()
+                    else 0.0,
                 )
-
+            )
 
         total_training_time = time.time() - start_training_time
         total_time_str = str(datetime.timedelta(seconds=total_training_time))
         logger.info(
             "Total training time: {} ({:.4f} s / epoch)".format(
                 total_time_str, total_training_time / (max_epoch)
-            ))
-
-    log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name))
-    logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss", "labeled loss", "unlabeled loss", "lr","max memory"])
-    fig = loss_curve(logdf,output_folder)
+            )
+        )
+
+    log_plot_file = os.path.join(output_folder, "{}_trainlog.pdf".format(model.name))
+    logdf = pd.read_csv(
+        os.path.join(output_folder, "{}_trainlog.csv".format(model.name)),
+        header=None,
+        names=[
+            "avg. loss",
+            "median loss",
+            "labeled loss",
+            "unlabeled loss",
+            "lr",
+            "max memory",
+        ],
+    )
+    fig = loss_curve(logdf, output_folder)
     logger.info("saving {}".format(log_plot_file))
     fig.savefig(log_plot_file)
-
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index e083ec85..2b35528e 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import os 
+import os
 import logging
 import time
 import datetime
@@ -23,7 +23,7 @@ def do_train(
     checkpoint_period,
     device,
     arguments,
-    output_folder
+    output_folder,
 ):
     """ 
     Train model and save to disk.
@@ -55,8 +55,10 @@ def do_train(
     max_epoch = arguments["max_epoch"]
 
     # Logg to file
-    with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile:
-        
+    with open(
+        os.path.join(output_folder, "{}_trainlog.csv".format(model.name)), "a+"
+    ) as outfile:
+
         model.train().to(device)
         for state in optimizer.state.values():
             for k, v in state.items():
@@ -70,7 +72,7 @@ def do_train(
             losses = SmoothedValue(len(data_loader))
             epoch = epoch + 1
             arguments["epoch"] = epoch
-            
+
             # Epoch time
             start_epoch_time = time.time()
 
@@ -81,9 +83,9 @@ def do_train(
                 masks = None
                 if len(samples) == 4:
                     masks = samples[-1].to(device)
-                
+
                 outputs = model(images)
-                
+
                 loss = criterion(outputs, ground_truths, masks)
                 optimizer.zero_grad()
                 loss.backward()
@@ -100,51 +102,62 @@ def do_train(
 
             epoch_time = time.time() - start_epoch_time
 
-
             eta_seconds = epoch_time * (max_epoch - epoch)
             eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
 
-            outfile.write(("{epoch}, "
-                        "{avg_loss:.6f}, "
-                        "{median_loss:.6f}, "
-                        "{lr:.6f}, "
-                        "{memory:.0f}"
-                        "\n"
-                        ).format(
+            outfile.write(
+                (
+                    "{epoch}, "
+                    "{avg_loss:.6f}, "
+                    "{median_loss:.6f}, "
+                    "{lr:.6f}, "
+                    "{memory:.0f}"
+                    "\n"
+                ).format(
                     eta=eta_string,
                     epoch=epoch,
                     avg_loss=losses.avg,
                     median_loss=losses.median,
                     lr=optimizer.param_groups[0]["lr"],
-                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0,
-                    )
-                )  
-            logger.info(("eta: {eta}, " 
-                        "epoch: {epoch}, "
-                        "avg. loss: {avg_loss:.6f}, "
-                        "median loss: {median_loss:.6f}, "
-                        "lr: {lr:.6f}, "
-                        "max mem: {memory:.0f}"
-                        ).format(
+                    memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
+                    if torch.cuda.is_available()
+                    else 0.0,
+                )
+            )
+            logger.info(
+                (
+                    "eta: {eta}, "
+                    "epoch: {epoch}, "
+                    "avg. loss: {avg_loss:.6f}, "
+                    "median loss: {median_loss:.6f}, "
+                    "lr: {lr:.6f}, "
+                    "max mem: {memory:.0f}"
+                ).format(
                     eta=eta_string,
                     epoch=epoch,
                     avg_loss=losses.avg,
                     median_loss=losses.median,
                     lr=optimizer.param_groups[0]["lr"],
-                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0
-                    )
+                    memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
+                    if torch.cuda.is_available()
+                    else 0.0,
                 )
-
+            )
 
         total_training_time = time.time() - start_training_time
         total_time_str = str(datetime.timedelta(seconds=total_training_time))
         logger.info(
             "Total training time: {} ({:.4f} s / epoch)".format(
                 total_time_str, total_training_time / (max_epoch)
-            ))
-        
-    log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name))
-    logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss","lr","max memory"])
-    fig = loss_curve(logdf,output_folder)
+            )
+        )
+
+    log_plot_file = os.path.join(output_folder, "{}_trainlog.pdf".format(model.name))
+    logdf = pd.read_csv(
+        os.path.join(output_folder, "{}_trainlog.csv".format(model.name)),
+        header=None,
+        names=["avg. loss", "median loss", "lr", "max memory"],
+    )
+    fig = loss_curve(logdf, output_folder)
     logger.info("saving {}".format(log_plot_file))
     fig.savefig(log_plot_file)
diff --git a/bob/ip/binseg/modeling/backbones/mobilenetv2.py b/bob/ip/binseg/modeling/backbones/mobilenetv2.py
index 5d87f496..2f4b8851 100644
--- a/bob/ip/binseg/modeling/backbones/mobilenetv2.py
+++ b/bob/ip/binseg/modeling/backbones/mobilenetv2.py
@@ -12,7 +12,7 @@ def conv_bn(inp, oup, stride):
     return torch.nn.Sequential(
         torch.nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
         torch.nn.BatchNorm2d(oup),
-        torch.nn.ReLU6(inplace=True)
+        torch.nn.ReLU6(inplace=True),
     )
 
 
@@ -20,7 +20,7 @@ def conv_1x1_bn(inp, oup):
     return torch.nn.Sequential(
         torch.nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
         torch.nn.BatchNorm2d(oup),
-        torch.nn.ReLU6(inplace=True)
+        torch.nn.ReLU6(inplace=True),
     )
 
 
@@ -36,7 +36,9 @@ class InvertedResidual(torch.nn.Module):
         if expand_ratio == 1:
             self.conv = torch.nn.Sequential(
                 # dw
-                torch.nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+                torch.nn.Conv2d(
+                    hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False
+                ),
                 torch.nn.BatchNorm2d(hidden_dim),
                 torch.nn.ReLU6(inplace=True),
                 # pw-linear
@@ -50,7 +52,9 @@ class InvertedResidual(torch.nn.Module):
                 torch.nn.BatchNorm2d(hidden_dim),
                 torch.nn.ReLU6(inplace=True),
                 # dw
-                torch.nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+                torch.nn.Conv2d(
+                    hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False
+                ),
                 torch.nn.BatchNorm2d(hidden_dim),
                 torch.nn.ReLU6(inplace=True),
                 # pw-linear
@@ -66,7 +70,14 @@ class InvertedResidual(torch.nn.Module):
 
 
 class MobileNetV2(torch.nn.Module):
-    def __init__(self, n_class=1000, input_size=224, width_mult=1., return_features = None, m2u=True):
+    def __init__(
+        self,
+        n_class=1000,
+        input_size=224,
+        width_mult=1.0,
+        return_features=None,
+        m2u=True,
+    ):
         super(MobileNetV2, self).__init__()
         self.return_features = return_features
         self.m2u = m2u
@@ -80,34 +91,38 @@ class MobileNetV2(torch.nn.Module):
             [6, 32, 3, 2],
             [6, 64, 4, 2],
             [6, 96, 3, 1],
-            #[6, 160, 3, 2],
-            #[6, 320, 1, 1],
+            # [6, 160, 3, 2],
+            # [6, 320, 1, 1],
         ]
 
         # building first layer
         assert input_size % 32 == 0
         input_channel = int(input_channel * width_mult)
-        #self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
+        # self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
         self.features = [conv_bn(3, input_channel, 2)]
         # building inverted residual blocks
         for t, c, n, s in interverted_residual_setting:
             output_channel = int(c * width_mult)
             for i in range(n):
                 if i == 0:
-                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
+                    self.features.append(
+                        block(input_channel, output_channel, s, expand_ratio=t)
+                    )
                 else:
-                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
+                    self.features.append(
+                        block(input_channel, output_channel, 1, expand_ratio=t)
+                    )
                 input_channel = output_channel
         # building last several layers
-        #self.features.append(conv_1x1_bn(input_channel, self.last_channel))
+        # self.features.append(conv_1x1_bn(input_channel, self.last_channel))
         # make it torch.nn.Sequential
         self.features = torch.nn.Sequential(*self.features)
 
         # building classifier
-        #self.classifier = torch.nn.Sequential(
+        # self.classifier = torch.nn.Sequential(
         #    torch.nn.Dropout(0.2),
         #    torch.nn.Linear(self.last_channel, n_class),
-        #)
+        # )
 
         self._initialize_weights()
 
@@ -117,7 +132,7 @@ class MobileNetV2(torch.nn.Module):
         outputs.append(x.shape[2:4])
         if self.m2u:
             outputs.append(x)
-        for index,m in enumerate(self.features):
+        for index, m in enumerate(self.features):
             x = m(x)
             # extract layers
             if index in self.return_features:
@@ -128,7 +143,7 @@ class MobileNetV2(torch.nn.Module):
         for m in self.modules():
             if isinstance(m, torch.nn.Conv2d):
                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
-                m.weight.data.normal_(0, math.sqrt(2. / n))
+                m.weight.data.normal_(0, math.sqrt(2.0 / n))
                 if m.bias is not None:
                     m.bias.data.zero_()
             elif isinstance(m, torch.nn.BatchNorm2d):
diff --git a/bob/ip/binseg/modeling/backbones/resnet.py b/bob/ip/binseg/modeling/backbones/resnet.py
index 285a5a15..f26c5d57 100644
--- a/bob/ip/binseg/modeling/backbones/resnet.py
+++ b/bob/ip/binseg/modeling/backbones/resnet.py
@@ -18,20 +18,13 @@ model_urls = {
 def _conv3x3(in_planes, out_planes, stride=1):
     """3x3 convolution with padding"""
     return nn.Conv2d(
-        in_planes,
-        out_planes,
-        kernel_size=3,
-        stride=stride,
-        padding=1,
-        bias=False,
+        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False,
     )
 
 
 def _conv1x1(in_planes, out_planes, stride=1):
     """1x1 convolution"""
-    return nn.Conv2d(
-        in_planes, out_planes, kernel_size=1, stride=stride, bias=False
-    )
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
 
 
 class _BasicBlock(nn.Module):
@@ -105,9 +98,7 @@ class _Bottleneck(nn.Module):
 
 
 class ResNet(nn.Module):
-    def __init__(
-        self, block, layers, return_features, zero_init_residual=False
-    ):
+    def __init__(self, block, layers, return_features, zero_init_residual=False):
         """
         Generic ResNet network with layer return.
         Attributes
@@ -118,9 +109,7 @@ class ResNet(nn.Module):
         super(ResNet, self).__init__()
         self.inplanes = 64
         self.return_features = return_features
-        self.conv1 = nn.Conv2d(
-            3, 64, kernel_size=7, stride=2, padding=3, bias=False
-        )
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
         self.bn1 = nn.BatchNorm2d(64)
         self.relu = nn.ReLU(inplace=True)
         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@@ -142,9 +131,7 @@ class ResNet(nn.Module):
 
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(
-                    m.weight, mode="fan_out", nonlinearity="relu"
-                )
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
             elif isinstance(m, nn.BatchNorm2d):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
@@ -229,9 +216,7 @@ def shaperesnet50(pretrained=False, **kwargs):
     if pretrained:
         model.load_state_dict(
             model_zoo.load_url(
-                model_urls[
-                    "resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN"
-                ]
+                model_urls["resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN"]
             )
         )
     return model
diff --git a/bob/ip/binseg/modeling/backbones/vgg.py b/bob/ip/binseg/modeling/backbones/vgg.py
index 7736a4c1..390f6ef4 100644
--- a/bob/ip/binseg/modeling/backbones/vgg.py
+++ b/bob/ip/binseg/modeling/backbones/vgg.py
@@ -8,19 +8,18 @@ import torch.utils.model_zoo as model_zoo
 
 
 model_urls = {
-    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
-    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
-    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
-    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
-    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
-    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
-    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
-    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
+    "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
+    "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
+    "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
+    "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
+    "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
+    "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
+    "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
+    "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
 }
 
 
 class VGG(nn.Module):
-
     def __init__(self, features, return_features, init_weights=True):
         super(VGG, self).__init__()
         self.features = features
@@ -32,7 +31,7 @@ class VGG(nn.Module):
         outputs = []
         # hw of input, needed for DRIU and HED
         outputs.append(x.shape[2:4])
-        for index,m in enumerate(self.features):
+        for index, m in enumerate(self.features):
             x = m(x)
             # extract layers
             if index in self.return_features:
@@ -42,7 +41,7 @@ class VGG(nn.Module):
     def _initialize_weights(self):
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                 if m.bias is not None:
                     nn.init.constant_(m.bias, 0)
             elif isinstance(m, nn.BatchNorm2d):
@@ -57,7 +56,7 @@ def _make_layers(cfg, batch_norm=False):
     layers = []
     in_channels = 3
     for v in cfg:
-        if v == 'M':
+        if v == "M":
             layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
         else:
             conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
@@ -70,10 +69,51 @@ def _make_layers(cfg, batch_norm=False):
 
 
 _cfg = {
-    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
-    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
-    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
-    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+    "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
+    "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
+    "D": [
+        64,
+        64,
+        "M",
+        128,
+        128,
+        "M",
+        256,
+        256,
+        256,
+        "M",
+        512,
+        512,
+        512,
+        "M",
+        512,
+        512,
+        512,
+        "M",
+    ],
+    "E": [
+        64,
+        64,
+        "M",
+        128,
+        128,
+        "M",
+        256,
+        256,
+        256,
+        256,
+        "M",
+        512,
+        512,
+        512,
+        512,
+        "M",
+        512,
+        512,
+        512,
+        512,
+        "M",
+    ],
 }
 
 
@@ -83,10 +123,10 @@ def vgg11(pretrained=False, **kwargs):
         pretrained (bool): If True, returns a model pre-trained on ImageNet
     """
     if pretrained:
-        kwargs['init_weights'] = False
-    model = VGG(_make_layers(_cfg['A']), **kwargs)
+        kwargs["init_weights"] = False
+    model = VGG(_make_layers(_cfg["A"]), **kwargs)
     if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
+        model.load_state_dict(model_zoo.load_url(model_urls["vgg11"]))
     return model
 
 
@@ -96,10 +136,10 @@ def vgg11_bn(pretrained=False, **kwargs):
         pretrained (bool): If True, returns a model pre-trained on ImageNet
     """
     if pretrained:
-        kwargs['init_weights'] = False
-    model = VGG(_make_layers(_cfg['A'], batch_norm=True), **kwargs)
+        kwargs["init_weights"] = False
+    model = VGG(_make_layers(_cfg["A"], batch_norm=True), **kwargs)
     if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
+        model.load_state_dict(model_zoo.load_url(model_urls["vgg11_bn"]))
     return model
 
 
@@ -109,10 +149,10 @@ def vgg13(pretrained=False, **kwargs):
         pretrained (bool): If True, returns a model pre-trained on ImageNet
     """
     if pretrained:
-        kwargs['init_weights'] = False
-    model = VGG(_make_layers(_cfg['B']), **kwargs)
+        kwargs["init_weights"] = False
+    model = VGG(_make_layers(_cfg["B"]), **kwargs)
     if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
+        model.load_state_dict(model_zoo.load_url(model_urls["vgg13"]))
     return model
 
 
@@ -122,10 +162,10 @@ def vgg13_bn(pretrained=False, **kwargs):
         pretrained (bool): If True, returns a model pre-trained on ImageNet
     """
     if pretrained:
-        kwargs['init_weights'] = False
-    model = VGG(_make_layers(_cfg['B'], batch_norm=True), **kwargs)
+        kwargs["init_weights"] = False
+    model = VGG(_make_layers(_cfg["B"], batch_norm=True), **kwargs)
     if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
+        model.load_state_dict(model_zoo.load_url(model_urls["vgg13_bn"]))
     return model
 
 
@@ -135,10 +175,10 @@ def vgg16(pretrained=False, **kwargs):
         pretrained (bool): If True, returns a model pre-trained on ImageNet
     """
     if pretrained:
-        kwargs['init_weights'] = False
-    model = VGG(_make_layers(_cfg['D']), **kwargs)
+        kwargs["init_weights"] = False
+    model = VGG(_make_layers(_cfg["D"]), **kwargs)
     if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls['vgg16']),strict=False)
+        model.load_state_dict(model_zoo.load_url(model_urls["vgg16"]), strict=False)
     return model
 
 
@@ -148,10 +188,10 @@ def vgg16_bn(pretrained=False, **kwargs):
         pretrained (bool): If True, returns a model pre-trained on ImageNet
     """
     if pretrained:
-        kwargs['init_weights'] = False
-    model = VGG(_make_layers(_cfg['D'], batch_norm=True), **kwargs)
+        kwargs["init_weights"] = False
+    model = VGG(_make_layers(_cfg["D"], batch_norm=True), **kwargs)
     if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
+        model.load_state_dict(model_zoo.load_url(model_urls["vgg16_bn"]))
     return model
 
 
@@ -161,10 +201,10 @@ def vgg19(pretrained=False, **kwargs):
         pretrained (bool): If True, returns a model pre-trained on ImageNet
     """
     if pretrained:
-        kwargs['init_weights'] = False
-    model = VGG(_make_layers(_cfg['E']), **kwargs)
+        kwargs["init_weights"] = False
+    model = VGG(_make_layers(_cfg["E"]), **kwargs)
     if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
+        model.load_state_dict(model_zoo.load_url(model_urls["vgg19"]))
     return model
 
 
@@ -174,8 +214,8 @@ def vgg19_bn(pretrained=False, **kwargs):
         pretrained (bool): If True, returns a model pre-trained on ImageNet
     """
     if pretrained:
-        kwargs['init_weights'] = False
-    model = VGG(_make_layers(_cfg['E'], batch_norm=True), **kwargs)
+        kwargs["init_weights"] = False
+    model = VGG(_make_layers(_cfg["E"], batch_norm=True), **kwargs)
     if pretrained:
-        model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
+        model.load_state_dict(model_zoo.load_url(model_urls["vgg19_bn"]))
     return model
diff --git a/bob/ip/binseg/modeling/driu.py b/bob/ip/binseg/modeling/driu.py
index 06454a17..5b4425c2 100644
--- a/bob/ip/binseg/modeling/driu.py
+++ b/bob/ip/binseg/modeling/driu.py
@@ -43,12 +43,7 @@ class DRIU(torch.nn.Module):
 
     def __init__(self, in_channels_list=None):
         super(DRIU, self).__init__()
-        (
-            in_conv_1_2_16,
-            in_upsample2,
-            in_upsample_4,
-            in_upsample_8,
-        ) = in_channels_list
+        (in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8,) = in_channels_list
 
         self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
         # Upsample layers
diff --git a/bob/ip/binseg/modeling/driubn.py b/bob/ip/binseg/modeling/driubn.py
index f9145011..245fdf17 100644
--- a/bob/ip/binseg/modeling/driubn.py
+++ b/bob/ip/binseg/modeling/driubn.py
@@ -5,24 +5,31 @@ import torch
 import torch.nn
 from collections import OrderedDict
 from bob.ip.binseg.modeling.backbones.vgg import vgg16_bn
-from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform,convtrans_with_kaiming_uniform, UpsampleCropBlock
+from bob.ip.binseg.modeling.make_layers import (
+    conv_with_kaiming_uniform,
+    convtrans_with_kaiming_uniform,
+    UpsampleCropBlock,
+)
+
 
 class ConcatFuseBlock(torch.nn.Module):
     """
     Takes in four feature maps with 16 channels each, concatenates them
     and applies a 1x1 convolution with 1 output channel.
     """
+
     def __init__(self):
         super().__init__()
         self.conv = torch.nn.Sequential(
-            conv_with_kaiming_uniform(4*16,1,1,1,0)
-            ,torch.nn.BatchNorm2d(1)
+            conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0), torch.nn.BatchNorm2d(1)
         )
-    def forward(self,x1,x2,x3,x4):
-        x_cat = torch.cat([x1,x2,x3,x4],dim=1)
+
+    def forward(self, x1, x2, x3, x4):
+        x_cat = torch.cat([x1, x2, x3, x4], dim=1)
         x = self.conv(x_cat)
         return x
 
+
 class DRIU(torch.nn.Module):
     """
     DRIU head module
@@ -34,6 +41,7 @@ class DRIU(torch.nn.Module):
     in_channels_list : list
         number of channels for each feature map that is returned from backbone
     """
+
     def __init__(self, in_channels_list=None):
         super(DRIU, self).__init__()
         in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8 = in_channels_list
@@ -47,7 +55,7 @@ class DRIU(torch.nn.Module):
         # Concat and Fuse
         self.concatfuse = ConcatFuseBlock()
 
-    def forward(self,x):
+    def forward(self, x):
         """
         Parameters
         ----------
@@ -62,12 +70,13 @@ class DRIU(torch.nn.Module):
         """
         hw = x[0]
         conv1_2_16 = self.conv1_2_16(x[1])  # conv1_2_16
-        upsample2 = self.upsample2(x[2], hw) # side-multi2-up
-        upsample4 = self.upsample4(x[3], hw) # side-multi3-up
-        upsample8 = self.upsample8(x[4], hw) # side-multi4-up
+        upsample2 = self.upsample2(x[2], hw)  # side-multi2-up
+        upsample4 = self.upsample4(x[3], hw)  # side-multi3-up
+        upsample8 = self.upsample8(x[4], hw)  # side-multi4-up
         out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
         return out
 
+
 def build_driu():
     """
     Adds backbone and head together
@@ -78,9 +87,11 @@ def build_driu():
     module : :py:class:`torch.nn.Module`
 
     """
-    backbone = vgg16_bn(pretrained=False, return_features = [5, 12, 19, 29])
+    backbone = vgg16_bn(pretrained=False, return_features=[5, 12, 19, 29])
     driu_head = DRIU([64, 128, 256, 512])
 
-    model = torch.nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)]))
+    model = torch.nn.Sequential(
+        OrderedDict([("backbone", backbone), ("head", driu_head)])
+    )
     model.name = "DRIUBN"
     return model
diff --git a/bob/ip/binseg/modeling/driuod.py b/bob/ip/binseg/modeling/driuod.py
index ab543e9e..25e5b82d 100644
--- a/bob/ip/binseg/modeling/driuod.py
+++ b/bob/ip/binseg/modeling/driuod.py
@@ -5,22 +5,29 @@ import torch
 import torch.nn
 from collections import OrderedDict
 from bob.ip.binseg.modeling.backbones.vgg import vgg16
-from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform,convtrans_with_kaiming_uniform, UpsampleCropBlock
+from bob.ip.binseg.modeling.make_layers import (
+    conv_with_kaiming_uniform,
+    convtrans_with_kaiming_uniform,
+    UpsampleCropBlock,
+)
+
 
 class ConcatFuseBlock(torch.nn.Module):
     """
     Takes in four feature maps with 16 channels each, concatenates them
     and applies a 1x1 convolution with 1 output channel.
     """
+
     def __init__(self):
         super().__init__()
-        self.conv = conv_with_kaiming_uniform(4*16,1,1,1,0)
+        self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0)
 
-    def forward(self,x1,x2,x3,x4):
-        x_cat = torch.cat([x1,x2,x3,x4],dim=1)
+    def forward(self, x1, x2, x3, x4):
+        x_cat = torch.cat([x1, x2, x3, x4], dim=1)
         x = self.conv(x_cat)
         return x
 
+
 class DRIUOD(torch.nn.Module):
     """
     DRIU head module
@@ -30,6 +37,7 @@ class DRIUOD(torch.nn.Module):
     in_channels_list : list
         number of channels for each feature map that is returned from backbone
     """
+
     def __init__(self, in_channels_list=None):
         super(DRIUOD, self).__init__()
         in_upsample2, in_upsample_4, in_upsample_8, in_upsample_16 = in_channels_list
@@ -40,11 +48,10 @@ class DRIUOD(torch.nn.Module):
         self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0)
         self.upsample16 = UpsampleCropBlock(in_upsample_16, 16, 32, 16, 0)
 
-
         # Concat and Fuse
         self.concatfuse = ConcatFuseBlock()
 
-    def forward(self,x):
+    def forward(self, x):
         """
         Parameters
         ----------
@@ -59,12 +66,13 @@ class DRIUOD(torch.nn.Module):
         """
         hw = x[0]
         upsample2 = self.upsample2(x[1], hw)  # side-multi2-up
-        upsample4 = self.upsample4(x[2], hw)   # side-multi3-up
-        upsample8 = self.upsample8(x[3], hw)   # side-multi4-up
+        upsample4 = self.upsample4(x[2], hw)  # side-multi3-up
+        upsample8 = self.upsample8(x[3], hw)  # side-multi4-up
         upsample16 = self.upsample16(x[4], hw)  # side-multi5-up
-        out = self.concatfuse(upsample2, upsample4, upsample8,upsample16)
+        out = self.concatfuse(upsample2, upsample4, upsample8, upsample16)
         return out
 
+
 def build_driuod():
     """
     Adds backbone and head together
@@ -74,9 +82,11 @@ def build_driuod():
     module : :py:class:`torch.nn.Module`
 
     """
-    backbone = vgg16(pretrained=False, return_features = [8, 14, 22,29])
-    driu_head = DRIUOD([128, 256, 512,512])
+    backbone = vgg16(pretrained=False, return_features=[8, 14, 22, 29])
+    driu_head = DRIUOD([128, 256, 512, 512])
 
-    model = torch.nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)]))
+    model = torch.nn.Sequential(
+        OrderedDict([("backbone", backbone), ("head", driu_head)])
+    )
     model.name = "DRIUOD"
     return model
diff --git a/bob/ip/binseg/modeling/driupix.py b/bob/ip/binseg/modeling/driupix.py
index e38768ea..3ad10aa7 100644
--- a/bob/ip/binseg/modeling/driupix.py
+++ b/bob/ip/binseg/modeling/driupix.py
@@ -5,22 +5,29 @@ import torch
 import torch.nn
 from collections import OrderedDict
 from bob.ip.binseg.modeling.backbones.vgg import vgg16
-from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform,convtrans_with_kaiming_uniform, UpsampleCropBlock
+from bob.ip.binseg.modeling.make_layers import (
+    conv_with_kaiming_uniform,
+    convtrans_with_kaiming_uniform,
+    UpsampleCropBlock,
+)
+
 
 class ConcatFuseBlock(torch.nn.Module):
     """
     Takes in four feature maps with 16 channels each, concatenates them
     and applies a 1x1 convolution with 1 output channel.
     """
+
     def __init__(self):
         super().__init__()
-        self.conv = conv_with_kaiming_uniform(4*16,1,1,1,0)
+        self.conv = conv_with_kaiming_uniform(4 * 16, 1, 1, 1, 0)
 
-    def forward(self,x1,x2,x3,x4):
-        x_cat = torch.cat([x1,x2,x3,x4],dim=1)
+    def forward(self, x1, x2, x3, x4):
+        x_cat = torch.cat([x1, x2, x3, x4], dim=1)
         x = self.conv(x_cat)
         return x
 
+
 class DRIUPIX(torch.nn.Module):
     """
     DRIUPIX head module. DRIU with pixelshuffle instead of ConvTrans2D
@@ -30,6 +37,7 @@ class DRIUPIX(torch.nn.Module):
     in_channels_list : list
         number of channels for each feature map that is returned from backbone
     """
+
     def __init__(self, in_channels_list=None):
         super(DRIUPIX, self).__init__()
         in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8 = in_channels_list
@@ -37,13 +45,17 @@ class DRIUPIX(torch.nn.Module):
         self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
         # Upsample layers
         self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0, pixelshuffle=True)
-        self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0, pixelshuffle=True)
-        self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0, pixelshuffle=True)
+        self.upsample4 = UpsampleCropBlock(
+            in_upsample_4, 16, 8, 4, 0, pixelshuffle=True
+        )
+        self.upsample8 = UpsampleCropBlock(
+            in_upsample_8, 16, 16, 8, 0, pixelshuffle=True
+        )
 
         # Concat and Fuse
         self.concatfuse = ConcatFuseBlock()
 
-    def forward(self,x):
+    def forward(self, x):
         """
         Parameters
         ----------
@@ -58,12 +70,13 @@ class DRIUPIX(torch.nn.Module):
         """
         hw = x[0]
         conv1_2_16 = self.conv1_2_16(x[1])  # conv1_2_16
-        upsample2 = self.upsample2(x[2], hw) # side-multi2-up
-        upsample4 = self.upsample4(x[3], hw) # side-multi3-up
-        upsample8 = self.upsample8(x[4], hw) # side-multi4-up
+        upsample2 = self.upsample2(x[2], hw)  # side-multi2-up
+        upsample4 = self.upsample4(x[3], hw)  # side-multi3-up
+        upsample8 = self.upsample8(x[4], hw)  # side-multi4-up
         out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
         return out
 
+
 def build_driupix():
     """
     Adds backbone and head together
@@ -73,9 +86,11 @@ def build_driupix():
     module : :py:class:`torch.nn.Module`
 
     """
-    backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22])
+    backbone = vgg16(pretrained=False, return_features=[3, 8, 14, 22])
     driu_head = DRIUPIX([64, 128, 256, 512])
 
-    model = torch.nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)]))
+    model = torch.nn.Sequential(
+        OrderedDict([("backbone", backbone), ("head", driu_head)])
+    )
     model.name = "DRIUPIX"
     return model
diff --git a/bob/ip/binseg/modeling/hed.py b/bob/ip/binseg/modeling/hed.py
index 9be7fc86..02c2b957 100644
--- a/bob/ip/binseg/modeling/hed.py
+++ b/bob/ip/binseg/modeling/hed.py
@@ -5,22 +5,29 @@ import torch
 import torch.nn
 from collections import OrderedDict
 from bob.ip.binseg.modeling.backbones.vgg import vgg16
-from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, UpsampleCropBlock
+from bob.ip.binseg.modeling.make_layers import (
+    conv_with_kaiming_uniform,
+    convtrans_with_kaiming_uniform,
+    UpsampleCropBlock,
+)
+
 
 class ConcatFuseBlock(torch.nn.Module):
     """
     Takes in five feature maps with one channel each, concatenates thems
     and applies a 1x1 convolution with 1 output channel.
     """
+
     def __init__(self):
         super().__init__()
-        self.conv = conv_with_kaiming_uniform(5,1,1,1,0)
+        self.conv = conv_with_kaiming_uniform(5, 1, 1, 1, 0)
 
-    def forward(self,x1,x2,x3,x4,x5):
-        x_cat = torch.cat([x1,x2,x3,x4,x5],dim=1)
+    def forward(self, x1, x2, x3, x4, x5):
+        x_cat = torch.cat([x1, x2, x3, x4, x5], dim=1)
         x = self.conv(x_cat)
         return x
 
+
 class HED(torch.nn.Module):
     """
     HED head module
@@ -30,20 +37,27 @@ class HED(torch.nn.Module):
     in_channels_list : list
         number of channels for each feature map that is returned from backbone
     """
+
     def __init__(self, in_channels_list=None):
         super(HED, self).__init__()
-        in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8, in_upsample_16 = in_channels_list
+        (
+            in_conv_1_2_16,
+            in_upsample2,
+            in_upsample_4,
+            in_upsample_8,
+            in_upsample_16,
+        ) = in_channels_list
 
-        self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16,1,3,1,1)
+        self.conv1_2_16 = torch.nn.Conv2d(in_conv_1_2_16, 1, 3, 1, 1)
         # Upsample
-        self.upsample2 = UpsampleCropBlock(in_upsample2,1,4,2,0)
-        self.upsample4 = UpsampleCropBlock(in_upsample_4,1,8,4,0)
-        self.upsample8 = UpsampleCropBlock(in_upsample_8,1,16,8,0)
-        self.upsample16 = UpsampleCropBlock(in_upsample_16,1,32,16,0)
+        self.upsample2 = UpsampleCropBlock(in_upsample2, 1, 4, 2, 0)
+        self.upsample4 = UpsampleCropBlock(in_upsample_4, 1, 8, 4, 0)
+        self.upsample8 = UpsampleCropBlock(in_upsample_8, 1, 16, 8, 0)
+        self.upsample16 = UpsampleCropBlock(in_upsample_16, 1, 32, 16, 0)
         # Concat and Fuse
         self.concatfuse = ConcatFuseBlock()
 
-    def forward(self,x):
+    def forward(self, x):
         """
         Parameters
         ----------
@@ -58,15 +72,18 @@ class HED(torch.nn.Module):
         """
         hw = x[0]
         conv1_2_16 = self.conv1_2_16(x[1])
-        upsample2 = self.upsample2(x[2],hw)
-        upsample4 = self.upsample4(x[3],hw)
-        upsample8 = self.upsample8(x[4],hw)
-        upsample16 = self.upsample16(x[5],hw)
-        concatfuse = self.concatfuse(conv1_2_16,upsample2,upsample4,upsample8,upsample16)
+        upsample2 = self.upsample2(x[2], hw)
+        upsample4 = self.upsample4(x[3], hw)
+        upsample8 = self.upsample8(x[4], hw)
+        upsample16 = self.upsample16(x[5], hw)
+        concatfuse = self.concatfuse(
+            conv1_2_16, upsample2, upsample4, upsample8, upsample16
+        )
 
-        out = [upsample2,upsample4,upsample8,upsample16,concatfuse]
+        out = [upsample2, upsample4, upsample8, upsample16, concatfuse]
         return out
 
+
 def build_hed():
     """
     Adds backbone and head together
@@ -75,9 +92,11 @@ def build_hed():
     -------
     module : :py:class:`torch.nn.Module`
     """
-    backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22, 29])
+    backbone = vgg16(pretrained=False, return_features=[3, 8, 14, 22, 29])
     hed_head = HED([64, 128, 256, 512, 512])
 
-    model = torch.nn.Sequential(OrderedDict([("backbone", backbone), ("head", hed_head)]))
+    model = torch.nn.Sequential(
+        OrderedDict([("backbone", backbone), ("head", hed_head)])
+    )
     model.name = "HED"
     return model
diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py
index 93235d03..e2aa27b8 100644
--- a/bob/ip/binseg/modeling/losses.py
+++ b/bob/ip/binseg/modeling/losses.py
@@ -32,9 +32,7 @@ class WeightedBCELogitsLoss(_Loss):
         reduction="mean",
         pos_weight=None,
     ):
-        super(WeightedBCELogitsLoss, self).__init__(
-            size_average, reduce, reduction
-        )
+        super(WeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction)
         self.register_buffer("weight", weight)
         self.register_buffer("pos_weight", pos_weight)
 
@@ -56,9 +54,7 @@ class WeightedBCELogitsLoss(_Loss):
             torch.sum(target, dim=[1, 2, 3]).float().reshape(n, 1)
         )  # torch.Size([n, 1])
         if hasattr(masks, "dtype"):
-            num_mask_neg = c * h * w - torch.sum(
-                masks, dim=[1, 2, 3]
-            ).float().reshape(
+            num_mask_neg = c * h * w - torch.sum(masks, dim=[1, 2, 3]).float().reshape(
                 n, 1
             )  # torch.Size([n, 1])
             num_neg = c * h * w - num_pos - num_mask_neg
@@ -97,9 +93,7 @@ class SoftJaccardBCELogitsLoss(_Loss):
         reduction="mean",
         pos_weight=None,
     ):
-        super(SoftJaccardBCELogitsLoss, self).__init__(
-            size_average, reduce, reduction
-        )
+        super(SoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction)
         self.alpha = alpha
 
     @weak_script_method
@@ -145,9 +139,7 @@ class HEDWeightedBCELogitsLoss(_Loss):
         reduction="mean",
         pos_weight=None,
     ):
-        super(HEDWeightedBCELogitsLoss, self).__init__(
-            size_average, reduce, reduction
-        )
+        super(HEDWeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction)
         self.register_buffer("weight", weight)
         self.register_buffer("pos_weight", pos_weight)
 
@@ -185,9 +177,7 @@ class HEDWeightedBCELogitsLoss(_Loss):
             numnegnumtotal = torch.ones_like(target) * (
                 num_neg / (num_pos + num_neg)
             ).unsqueeze(1).unsqueeze(2)
-            weight = torch.where(
-                (target <= 0.5), numposnumtotal, numnegnumtotal
-            )
+            weight = torch.where((target <= 0.5), numposnumtotal, numnegnumtotal)
             loss = torch.nn.functional.binary_cross_entropy_with_logits(
                 input, target, weight=weight, reduction=self.reduction
             )
@@ -278,9 +268,7 @@ class MixJacLoss(_Loss):
         self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()
 
     @weak_script_method
-    def forward(
-        self, input, target, unlabeled_input, unlabeled_traget, ramp_up_factor
-    ):
+    def forward(self, input, target, unlabeled_input, unlabeled_traget, ramp_up_factor):
         """
         Parameters
         ----------
diff --git a/bob/ip/binseg/modeling/m2u.py b/bob/ip/binseg/modeling/m2u.py
index fa34c579..25bc0515 100644
--- a/bob/ip/binseg/modeling/m2u.py
+++ b/bob/ip/binseg/modeling/m2u.py
@@ -8,35 +8,46 @@ import torch
 import torch.nn
 from bob.ip.binseg.modeling.backbones.mobilenetv2 import MobileNetV2, InvertedResidual
 
+
 class DecoderBlock(torch.nn.Module):
     """
     Decoder block: upsample and concatenate with features maps from the encoder part
     """
-    def __init__(self,up_in_c,x_in_c,upsamplemode='bilinear',expand_ratio=0.15):
-        super().__init__()
-        self.upsample = torch.nn.Upsample(scale_factor=2,mode=upsamplemode,align_corners=False) # H, W -> 2H, 2W
-        self.ir1 = InvertedResidual(up_in_c+x_in_c,(x_in_c + up_in_c) // 2,stride=1,expand_ratio=expand_ratio)
 
-    def forward(self,up_in,x_in):
+    def __init__(self, up_in_c, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
+        super().__init__()
+        self.upsample = torch.nn.Upsample(
+            scale_factor=2, mode=upsamplemode, align_corners=False
+        )  # H, W -> 2H, 2W
+        self.ir1 = InvertedResidual(
+            up_in_c + x_in_c,
+            (x_in_c + up_in_c) // 2,
+            stride=1,
+            expand_ratio=expand_ratio,
+        )
+
+    def forward(self, up_in, x_in):
         up_out = self.upsample(up_in)
-        cat_x = torch.cat([up_out, x_in] , dim=1)
+        cat_x = torch.cat([up_out, x_in], dim=1)
         x = self.ir1(cat_x)
         return x
 
+
 class LastDecoderBlock(torch.nn.Module):
-    def __init__(self,x_in_c,upsamplemode='bilinear',expand_ratio=0.15):
+    def __init__(self, x_in_c, upsamplemode="bilinear", expand_ratio=0.15):
         super().__init__()
-        self.upsample = torch.nn.Upsample(scale_factor=2,mode=upsamplemode,align_corners=False) # H, W -> 2H, 2W
-        self.ir1 = InvertedResidual(x_in_c,1,stride=1,expand_ratio=expand_ratio)
+        self.upsample = torch.nn.Upsample(
+            scale_factor=2, mode=upsamplemode, align_corners=False
+        )  # H, W -> 2H, 2W
+        self.ir1 = InvertedResidual(x_in_c, 1, stride=1, expand_ratio=expand_ratio)
 
-    def forward(self,up_in,x_in):
+    def forward(self, up_in, x_in):
         up_out = self.upsample(up_in)
-        cat_x = torch.cat([up_out, x_in] , dim=1)
+        cat_x = torch.cat([up_out, x_in], dim=1)
         x = self.ir1(cat_x)
         return x
 
 
-
 class M2U(torch.nn.Module):
     """
     M2U-Net head module
@@ -46,14 +57,17 @@ class M2U(torch.nn.Module):
     in_channels_list : list
         number of channels for each feature map that is returned from backbone
     """
-    def __init__(self, in_channels_list=None,upsamplemode='bilinear',expand_ratio=0.15):
+
+    def __init__(
+        self, in_channels_list=None, upsamplemode="bilinear", expand_ratio=0.15
+    ):
         super(M2U, self).__init__()
 
         # Decoder
-        self.decode4 = DecoderBlock(96,32,upsamplemode,expand_ratio)
-        self.decode3 = DecoderBlock(64,24,upsamplemode,expand_ratio)
-        self.decode2 = DecoderBlock(44,16,upsamplemode,expand_ratio)
-        self.decode1 = LastDecoderBlock(33,upsamplemode,expand_ratio)
+        self.decode4 = DecoderBlock(96, 32, upsamplemode, expand_ratio)
+        self.decode3 = DecoderBlock(64, 24, upsamplemode, expand_ratio)
+        self.decode2 = DecoderBlock(44, 16, upsamplemode, expand_ratio)
+        self.decode1 = LastDecoderBlock(33, upsamplemode, expand_ratio)
 
         # initilaize weights
         self._initialize_weights()
@@ -68,7 +82,7 @@ class M2U(torch.nn.Module):
                 m.weight.data.fill_(1)
                 m.bias.data.zero_()
 
-    def forward(self,x):
+    def forward(self, x):
         """
         Parameters
         ----------
@@ -80,13 +94,14 @@ class M2U(torch.nn.Module):
         -------
         tensor : :py:class:`torch.Tensor`
         """
-        decode4 = self.decode4(x[5],x[4])    # 96, 32
-        decode3 = self.decode3(decode4,x[3]) # 64, 24
-        decode2 = self.decode2(decode3,x[2]) # 44, 16
-        decode1 = self.decode1(decode2,x[1]) # 30, 3
+        decode4 = self.decode4(x[5], x[4])  # 96, 32
+        decode3 = self.decode3(decode4, x[3])  # 64, 24
+        decode2 = self.decode2(decode3, x[2])  # 44, 16
+        decode1 = self.decode1(decode2, x[1])  # 30, 3
 
         return decode1
 
+
 def build_m2unet():
     """
     Adds backbone and head together
@@ -95,9 +110,11 @@ def build_m2unet():
     -------
     module : :py:class:`torch.nn.Module`
     """
-    backbone = MobileNetV2(return_features = [1, 3, 6, 13], m2u=True)
+    backbone = MobileNetV2(return_features=[1, 3, 6, 13], m2u=True)
     m2u_head = M2U(in_channels_list=[16, 24, 32, 96])
 
-    model = torch.nn.Sequential(OrderedDict([("backbone", backbone), ("head", m2u_head)]))
+    model = torch.nn.Sequential(
+        OrderedDict([("backbone", backbone), ("head", m2u_head)])
+    )
     model.name = "M2UNet"
     return model
diff --git a/bob/ip/binseg/modeling/make_layers.py b/bob/ip/binseg/modeling/make_layers.py
index 88103048..23704eae 100644
--- a/bob/ip/binseg/modeling/make_layers.py
+++ b/bob/ip/binseg/modeling/make_layers.py
@@ -6,7 +6,10 @@ import torch.nn
 from torch.nn import Conv2d
 from torch.nn import ConvTranspose2d
 
-def conv_with_kaiming_uniform(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
+
+def conv_with_kaiming_uniform(
+    in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
+):
     conv = Conv2d(
         in_channels,
         out_channels,
@@ -14,16 +17,18 @@ def conv_with_kaiming_uniform(in_channels, out_channels, kernel_size, stride=1,
         stride=stride,
         padding=padding,
         dilation=dilation,
-        bias= True
-        )
-        # Caffe2 implementation uses XavierFill, which in fact
-        # corresponds to kaiming_uniform_ in PyTorch
+        bias=True,
+    )
+    # Caffe2 implementation uses XavierFill, which in fact
+    # corresponds to kaiming_uniform_ in PyTorch
     torch.nn.init.kaiming_uniform_(conv.weight, a=1)
     torch.nn.init.constant_(conv.bias, 0)
     return conv
 
 
-def convtrans_with_kaiming_uniform(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
+def convtrans_with_kaiming_uniform(
+    in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
+):
     conv = ConvTranspose2d(
         in_channels,
         out_channels,
@@ -31,10 +36,10 @@ def convtrans_with_kaiming_uniform(in_channels, out_channels, kernel_size, strid
         stride=stride,
         padding=padding,
         dilation=dilation,
-        bias= True
-        )
-        # Caffe2 implementation uses XavierFill, which in fact
-        # corresponds to kaiming_uniform_ in PyTorch
+        bias=True,
+    )
+    # Caffe2 implementation uses XavierFill, which in fact
+    # corresponds to kaiming_uniform_ in PyTorch
     torch.nn.init.kaiming_uniform_(conv.weight, a=1)
     torch.nn.init.constant_(conv.bias, 0)
     return conv
@@ -63,15 +68,24 @@ class UpsampleCropBlock(torch.nn.Module):
 
     """
 
-    def __init__(self, in_channels, out_channels, up_kernel_size, up_stride, up_padding, pixelshuffle=False):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        up_kernel_size,
+        up_stride,
+        up_padding,
+        pixelshuffle=False,
+    ):
         super().__init__()
         # NOTE: Kaiming init, replace with torch.nn.Conv2d and torch.nn.ConvTranspose2d to get original DRIU impl.
         self.conv = conv_with_kaiming_uniform(in_channels, out_channels, 3, 1, 1)
         if pixelshuffle:
-            self.upconv = PixelShuffle_ICNR( out_channels, out_channels, scale = up_stride)
+            self.upconv = PixelShuffle_ICNR(out_channels, out_channels, scale=up_stride)
         else:
-            self.upconv = convtrans_with_kaiming_uniform(out_channels, out_channels, up_kernel_size, up_stride, up_padding)
-
+            self.upconv = convtrans_with_kaiming_uniform(
+                out_channels, out_channels, up_kernel_size, up_stride, up_padding
+            )
 
     def forward(self, x, input_res):
         """Forward pass of UpsampleBlock.
@@ -98,39 +112,40 @@ class UpsampleCropBlock(torch.nn.Module):
         # height
         up_h = x.shape[2]
         h_crop = up_h - img_h
-        h_s = h_crop//2
+        h_s = h_crop // 2
         h_e = up_h - (h_crop - h_s)
         # width
         up_w = x.shape[3]
-        w_crop = up_w-img_w
-        w_s = w_crop//2
+        w_crop = up_w - img_w
+        w_s = w_crop // 2
         w_e = up_w - (w_crop - w_s)
         # perform crop
         # needs explicit ranges for onnx export
-        x = x[:,:,h_s:h_e,w_s:w_e] # crop to input size
+        x = x[:, :, h_s:h_e, w_s:w_e]  # crop to input size
 
         return x
 
 
-
 def ifnone(a, b):
     "``a`` if ``a`` is not None, otherwise ``b``."
     return b if a is None else a
 
+
 def icnr(x, scale=2, init=torch.nn.init.kaiming_normal_):
     """https://docs.fast.ai/layers.html#PixelShuffle_ICNR
 
     ICNR init of ``x``, with ``scale`` and ``init`` function.
     """
 
-    ni,nf,h,w = x.shape
-    ni2 = int(ni/(scale**2))
-    k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
+    ni, nf, h, w = x.shape
+    ni2 = int(ni / (scale ** 2))
+    k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
     k = k.contiguous().view(ni2, nf, -1)
-    k = k.repeat(1, 1, scale**2)
-    k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
+    k = k.repeat(1, 1, scale ** 2)
+    k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
     x.data.copy_(k)
 
+
 class PixelShuffle_ICNR(torch.nn.Module):
     """https://docs.fast.ai/layers.html#PixelShuffle_ICNR
 
@@ -138,47 +153,52 @@ class PixelShuffle_ICNR(torch.nn.Module):
     ``torch.nn.PixelShuffle``, ``icnr`` init, and ``weight_norm``.
     """
 
-    def __init__(self, ni:int, nf:int=None, scale:int=2):
+    def __init__(self, ni: int, nf: int = None, scale: int = 2):
         super().__init__()
         nf = ifnone(nf, ni)
-        self.conv = conv_with_kaiming_uniform(ni, nf*(scale**2), 1)
+        self.conv = conv_with_kaiming_uniform(ni, nf * (scale ** 2), 1)
         icnr(self.conv.weight)
         self.shuf = torch.nn.PixelShuffle(scale)
         # Blurring over (h*w) kernel
         # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
         # - https://arxiv.org/abs/1806.02658
-        self.pad = torch.nn.ReplicationPad2d((1,0,1,0))
+        self.pad = torch.nn.ReplicationPad2d((1, 0, 1, 0))
         self.blur = torch.nn.AvgPool2d(2, stride=1)
         self.relu = torch.nn.ReLU(inplace=True)
 
-    def forward(self,x):
+    def forward(self, x):
         x = self.shuf(self.relu(self.conv(x)))
         x = self.blur(self.pad(x))
         return x
 
+
 class UnetBlock(torch.nn.Module):
     def __init__(self, up_in_c, x_in_c, pixel_shuffle=False, middle_block=False):
         super().__init__()
 
         # middle block for VGG based U-Net
         if middle_block:
-            up_out_c =  up_in_c
+            up_out_c = up_in_c
         else:
-            up_out_c =  up_in_c // 2
+            up_out_c = up_in_c // 2
         cat_channels = x_in_c + up_out_c
         inner_channels = cat_channels // 2
 
         if pixel_shuffle:
-            self.upsample = PixelShuffle_ICNR( up_in_c, up_out_c )
+            self.upsample = PixelShuffle_ICNR(up_in_c, up_out_c)
         else:
-            self.upsample = convtrans_with_kaiming_uniform( up_in_c, up_out_c, 2, 2)
-        self.convtrans1 = convtrans_with_kaiming_uniform( cat_channels, inner_channels, 3, 1, 1)
-        self.convtrans2 = convtrans_with_kaiming_uniform( inner_channels, inner_channels, 3, 1, 1)
+            self.upsample = convtrans_with_kaiming_uniform(up_in_c, up_out_c, 2, 2)
+        self.convtrans1 = convtrans_with_kaiming_uniform(
+            cat_channels, inner_channels, 3, 1, 1
+        )
+        self.convtrans2 = convtrans_with_kaiming_uniform(
+            inner_channels, inner_channels, 3, 1, 1
+        )
         self.relu = torch.nn.ReLU(inplace=True)
 
     def forward(self, up_in, x_in):
         up_out = self.upsample(up_in)
-        cat_x = torch.cat([up_out, x_in] , dim=1)
+        cat_x = torch.cat([up_out, x_in], dim=1)
         x = self.relu(self.convtrans1(cat_x))
         x = self.relu(self.convtrans2(x))
         return x
diff --git a/bob/ip/binseg/modeling/resunet.py b/bob/ip/binseg/modeling/resunet.py
index 5256bd72..7b262c7a 100644
--- a/bob/ip/binseg/modeling/resunet.py
+++ b/bob/ip/binseg/modeling/resunet.py
@@ -4,11 +4,15 @@
 import torch.nn as nn
 import torch
 from collections import OrderedDict
-from bob.ip.binseg.modeling.make_layers  import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, PixelShuffle_ICNR, UnetBlock
+from bob.ip.binseg.modeling.make_layers import (
+    conv_with_kaiming_uniform,
+    convtrans_with_kaiming_uniform,
+    PixelShuffle_ICNR,
+    UnetBlock,
+)
 from bob.ip.binseg.modeling.backbones.resnet import resnet50
 
 
-
 class ResUNet(nn.Module):
     """
     UNet head module for ResNet backbones
@@ -18,12 +22,13 @@ class ResUNet(nn.Module):
     in_channels_list : list
                         number of channels for each feature map that is returned from backbone
     """
+
     def __init__(self, in_channels_list=None, pixel_shuffle=False):
         super(ResUNet, self).__init__()
         # number of channels
         c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
         # number of channels for last upsampling operation
-        c_decode0 = (c_decode1 + c_decode2//2)//2
+        c_decode0 = (c_decode1 + c_decode2 // 2) // 2
 
         # build layers
         self.decode4 = UnetBlock(c_decode5, c_decode4, pixel_shuffle)
@@ -36,7 +41,7 @@ class ResUNet(nn.Module):
             self.decode0 = convtrans_with_kaiming_uniform(c_decode0, c_decode0, 2, 2)
         self.final = conv_with_kaiming_uniform(c_decode0, 1, 1)
 
-    def forward(self,x):
+    def forward(self, x):
         """
         Parameters
         ----------
@@ -54,6 +59,7 @@ class ResUNet(nn.Module):
         out = self.final(decode0)
         return out
 
+
 def build_res50unet():
     """
     Adds backbone and head together
@@ -62,8 +68,8 @@ def build_res50unet():
     -------
     model : :py:class:`torch.nn.Module`
     """
-    backbone = resnet50(pretrained=False, return_features = [2, 4, 5, 6, 7])
-    unet_head  = ResUNet([64, 256, 512, 1024, 2048],pixel_shuffle=False)
+    backbone = resnet50(pretrained=False, return_features=[2, 4, 5, 6, 7])
+    unet_head = ResUNet([64, 256, 512, 1024, 2048], pixel_shuffle=False)
     model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
     model.name = "ResUNet"
     return model
diff --git a/bob/ip/binseg/modeling/unet.py b/bob/ip/binseg/modeling/unet.py
index b89ec5f1..7602d63a 100644
--- a/bob/ip/binseg/modeling/unet.py
+++ b/bob/ip/binseg/modeling/unet.py
@@ -4,11 +4,15 @@
 import torch.nn as nn
 import torch
 from collections import OrderedDict
-from bob.ip.binseg.modeling.make_layers  import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, PixelShuffle_ICNR, UnetBlock
+from bob.ip.binseg.modeling.make_layers import (
+    conv_with_kaiming_uniform,
+    convtrans_with_kaiming_uniform,
+    PixelShuffle_ICNR,
+    UnetBlock,
+)
 from bob.ip.binseg.modeling.backbones.vgg import vgg16
 
 
-
 class UNet(nn.Module):
     """
     UNet head module
@@ -18,6 +22,7 @@ class UNet(nn.Module):
     in_channels_list : list
                         number of channels for each feature map that is returned from backbone
     """
+
     def __init__(self, in_channels_list=None, pixel_shuffle=False):
         super(UNet, self).__init__()
         # number of channels
@@ -30,7 +35,7 @@ class UNet(nn.Module):
         self.decode1 = UnetBlock(c_decode2, c_decode1, pixel_shuffle)
         self.final = conv_with_kaiming_uniform(c_decode1, 1, 1)
 
-    def forward(self,x):
+    def forward(self, x):
         """
         Parameters
         ----------
@@ -47,6 +52,7 @@ class UNet(nn.Module):
         out = self.final(decode1)
         return out
 
+
 def build_unet():
     """
     Adds backbone and head together
@@ -56,7 +62,7 @@ def build_unet():
     module : :py:class:`torch.nn.Module`
     """
 
-    backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22, 29])
+    backbone = vgg16(pretrained=False, return_features=[3, 8, 14, 22, 29])
     unet_head = UNet([64, 128, 256, 512, 512], pixel_shuffle=False)
 
     model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py
index 944a9995..9fc05e9a 100644
--- a/bob/ip/binseg/script/binseg.py
+++ b/bob/ip/binseg/script/binseg.py
@@ -18,8 +18,12 @@ import logging
 import torch
 
 import bob.extension
-from bob.extension.scripts.click_helper import (verbosity_option,
-    ConfigCommand, ResourceOption, AliasedGroup)
+from bob.extension.scripts.click_helper import (
+    verbosity_option,
+    ConfigCommand,
+    ResourceOption,
+    AliasedGroup,
+)
 
 from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer
 from torch.utils.data import DataLoader
@@ -29,7 +33,7 @@ from bob.ip.binseg.engine.inferencer import do_inference
 from bob.ip.binseg.utils.plot import plot_overview
 from bob.ip.binseg.utils.click import OptionEatAll
 from bob.ip.binseg.utils.rsttable import create_overview_grid
-from bob.ip.binseg.utils.plot import metricsviz, overlay,savetransformedtest
+from bob.ip.binseg.utils.plot import metricsviz, overlay, savetransformedtest
 from bob.ip.binseg.utils.transformfolder import transformfolder as transfld
 from bob.ip.binseg.utils.evaluate import do_eval
 from bob.ip.binseg.engine.predicter import do_predict
@@ -37,121 +41,94 @@ from bob.ip.binseg.engine.predicter import do_predict
 logger = logging.getLogger(__name__)
 
 
-@with_plugins(pkg_resources.iter_entry_points('bob.ip.binseg.cli'))
+@with_plugins(pkg_resources.iter_entry_points("bob.ip.binseg.cli"))
 @click.group(cls=AliasedGroup)
 def binseg():
     """Binary 2D Fundus Image Segmentation Benchmark commands."""
     pass
 
+
 # Train
-@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
-@click.option(
-    '--output-path',
-    '-o',
-    required=True,
-    default="output",
-    cls=ResourceOption
-    )
-@click.option(
-    '--model',
-    '-m',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--dataset',
-    '-d',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--optimizer',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--criterion',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--scheduler',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--pretrained-backbone',
-    '-t',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--batch-size',
-    '-b',
-    required=True,
-    default=2,
-    cls=ResourceOption)
-@click.option(
-    '--epochs',
-    '-e',
-    help='Number of epochs used for training',
+@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
+@click.option(
+    "--output-path", "-o", required=True, default="output", cls=ResourceOption
+)
+@click.option("--model", "-m", required=True, cls=ResourceOption)
+@click.option("--dataset", "-d", required=True, cls=ResourceOption)
+@click.option("--optimizer", required=True, cls=ResourceOption)
+@click.option("--criterion", required=True, cls=ResourceOption)
+@click.option("--scheduler", required=True, cls=ResourceOption)
+@click.option("--pretrained-backbone", "-t", required=True, cls=ResourceOption)
+@click.option("--batch-size", "-b", required=True, default=2, cls=ResourceOption)
+@click.option(
+    "--epochs",
+    "-e",
+    help="Number of epochs used for training",
     show_default=True,
     required=True,
     default=1000,
-    cls=ResourceOption)
+    cls=ResourceOption,
+)
 @click.option(
-    '--checkpoint-period',
-    '-p',
-    help='Number of epochs after which a checkpoint is saved',
+    "--checkpoint-period",
+    "-p",
+    help="Number of epochs after which a checkpoint is saved",
     show_default=True,
     required=True,
     default=100,
-    cls=ResourceOption)
+    cls=ResourceOption,
+)
 @click.option(
-    '--device',
-    '-d',
+    "--device",
+    "-d",
     help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
     show_default=True,
     required=True,
-    default='cpu',
-    cls=ResourceOption)
+    default="cpu",
+    cls=ResourceOption,
+)
 @click.option(
-    '--seed',
-    '-s',
-    help='torch random seed',
+    "--seed",
+    "-s",
+    help="torch random seed",
     show_default=True,
     required=False,
     default=42,
-    cls=ResourceOption)
-
+    cls=ResourceOption,
+)
 @verbosity_option(cls=ResourceOption)
-def train(model
-        ,optimizer
-        ,scheduler
-        ,output_path
-        ,epochs
-        ,pretrained_backbone
-        ,batch_size
-        ,criterion
-        ,dataset
-        ,checkpoint_period
-        ,device
-        ,seed
-        ,**kwargs):
+def train(
+    model,
+    optimizer,
+    scheduler,
+    output_path,
+    epochs,
+    pretrained_backbone,
+    batch_size,
+    criterion,
+    dataset,
+    checkpoint_period,
+    device,
+    seed,
+    **kwargs
+):
     """ Train a model """
 
-    if not os.path.exists(output_path): os.makedirs(output_path)
+    if not os.path.exists(output_path):
+        os.makedirs(output_path)
     torch.manual_seed(seed)
     # PyTorch dataloader
     data_loader = DataLoader(
-        dataset = dataset
-        ,batch_size = batch_size
-        ,shuffle= True
-        ,pin_memory = torch.cuda.is_available()
-        )
+        dataset=dataset,
+        batch_size=batch_size,
+        shuffle=True,
+        pin_memory=torch.cuda.is_available(),
+    )
 
     # Checkpointer
-    checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = output_path, save_to_disk=True)
+    checkpointer = DetectronCheckpointer(
+        model, optimizer, scheduler, save_dir=output_path, save_to_disk=True
+    )
     arguments = {}
     arguments["epoch"] = 0
     extra_checkpoint_data = checkpointer.load(pretrained_backbone)
@@ -161,125 +138,98 @@ def train(model
     # Train
     logger.info("Training for {} epochs".format(arguments["max_epoch"]))
     logger.info("Continuing from epoch {}".format(arguments["epoch"]))
-    do_train(model
-            , data_loader
-            , optimizer
-            , criterion
-            , scheduler
-            , checkpointer
-            , checkpoint_period
-            , device
-            , arguments
-            , output_path
-            )
+    do_train(
+        model,
+        data_loader,
+        optimizer,
+        criterion,
+        scheduler,
+        checkpointer,
+        checkpoint_period,
+        device,
+        arguments,
+        output_path,
+    )
 
 
 # Inference
-@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
+@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
 @click.option(
-    '--output-path',
-    '-o',
-    required=True,
-    default="output",
-    cls=ResourceOption
-    )
+    "--output-path", "-o", required=True, default="output", cls=ResourceOption
+)
+@click.option("--model", "-m", required=True, cls=ResourceOption)
+@click.option("--dataset", "-d", required=True, cls=ResourceOption)
+@click.option("--batch-size", "-b", required=True, default=2, cls=ResourceOption)
 @click.option(
-    '--model',
-    '-m',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--dataset',
-    '-d',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--batch-size',
-    '-b',
-    required=True,
-    default=2,
-    cls=ResourceOption)
-@click.option(
-    '--device',
-    '-d',
+    "--device",
+    "-d",
     help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
     show_default=True,
     required=True,
-    default='cpu',
-    cls=ResourceOption)
+    default="cpu",
+    cls=ResourceOption,
+)
 @click.option(
-    '--weight',
-    '-w',
-    help='Path or URL to pretrained model',
+    "--weight",
+    "-w",
+    help="Path or URL to pretrained model",
     required=False,
     default=None,
-    cls=ResourceOption
-    )
+    cls=ResourceOption,
+)
 @verbosity_option(cls=ResourceOption)
-def test(model
-        ,output_path
-        ,device
-        ,batch_size
-        ,dataset
-        ,weight
-        , **kwargs):
+def test(model, output_path, device, batch_size, dataset, weight, **kwargs):
     """ Run inference and evalaute the model performance """
 
     # PyTorch dataloader
     data_loader = DataLoader(
-        dataset = dataset
-        ,batch_size = batch_size
-        ,shuffle= False
-        ,pin_memory = torch.cuda.is_available()
-        )
+        dataset=dataset,
+        batch_size=batch_size,
+        shuffle=False,
+        pin_memory=torch.cuda.is_available(),
+    )
 
     # checkpointer, load last model in dir
-    checkpointer = DetectronCheckpointer(model, save_dir = output_path, save_to_disk=False)
+    checkpointer = DetectronCheckpointer(
+        model, save_dir=output_path, save_to_disk=False
+    )
     checkpointer.load(weight)
     do_inference(model, data_loader, device, output_path)
 
 
-
 # Plot comparison
-@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
+@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
 @click.option(
-    '--output-path-list',
-    '-l',
+    "--output-path-list",
+    "-l",
     required=True,
-    help='Pass all output paths as arguments',
+    help="Pass all output paths as arguments",
     cls=OptionEatAll,
-    )
+)
 @click.option(
-    '--output-path',
-    '-o',
-    required=True,
-    )
+    "--output-path", "-o", required=True,
+)
 @click.option(
-    '--title',
-    '-t',
-    required=False,
-    )
+    "--title", "-t", required=False,
+)
 @verbosity_option(cls=ResourceOption)
 def compare(output_path_list, output_path, title, **kwargs):
     """ Compares multiple metrics files that are stored in the format mymodel/results/Metrics.csv """
     logger.debug("Output paths: {}".format(output_path_list))
-    logger.info('Plotting precision vs recall curves for {}'.format(output_path_list))
-    fig = plot_overview(output_path_list,title)
-    if not os.path.exists(output_path): os.makedirs(output_path)
-    fig_filename = os.path.join(output_path, 'precision_recall_comparison.pdf')
-    logger.info('saving {}'.format(fig_filename))
+    logger.info("Plotting precision vs recall curves for {}".format(output_path_list))
+    fig = plot_overview(output_path_list, title)
+    if not os.path.exists(output_path):
+        os.makedirs(output_path)
+    fig_filename = os.path.join(output_path, "precision_recall_comparison.pdf")
+    logger.info("saving {}".format(fig_filename))
     fig.savefig(fig_filename)
 
 
 # Create grid table with results
-@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
+@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
 @click.option(
-    '--output-path',
-    '-o',
-    required=True,
-    )
+    "--output-path", "-o", required=True,
+)
 @verbosity_option(cls=ResourceOption)
 def gridtable(output_path, **kwargs):
     """ Creates an overview table in grid rst format for all Metrics.csv in the output_path
@@ -289,23 +239,16 @@ def gridtable(output_path, **kwargs):
             ├── images
             └── results
     """
-    logger.info('Creating grid for all results in {}'.format(output_path))
+    logger.info("Creating grid for all results in {}".format(output_path))
     create_overview_grid(output_path)
 
 
 # Create metrics viz
-@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
-@click.option(
-    '--dataset',
-    '-d',
-    required=True,
-    cls=ResourceOption
-    )
+@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
+@click.option("--dataset", "-d", required=True, cls=ResourceOption)
 @click.option(
-    '--output-path',
-    '-o',
-    required=True,
-    )
+    "--output-path", "-o", required=True,
+)
 @verbosity_option(cls=ResourceOption)
 def visualize(dataset, output_path, **kwargs):
     """ Creates the following visualizations of the probabilties output maps:
@@ -318,132 +261,105 @@ def visualize(dataset, output_path, **kwargs):
             ├── images
             └── results
     """
-    logger.info('Creating TP, FP, FN visualizations for {}'.format(output_path))
+    logger.info("Creating TP, FP, FN visualizations for {}".format(output_path))
     metricsviz(dataset=dataset, output_path=output_path)
-    logger.info('Creating overlay visualizations for {}'.format(output_path))
+    logger.info("Creating overlay visualizations for {}".format(output_path))
     overlay(dataset=dataset, output_path=output_path)
-    logger.info('Saving transformed test images {}'.format(output_path))
+    logger.info("Saving transformed test images {}".format(output_path))
     savetransformedtest(dataset=dataset, output_path=output_path)
 
 
 # SSLTrain
-@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
-@click.option(
-    '--output-path',
-    '-o',
-    required=True,
-    default="output",
-    cls=ResourceOption
-    )
-@click.option(
-    '--model',
-    '-m',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--dataset',
-    '-d',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--optimizer',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--criterion',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--scheduler',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--pretrained-backbone',
-    '-t',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--batch-size',
-    '-b',
-    required=True,
-    default=2,
-    cls=ResourceOption)
-@click.option(
-    '--epochs',
-    '-e',
-    help='Number of epochs used for training',
+@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
+@click.option(
+    "--output-path", "-o", required=True, default="output", cls=ResourceOption
+)
+@click.option("--model", "-m", required=True, cls=ResourceOption)
+@click.option("--dataset", "-d", required=True, cls=ResourceOption)
+@click.option("--optimizer", required=True, cls=ResourceOption)
+@click.option("--criterion", required=True, cls=ResourceOption)
+@click.option("--scheduler", required=True, cls=ResourceOption)
+@click.option("--pretrained-backbone", "-t", required=True, cls=ResourceOption)
+@click.option("--batch-size", "-b", required=True, default=2, cls=ResourceOption)
+@click.option(
+    "--epochs",
+    "-e",
+    help="Number of epochs used for training",
     show_default=True,
     required=True,
     default=1000,
-    cls=ResourceOption)
+    cls=ResourceOption,
+)
 @click.option(
-    '--checkpoint-period',
-    '-p',
-    help='Number of epochs after which a checkpoint is saved',
+    "--checkpoint-period",
+    "-p",
+    help="Number of epochs after which a checkpoint is saved",
     show_default=True,
     required=True,
     default=100,
-    cls=ResourceOption)
+    cls=ResourceOption,
+)
 @click.option(
-    '--device',
-    '-d',
+    "--device",
+    "-d",
     help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
     show_default=True,
     required=True,
-    default='cpu',
-    cls=ResourceOption)
+    default="cpu",
+    cls=ResourceOption,
+)
 @click.option(
-    '--rampup',
-    '-r',
-    help='Ramp-up length in epochs',
+    "--rampup",
+    "-r",
+    help="Ramp-up length in epochs",
     show_default=True,
     required=True,
-    default='900',
-    cls=ResourceOption)
+    default="900",
+    cls=ResourceOption,
+)
 @click.option(
-    '--seed',
-    '-s',
-    help='torch random seed',
+    "--seed",
+    "-s",
+    help="torch random seed",
     show_default=True,
     required=False,
     default=42,
-    cls=ResourceOption)
-
+    cls=ResourceOption,
+)
 @verbosity_option(cls=ResourceOption)
-def ssltrain(model
-        ,optimizer
-        ,scheduler
-        ,output_path
-        ,epochs
-        ,pretrained_backbone
-        ,batch_size
-        ,criterion
-        ,dataset
-        ,checkpoint_period
-        ,device
-        ,rampup
-        ,seed
-        ,**kwargs):
+def ssltrain(
+    model,
+    optimizer,
+    scheduler,
+    output_path,
+    epochs,
+    pretrained_backbone,
+    batch_size,
+    criterion,
+    dataset,
+    checkpoint_period,
+    device,
+    rampup,
+    seed,
+    **kwargs
+):
     """ Train a model """
 
-    if not os.path.exists(output_path): os.makedirs(output_path)
+    if not os.path.exists(output_path):
+        os.makedirs(output_path)
     torch.manual_seed(seed)
     # PyTorch dataloader
     data_loader = DataLoader(
-        dataset = dataset
-        ,batch_size = batch_size
-        ,shuffle= True
-        ,pin_memory = torch.cuda.is_available()
-        )
+        dataset=dataset,
+        batch_size=batch_size,
+        shuffle=True,
+        pin_memory=torch.cuda.is_available(),
+    )
 
     # Checkpointer
-    checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = output_path, save_to_disk=True)
+    checkpointer = DetectronCheckpointer(
+        model, optimizer, scheduler, save_dir=output_path, save_to_disk=True
+    )
     arguments = {}
     arguments["epoch"] = 0
     extra_checkpoint_data = checkpointer.load(pretrained_backbone)
@@ -453,109 +369,77 @@ def ssltrain(model
     # Train
     logger.info("Training for {} epochs".format(arguments["max_epoch"]))
     logger.info("Continuing from epoch {}".format(arguments["epoch"]))
-    do_ssltrain(model
-            , data_loader
-            , optimizer
-            , criterion
-            , scheduler
-            , checkpointer
-            , checkpoint_period
-            , device
-            , arguments
-            , output_path
-            , rampup
-            )
-
-# Apply image transforms to a folder containing images
-@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
-@click.option(
-    '--source-path',
-    '-s',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--target-path',
-    '-t',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--transforms',
-    '-a',
-    required=True,
-    cls=ResourceOption
+    do_ssltrain(
+        model,
+        data_loader,
+        optimizer,
+        criterion,
+        scheduler,
+        checkpointer,
+        checkpoint_period,
+        device,
+        arguments,
+        output_path,
+        rampup,
     )
 
+
+# Apply image transforms to a folder containing images
+@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
+@click.option("--source-path", "-s", required=True, cls=ResourceOption)
+@click.option("--target-path", "-t", required=True, cls=ResourceOption)
+@click.option("--transforms", "-a", required=True, cls=ResourceOption)
 @verbosity_option(cls=ResourceOption)
-def transformfolder(source_path ,target_path,transforms,**kwargs):
-    logger.info('Applying transforms to images in {} and saving them to {}'.format(source_path, target_path))
-    transfld(source_path,target_path,transforms)
+def transformfolder(source_path, target_path, transforms, **kwargs):
+    logger.info(
+        "Applying transforms to images in {} and saving them to {}".format(
+            source_path, target_path
+        )
+    )
+    transfld(source_path, target_path, transforms)
 
 
 # Run inference and create predictions only (no ground truth available)
-@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
-@click.option(
-    '--output-path',
-    '-o',
-    required=True,
-    default="output",
-    cls=ResourceOption
-    )
-@click.option(
-    '--model',
-    '-m',
-    required=True,
-    cls=ResourceOption
-    )
+@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
 @click.option(
-    '--dataset',
-    '-d',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--batch-size',
-    '-b',
-    required=True,
-    default=2,
-    cls=ResourceOption)
+    "--output-path", "-o", required=True, default="output", cls=ResourceOption
+)
+@click.option("--model", "-m", required=True, cls=ResourceOption)
+@click.option("--dataset", "-d", required=True, cls=ResourceOption)
+@click.option("--batch-size", "-b", required=True, default=2, cls=ResourceOption)
 @click.option(
-    '--device',
-    '-d',
+    "--device",
+    "-d",
     help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
     show_default=True,
     required=True,
-    default='cpu',
-    cls=ResourceOption)
+    default="cpu",
+    cls=ResourceOption,
+)
 @click.option(
-    '--weight',
-    '-w',
-    help='Path or URL to pretrained model',
+    "--weight",
+    "-w",
+    help="Path or URL to pretrained model",
     required=False,
     default=None,
-    cls=ResourceOption
-    )
+    cls=ResourceOption,
+)
 @verbosity_option(cls=ResourceOption)
-def predict(model
-        ,output_path
-        ,device
-        ,batch_size
-        ,dataset
-        ,weight
-        , **kwargs):
+def predict(model, output_path, device, batch_size, dataset, weight, **kwargs):
     """ Run inference and evalaute the model performance """
 
     # PyTorch dataloader
     data_loader = DataLoader(
-        dataset = dataset
-        ,batch_size = batch_size
-        ,shuffle= False
-        ,pin_memory = torch.cuda.is_available()
-        )
+        dataset=dataset,
+        batch_size=batch_size,
+        shuffle=False,
+        pin_memory=torch.cuda.is_available(),
+    )
 
     # checkpointer, load last model in dir
-    checkpointer = DetectronCheckpointer(model, save_dir = output_path, save_to_disk=False)
+    checkpointer = DetectronCheckpointer(
+        model, save_dir=output_path, save_to_disk=False
+    )
     checkpointer.load(weight)
     do_predict(model, data_loader, device, output_path)
 
@@ -563,68 +447,55 @@ def predict(model
     overlay(dataset=dataset, output_path=output_path)
 
 
-
 # Evaluate only. Runs evaluation on predicted probability maps (--prediction-folder)
-@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
+@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
 @click.option(
-    '--output-path',
-    '-o',
-    required=True,
-    default="output",
-    cls=ResourceOption
-    )
+    "--output-path", "-o", required=True, default="output", cls=ResourceOption
+)
 @click.option(
-    '--prediction-folder',
-    '-p',
-    help = 'Path containing output probability maps',
+    "--prediction-folder",
+    "-p",
+    help="Path containing output probability maps",
     required=True,
-    cls=ResourceOption
-    )
+    cls=ResourceOption,
+)
 @click.option(
-    '--prediction-extension',
-    '-x',
-    help = 'Extension (e.g. ".png") for the prediction files',
+    "--prediction-extension",
+    "-x",
+    help='Extension (e.g. ".png") for the prediction files',
     default=".png",
     required=False,
-    cls=ResourceOption
-    )
-@click.option(
-    '--dataset',
-    '-d',
-    required=True,
-    cls=ResourceOption
-    )
-@click.option(
-    '--title',
-    required=False,
-    cls=ResourceOption
-    )
-@click.option(
-    '--legend',
-    cls=ResourceOption
-    )
-
+    cls=ResourceOption,
+)
+@click.option("--dataset", "-d", required=True, cls=ResourceOption)
+@click.option("--title", required=False, cls=ResourceOption)
+@click.option("--legend", cls=ResourceOption)
 @verbosity_option(cls=ResourceOption)
 def evalpred(
-        output_path
-        ,prediction_folder
-        ,prediction_extension
-        ,dataset
-        ,title
-        ,legend
-        , **kwargs):
+    output_path,
+    prediction_folder,
+    prediction_extension,
+    dataset,
+    title,
+    legend,
+    **kwargs
+):
     """ Run inference and evalaute the model performance """
 
     # PyTorch dataloader
     data_loader = DataLoader(
-        dataset = dataset
-        ,batch_size = 1
-        ,shuffle= False
-        ,pin_memory = torch.cuda.is_available()
-        )
+        dataset=dataset,
+        batch_size=1,
+        shuffle=False,
+        pin_memory=torch.cuda.is_available(),
+    )
 
     # Run eval
-    do_eval(prediction_folder, data_loader, output_folder = output_path, title=title, legend=legend, prediction_extension=prediction_extension)
-
-
-
+    do_eval(
+        prediction_folder,
+        data_loader,
+        output_folder=output_path,
+        title=title,
+        legend=legend,
+        prediction_extension=prediction_extension,
+    )
diff --git a/bob/ip/binseg/test/test_basemetrics.py b/bob/ip/binseg/test/test_basemetrics.py
index bf478ac7..6cd71614 100644
--- a/bob/ip/binseg/test/test_basemetrics.py
+++ b/bob/ip/binseg/test/test_basemetrics.py
@@ -6,39 +6,44 @@ import numpy as np
 from bob.ip.binseg.utils.metric import base_metrics
 import random
 
+
 class Tester(unittest.TestCase):
     """
     Unit test for base metrics
     """
+
     def setUp(self):
         self.tp = random.randint(1, 100)
         self.fp = random.randint(1, 100)
         self.tn = random.randint(1, 100)
         self.fn = random.randint(1, 100)
-    
+
     def test_precision(self):
         precision = base_metrics(self.tp, self.fp, self.tn, self.fn)[0]
-        self.assertEqual((self.tp)/(self.tp + self.fp),precision)
+        self.assertEqual((self.tp) / (self.tp + self.fp), precision)
 
     def test_recall(self):
         recall = base_metrics(self.tp, self.fp, self.tn, self.fn)[1]
-        self.assertEqual((self.tp)/(self.tp + self.fn),recall)
+        self.assertEqual((self.tp) / (self.tp + self.fn), recall)
 
     def test_specificity(self):
         specificity = base_metrics(self.tp, self.fp, self.tn, self.fn)[2]
-        self.assertEqual((self.tn)/(self.tn + self.fp),specificity)
-    
+        self.assertEqual((self.tn) / (self.tn + self.fp), specificity)
+
     def test_accuracy(self):
         accuracy = base_metrics(self.tp, self.fp, self.tn, self.fn)[3]
-        self.assertEqual((self.tp + self.tn)/(self.tp + self.tn + self.fp + self.fn), accuracy)
+        self.assertEqual(
+            (self.tp + self.tn) / (self.tp + self.tn + self.fp + self.fn), accuracy
+        )
 
     def test_jaccard(self):
         jaccard = base_metrics(self.tp, self.fp, self.tn, self.fn)[4]
-        self.assertEqual(self.tp / (self.tp+self.fp+self.fn), jaccard)
+        self.assertEqual(self.tp / (self.tp + self.fp + self.fn), jaccard)
 
     def test_f1(self):
         f1 = base_metrics(self.tp, self.fp, self.tn, self.fn)[5]
-        self.assertEqual((2.0 * self.tp ) / (2.0 * self.tp + self.fp + self.fn ),f1)
-        
-if __name__ == '__main__':
-    unittest.main()
\ No newline at end of file
+        self.assertEqual((2.0 * self.tp) / (2.0 * self.tp + self.fp + self.fn), f1)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/bob/ip/binseg/test/test_batchmetrics.py b/bob/ip/binseg/test/test_batchmetrics.py
index 4988cab6..25868691 100644
--- a/bob/ip/binseg/test/test_batchmetrics.py
+++ b/bob/ip/binseg/test/test_batchmetrics.py
@@ -9,31 +9,42 @@ import shutil, tempfile
 import logging
 import torch
 
+
 class Tester(unittest.TestCase):
     """
     Unit test for batch metrics
     """
+
     def setUp(self):
         self.tp = random.randint(1, 100)
         self.fp = random.randint(1, 100)
         self.tn = random.randint(1, 100)
         self.fn = random.randint(1, 100)
-        self.predictions = torch.rand(size=(2,1,420,420))
-        self.ground_truths = torch.randint(low=0, high=2, size=(2,1,420,420))
-        self.names = ['Bob','Tim'] 
+        self.predictions = torch.rand(size=(2, 1, 420, 420))
+        self.ground_truths = torch.randint(low=0, high=2, size=(2, 1, 420, 420))
+        self.names = ["Bob", "Tim"]
         self.output_folder = tempfile.mkdtemp()
         self.logger = logging.getLogger(__name__)
 
     def tearDown(self):
         # Remove the temporary folder after the test
         shutil.rmtree(self.output_folder)
-    
+
     def test_batch_metrics(self):
-        bm = batch_metrics(self.predictions, self.ground_truths, self.names, self.output_folder, self.logger)
-        self.assertEqual(len(bm),2*100)
+        bm = batch_metrics(
+            self.predictions,
+            self.ground_truths,
+            self.names,
+            self.output_folder,
+            self.logger,
+        )
+        self.assertEqual(len(bm), 2 * 100)
         for metric in bm:
             # check whether f1 score agree
-            self.assertAlmostEqual(metric[-1],2*(metric[-6]*metric[-5])/(metric[-6]+metric[-5]))
+            self.assertAlmostEqual(
+                metric[-1], 2 * (metric[-6] * metric[-5]) / (metric[-6] + metric[-5])
+            )
+
 
-if __name__ == '__main__':
-    unittest.main()
\ No newline at end of file
+if __name__ == "__main__":
+    unittest.main()
diff --git a/bob/ip/binseg/test/test_checkpointer.py b/bob/ip/binseg/test/test_checkpointer.py
index 976e9d94..b8fb6159 100644
--- a/bob/ip/binseg/test/test_checkpointer.py
+++ b/bob/ip/binseg/test/test_checkpointer.py
@@ -39,9 +39,7 @@ class TestCheckpointer(unittest.TestCase):
         trained_model = self.create_model()
         fresh_model = self.create_model()
         with TemporaryDirectory() as f:
-            checkpointer = Checkpointer(
-                trained_model, save_dir=f, save_to_disk=True
-            )
+            checkpointer = Checkpointer(trained_model, save_dir=f, save_to_disk=True)
             checkpointer.save("checkpoint_file")
 
             # in the same folder
@@ -66,9 +64,7 @@ class TestCheckpointer(unittest.TestCase):
         trained_model = self.create_model()
         fresh_model = self.create_model()
         with TemporaryDirectory() as f:
-            checkpointer = Checkpointer(
-                trained_model, save_dir=f, save_to_disk=True
-            )
+            checkpointer = Checkpointer(trained_model, save_dir=f, save_to_disk=True)
             checkpointer.save("checkpoint_file")
 
             # on different folders
diff --git a/bob/ip/binseg/test/test_summary.py b/bob/ip/binseg/test/test_summary.py
index 7faabf79..aebcaace 100644
--- a/bob/ip/binseg/test/test_summary.py
+++ b/bob/ip/binseg/test/test_summary.py
@@ -11,36 +11,37 @@ from bob.ip.binseg.modeling.unet import build_unet
 from bob.ip.binseg.modeling.resunet import build_res50unet
 from bob.ip.binseg.utils.summary import summary
 
+
 class Tester(unittest.TestCase):
     """
     Unit test for model architectures
-    """    
+    """
+
     def test_summary_driu(self):
         model = build_driu()
         param = summary(model)
-        self.assertIsInstance(param,int)
-
+        self.assertIsInstance(param, int)
 
     def test__summary_driuod(self):
         model = build_driuod()
         param = summary(model)
-        self.assertIsInstance(param,int)
-
+        self.assertIsInstance(param, int)
 
     def test_summary_hed(self):
         model = build_hed()
         param = summary(model)
-        self.assertIsInstance(param,int)
+        self.assertIsInstance(param, int)
 
     def test_summary_unet(self):
         model = build_unet()
         param = summary(model)
-        self.assertIsInstance(param,int)
+        self.assertIsInstance(param, int)
 
     def test_summary_resunet(self):
         model = build_res50unet()
         param = summary(model)
-        self.assertIsInstance(param,int)
+        self.assertIsInstance(param, int)
+
 
-if __name__ == '__main__':
-    unittest.main()
\ No newline at end of file
+if __name__ == "__main__":
+    unittest.main()
diff --git a/bob/ip/binseg/test/test_transforms.py b/bob/ip/binseg/test/test_transforms.py
index 479cd79c..e71716a0 100644
--- a/bob/ip/binseg/test/test_transforms.py
+++ b/bob/ip/binseg/test/test_transforms.py
@@ -6,15 +6,13 @@ import unittest
 import numpy as np
 from bob.ip.binseg.data.transforms import *
 
-transforms = Compose([
-                        RandomHFlip(prob=1)
-                        ,RandomHFlip(prob=1)
-                        ,RandomVFlip(prob=1)
-                        ,RandomVFlip(prob=1)
-                    ])
+transforms = Compose(
+    [RandomHFlip(prob=1), RandomHFlip(prob=1), RandomVFlip(prob=1), RandomVFlip(prob=1)]
+)
+
 
 def create_img():
-    t = torch.randn((3,42,24))
+    t = torch.randn((3, 42, 24))
     pil = VF.to_pil_image(t)
     return pil
 
@@ -23,14 +21,16 @@ class Tester(unittest.TestCase):
     """
     Unit test for random flips
     """
-    
+
     def test_flips(self):
-        transforms = Compose([
-                        RandomHFlip(prob=1)
-                        ,RandomHFlip(prob=1)
-                        ,RandomVFlip(prob=1)
-                        ,RandomVFlip(prob=1)
-                    ])
+        transforms = Compose(
+            [
+                RandomHFlip(prob=1),
+                RandomHFlip(prob=1),
+                RandomVFlip(prob=1),
+                RandomVFlip(prob=1),
+            ]
+        )
         img, gt, mask = [create_img() for i in range(3)]
         img_t, gt_t, mask_t = transforms(img, gt, mask)
         self.assertTrue(np.all(np.array(img_t) == np.array(img)))
@@ -41,9 +41,10 @@ class Tester(unittest.TestCase):
         transforms = ToTensor()
         img, gt, mask = [create_img() for i in range(3)]
         img_t, gt_t, mask_t = transforms(img, gt, mask)
-        self.assertEqual(str(img_t.dtype),"torch.float32")
-        self.assertEqual(str(gt_t.dtype),"torch.float32")
-        self.assertEqual(str(mask_t.dtype),"torch.float32")
+        self.assertEqual(str(img_t.dtype), "torch.float32")
+        self.assertEqual(str(gt_t.dtype), "torch.float32")
+        self.assertEqual(str(mask_t.dtype), "torch.float32")
+
 
-if __name__ == '__main__':
-    unittest.main()
\ No newline at end of file
+if __name__ == "__main__":
+    unittest.main()
diff --git a/bob/ip/binseg/utils/checkpointer.py b/bob/ip/binseg/utils/checkpointer.py
index 1a79d908..4b375e9f 100644
--- a/bob/ip/binseg/utils/checkpointer.py
+++ b/bob/ip/binseg/utils/checkpointer.py
@@ -59,9 +59,7 @@ class Checkpointer:
             f = self.get_checkpoint_file()
         if not f:
             # no checkpoint could be found
-            self.logger.warn(
-                "No checkpoint found. Initializing model from scratch"
-            )
+            self.logger.warn("No checkpoint found. Initializing model from scratch")
             return {}
         self.logger.info("Loading checkpoint from {}".format(f))
         checkpoint = self._load_file(f)
diff --git a/bob/ip/binseg/utils/click.py b/bob/ip/binseg/utils/click.py
index 03fd5d30..792cebfd 100644
--- a/bob/ip/binseg/utils/click.py
+++ b/bob/ip/binseg/utils/click.py
@@ -3,6 +3,7 @@
 
 import click
 
+
 class OptionEatAll(click.Option):
     """
     Allows for ``*args`` and ``**kwargs`` to be passed to click
@@ -11,15 +12,14 @@ class OptionEatAll(click.Option):
     """
 
     def __init__(self, *args, **kwargs):
-        self.save_other_options = kwargs.pop('save_other_options', True)
-        nargs = kwargs.pop('nargs', -1)
-        assert nargs == -1, 'nargs, if set, must be -1 not {}'.format(nargs)
+        self.save_other_options = kwargs.pop("save_other_options", True)
+        nargs = kwargs.pop("nargs", -1)
+        assert nargs == -1, "nargs, if set, must be -1 not {}".format(nargs)
         super(OptionEatAll, self).__init__(*args, **kwargs)
         self._previous_parser_process = None
         self._eat_all_parser = None
 
     def add_to_parser(self, parser, ctx):
-
         def parser_process(value, state):
             # method to hook to the parser.process
             done = False
diff --git a/bob/ip/binseg/utils/evaluate.py b/bob/ip/binseg/utils/evaluate.py
index 99259f41..d891a1ce 100644
--- a/bob/ip/binseg/utils/evaluate.py
+++ b/bob/ip/binseg/utils/evaluate.py
@@ -13,7 +13,10 @@ import torchvision.transforms.functional as VF
 from tqdm import tqdm
 
 from bob.ip.binseg.utils.metric import SmoothedValue, base_metrics
-from bob.ip.binseg.utils.plot import precision_recall_f1iso, precision_recall_f1iso_confintval
+from bob.ip.binseg.utils.plot import (
+    precision_recall_f1iso,
+    precision_recall_f1iso_confintval,
+)
 from bob.ip.binseg.utils.summary import summary
 from PIL import Image
 from torchvision.transforms.functional import to_tensor
@@ -51,21 +54,23 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
         file_name = "{}.csv".format(names[j])
         logger.info("saving {}".format(file_name))
 
-        with open (os.path.join(output_folder,file_name), "w+") as outfile:
+        with open(os.path.join(output_folder, file_name), "w+") as outfile:
 
-            outfile.write("threshold, precision, recall, specificity, accuracy, jaccard, f1_score\n")
+            outfile.write(
+                "threshold, precision, recall, specificity, accuracy, jaccard, f1_score\n"
+            )
 
-            for threshold in np.arange(0.0,1.0,step_size):
+            for threshold in np.arange(0.0, 1.0, step_size):
                 # threshold
                 binary_pred = torch.gt(predictions[j], threshold).byte()
 
                 # equals and not-equals
-                equals = torch.eq(binary_pred, gts) # tensor
-                notequals = torch.ne(binary_pred, gts) # tensor
+                equals = torch.eq(binary_pred, gts)  # tensor
+                notequals = torch.ne(binary_pred, gts)  # tensor
 
                 # true positives
-                tp_tensor = (gts * binary_pred ) # tensor
-                tp_count = torch.sum(tp_tensor).item() # scalar
+                tp_tensor = gts * binary_pred  # tensor
+                tp_count = torch.sum(tp_tensor).item()  # scalar
 
                 # false positives
                 fp_tensor = torch.eq((binary_pred + tp_tensor), 1)
@@ -83,22 +88,24 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
                 metrics = base_metrics(tp_count, fp_count, tn_count, fn_count)
 
                 # write to disk
-                outfile.write("{:.2f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f} \n".format(threshold, *metrics))
-
-                batch_metrics.append([names[j],threshold, *metrics ])
+                outfile.write(
+                    "{:.2f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f},{:.5f} \n".format(
+                        threshold, *metrics
+                    )
+                )
 
+                batch_metrics.append([names[j], threshold, *metrics])
 
     return batch_metrics
 
 
-
 def do_eval(
     prediction_folder,
     data_loader,
-    output_folder = None,
-    title = '2nd human',
-    legend = '2nd human',
-    prediction_extension = None,
+    output_folder=None,
+    title="2nd human",
+    legend="2nd human",
+    prediction_extension=None,
 ):
 
     """
@@ -116,9 +123,8 @@ def do_eval(
     logger = logging.getLogger("bob.ip.binseg.engine.evaluate")
     logger.info("Start evaluation")
     logger.info("Prediction folder {}".format(prediction_folder))
-    results_subfolder = os.path.join(output_folder,'results')
-    os.makedirs(results_subfolder,exist_ok=True)
-
+    results_subfolder = os.path.join(output_folder, "results")
+    os.makedirs(results_subfolder, exist_ok=True)
 
     # Collect overall metrics
     metrics = []
@@ -129,66 +135,80 @@ def do_eval(
         ground_truths = samples[2]
 
         if prediction_extension is None:
-            pred_file = os.path.join(prediction_folder,names[0])
+            pred_file = os.path.join(prediction_folder, names[0])
         else:
-            pred_file = os.path.join(prediction_folder,os.path.splitext(names[0])[0] + '.png')
+            pred_file = os.path.join(
+                prediction_folder, os.path.splitext(names[0])[0] + ".png"
+            )
         probabilities = Image.open(pred_file)
-        probabilities = probabilities.convert(mode='L')
+        probabilities = probabilities.convert(mode="L")
         probabilities = to_tensor(probabilities)
 
-
-        b_metrics = batch_metrics(probabilities, ground_truths, names,results_subfolder, logger)
+        b_metrics = batch_metrics(
+            probabilities, ground_truths, names, results_subfolder, logger
+        )
         metrics.extend(b_metrics)
 
-
-
     # DataFrame
-    df_metrics = pd.DataFrame(metrics,columns= \
-                           ["name",
-                            "threshold",
-                            "precision",
-                            "recall",
-                            "specificity",
-                            "accuracy",
-                            "jaccard",
-                            "f1_score"])
+    df_metrics = pd.DataFrame(
+        metrics,
+        columns=[
+            "name",
+            "threshold",
+            "precision",
+            "recall",
+            "specificity",
+            "accuracy",
+            "jaccard",
+            "f1_score",
+        ],
+    )
 
     # Report and Averages
     metrics_file = "Metrics.csv"
     metrics_path = os.path.join(results_subfolder, metrics_file)
     logger.info("Saving average over all input images: {}".format(metrics_file))
 
-    avg_metrics = df_metrics.groupby('threshold').mean()
-    std_metrics = df_metrics.groupby('threshold').std()
+    avg_metrics = df_metrics.groupby("threshold").mean()
+    std_metrics = df_metrics.groupby("threshold").std()
 
     # Uncomment below for F1-score calculation based on average precision and metrics instead of
     # F1-scores of individual images. This method is in line with Maninis et. al. (2016)
-    #avg_metrics["f1_score"] =  (2* avg_metrics["precision"]*avg_metrics["recall"])/ \
+    # avg_metrics["f1_score"] =  (2* avg_metrics["precision"]*avg_metrics["recall"])/ \
     #    (avg_metrics["precision"]+avg_metrics["recall"])
 
-
     avg_metrics["std_pr"] = std_metrics["precision"]
-    avg_metrics["pr_upper"] = avg_metrics['precision'] + avg_metrics["std_pr"]
-    avg_metrics["pr_lower"] = avg_metrics['precision'] - avg_metrics["std_pr"]
+    avg_metrics["pr_upper"] = avg_metrics["precision"] + avg_metrics["std_pr"]
+    avg_metrics["pr_lower"] = avg_metrics["precision"] - avg_metrics["std_pr"]
     avg_metrics["std_re"] = std_metrics["recall"]
-    avg_metrics["re_upper"] = avg_metrics['recall'] + avg_metrics["std_re"]
-    avg_metrics["re_lower"] = avg_metrics['recall'] - avg_metrics["std_re"]
+    avg_metrics["re_upper"] = avg_metrics["recall"] + avg_metrics["std_re"]
+    avg_metrics["re_lower"] = avg_metrics["recall"] - avg_metrics["std_re"]
     avg_metrics["std_f1"] = std_metrics["f1_score"]
 
     avg_metrics.to_csv(metrics_path)
-    maxf1 = avg_metrics['f1_score'].max()
-    optimal_f1_threshold = avg_metrics['f1_score'].idxmax()
+    maxf1 = avg_metrics["f1_score"].max()
+    optimal_f1_threshold = avg_metrics["f1_score"].idxmax()
 
-    logger.info("Highest F1-score of {:.5f}, achieved at threshold {}".format(maxf1, optimal_f1_threshold))
+    logger.info(
+        "Highest F1-score of {:.5f}, achieved at threshold {}".format(
+            maxf1, optimal_f1_threshold
+        )
+    )
 
     # Plotting
-    #print(avg_metrics)
+    # print(avg_metrics)
     np_avg_metrics = avg_metrics.to_numpy().T
     fig_name = "precision_recall.pdf"
     logger.info("saving {}".format(fig_name))
-    fig = precision_recall_f1iso_confintval([np_avg_metrics[0]],[np_avg_metrics[1]],[np_avg_metrics[7]],[np_avg_metrics[8]],[np_avg_metrics[10]],[np_avg_metrics[11]], [legend ,None], title=title)
+    fig = precision_recall_f1iso_confintval(
+        [np_avg_metrics[0]],
+        [np_avg_metrics[1]],
+        [np_avg_metrics[7]],
+        [np_avg_metrics[8]],
+        [np_avg_metrics[10]],
+        [np_avg_metrics[11]],
+        [legend, None],
+        title=title,
+    )
     fig_filename = os.path.join(results_subfolder, fig_name)
     fig.savefig(fig_filename)
-
-
-
diff --git a/bob/ip/binseg/utils/metric.py b/bob/ip/binseg/utils/metric.py
index bcb91511..471cac63 100644
--- a/bob/ip/binseg/utils/metric.py
+++ b/bob/ip/binseg/utils/metric.py
@@ -27,6 +27,7 @@ class SmoothedValue:
         d = torch.tensor(list(self.deque))
         return d.mean().item()
 
+
 def base_metrics(tp, fp, tn, fn):
     """
     Calculates Precision, Recall (=Sensitivity), Specificity, Accuracy, Jaccard and F1-score (Dice)
@@ -54,11 +55,11 @@ def base_metrics(tp, fp, tn, fn):
     metrics : list
     
     """
-    precision = tp / (tp + fp + ( (tp+fp) == 0) )
-    recall = tp / (tp + fn + ( (tp+fn) == 0) )
-    specificity = tn / (fp + tn + ( (fp+tn) == 0) )
-    accuracy = (tp + tn) / (tp+fp+fn+tn)
-    jaccard = tp / (tp+fp+fn + ( (tp+fp+fn) == 0) )
-    f1_score = (2.0 * tp ) / (2.0 * tp + fp + fn + ( (2.0 * tp + fp + fn) == 0) )
-    #f1_score = (2.0 * precision * recall) / (precision + recall)
-    return [precision, recall, specificity, accuracy, jaccard, f1_score]
\ No newline at end of file
+    precision = tp / (tp + fp + ((tp + fp) == 0))
+    recall = tp / (tp + fn + ((tp + fn) == 0))
+    specificity = tn / (fp + tn + ((fp + tn) == 0))
+    accuracy = (tp + tn) / (tp + fp + fn + tn)
+    jaccard = tp / (tp + fp + fn + ((tp + fp + fn) == 0))
+    f1_score = (2.0 * tp) / (2.0 * tp + fp + fn + ((2.0 * tp + fp + fn) == 0))
+    # f1_score = (2.0 * precision * recall) / (precision + recall)
+    return [precision, recall, specificity, accuracy, jaccard, f1_score]
diff --git a/bob/ip/binseg/utils/model_serialization.py b/bob/ip/binseg/utils/model_serialization.py
index 84ff2491..016f085e 100644
--- a/bob/ip/binseg/utils/model_serialization.py
+++ b/bob/ip/binseg/utils/model_serialization.py
@@ -5,6 +5,7 @@ import logging
 
 import torch
 
+
 def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
     """
     Strategy: suppose that the models that we will create will have prefixes appended
@@ -75,4 +76,4 @@ def load_state_dict(model, loaded_state_dict):
     align_and_update_state_dicts(model_state_dict, loaded_state_dict)
 
     # use strict loading
-    model.load_state_dict(model_state_dict)
\ No newline at end of file
+    model.load_state_dict(model_state_dict)
diff --git a/bob/ip/binseg/utils/model_zoo.py b/bob/ip/binseg/utils/model_zoo.py
index 00c7f7c5..18052744 100644
--- a/bob/ip/binseg/utils/model_zoo.py
+++ b/bob/ip/binseg/utils/model_zoo.py
@@ -35,13 +35,14 @@ modelurls = {
     "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
     "resnet50_SIN_IN": "https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar",
     "mobilenetv2": "https://dl.dropboxusercontent.com/s/4nie4ygivq04p8y/mobilenet_v2.pth.tar",
-    }
+}
+
 
 def _download_url_to_file(url, dst, hash_prefix, progress):
     file_size = None
     u = urlopen(url)
     meta = u.info()
-    if hasattr(meta, 'getheaders'):
+    if hasattr(meta, "getheaders"):
         content_length = meta.getheaders("Content-Length")
     else:
         content_length = meta.get_all("Content-Length")
@@ -65,16 +66,21 @@ def _download_url_to_file(url, dst, hash_prefix, progress):
         f.close()
         if hash_prefix is not None:
             digest = sha256.hexdigest()
-            if digest[:len(hash_prefix)] != hash_prefix:
-                raise RuntimeError('invalid hash value (expected "{}", got "{}")'
-                                   .format(hash_prefix, digest))
+            if digest[: len(hash_prefix)] != hash_prefix:
+                raise RuntimeError(
+                    'invalid hash value (expected "{}", got "{}")'.format(
+                        hash_prefix, digest
+                    )
+                )
         shutil.move(f.name, dst)
     finally:
         f.close()
         if os.path.exists(f.name):
             os.remove(f.name)
 
-HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
+
+HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
+
 
 def cache_url(url, model_dir=None, progress=True):
     r"""Loads the Torch serialized object at the given URL.
diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py
index de1d531d..ba85ba3d 100644
--- a/bob/ip/binseg/utils/plot.py
+++ b/bob/ip/binseg/utils/plot.py
@@ -6,10 +6,11 @@ import os
 import csv
 import pandas as pd
 import PIL
-from PIL import Image,ImageFont, ImageDraw
+from PIL import Image, ImageFont, ImageDraw
 import torchvision.transforms.functional as VF
 import torch
 
+
 def precision_recall_f1iso(precision, recall, names, title=None):
     """
     Author: Andre Anjos (andre.anjos@idiap.ch).
@@ -40,11 +41,13 @@ def precision_recall_f1iso(precision, recall, names, title=None):
         A matplotlib figure you can save or display
     """
     import matplotlib
-    matplotlib.use('agg')
+
+    matplotlib.use("agg")
     import matplotlib.pyplot as plt
     from itertools import cycle
+
     fig, ax1 = plt.subplots(1)
-    lines = ["-","--","-.",":"]
+    lines = ["-", "--", "-.", ":"]
     linecycler = cycle(lines)
     for p, r, n in zip(precision, recall, names):
         # Plots only from the point where recall reaches its maximum, otherwise, we
@@ -52,23 +55,29 @@ def precision_recall_f1iso(precision, recall, names, title=None):
         i = r.argmax()
         pi = p[i:]
         ri = r[i:]
-        valid = (pi+ri) > 0
-        f1 = 2 * (pi[valid]*ri[valid]) / (pi[valid]+ri[valid])
+        valid = (pi + ri) > 0
+        f1 = 2 * (pi[valid] * ri[valid]) / (pi[valid] + ri[valid])
         # optimal point along the curve
         argmax = f1.argmax()
         opi = pi[argmax]
         ori = ri[argmax]
         # Plot Recall/Precision as threshold changes
-        ax1.plot(ri[pi>0], pi[pi>0], next(linecycler), label='[F={:.4f}] {}'.format(f1.max(), n),)
-        ax1.plot(ori,opi, marker='o', linestyle=None, markersize=3, color='black')
-    ax1.grid(linestyle='--', linewidth=1, color='gray', alpha=0.2)
+        ax1.plot(
+            ri[pi > 0],
+            pi[pi > 0],
+            next(linecycler),
+            label="[F={:.4f}] {}".format(f1.max(), n),
+        )
+        ax1.plot(ori, opi, marker="o", linestyle=None, markersize=3, color="black")
+    ax1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
     if len(names) > 1:
-        plt.legend(loc='lower left', framealpha=0.5)
-    ax1.set_xlabel('Recall')
-    ax1.set_ylabel('Precision')
+        plt.legend(loc="lower left", framealpha=0.5)
+    ax1.set_xlabel("Recall")
+    ax1.set_ylabel("Precision")
     ax1.set_xlim([0.0, 1.0])
     ax1.set_ylim([0.0, 1.0])
-    if title is not None: ax1.set_title(title)
+    if title is not None:
+        ax1.set_title(title)
     # Annotates plot with F1-score iso-lines
     ax2 = ax1.twinx()
     f_scores = np.linspace(0.1, 0.9, num=9)
@@ -77,32 +86,35 @@ def precision_recall_f1iso(precision, recall, names, title=None):
     for f_score in f_scores:
         x = np.linspace(0.01, 1)
         y = f_score * x / (2 * x - f_score)
-        l, = plt.plot(x[y >= 0], y[y >= 0], color='green', alpha=0.1)
+        (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
         tick_locs.append(y[-1])
-        tick_labels.append('%.1f' % f_score)
-    ax2.tick_params(axis='y', which='both', pad=0, right=False, left=False)
-    ax2.set_ylabel('iso-F', color='green', alpha=0.3)
+        tick_labels.append("%.1f" % f_score)
+    ax2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
+    ax2.set_ylabel("iso-F", color="green", alpha=0.3)
     ax2.set_ylim([0.0, 1.0])
     ax2.yaxis.set_label_coords(1.015, 0.97)
-    ax2.set_yticks(tick_locs) #notice these are invisible
+    ax2.set_yticks(tick_locs)  # notice these are invisible
     for k in ax2.set_yticklabels(tick_labels):
-        k.set_color('green')
+        k.set_color("green")
         k.set_alpha(0.3)
         k.set_size(8)
     # we should see some of axes 1 axes
-    ax1.spines['right'].set_visible(False)
-    ax1.spines['top'].set_visible(False)
-    ax1.spines['left'].set_position(('data', -0.015))
-    ax1.spines['bottom'].set_position(('data', -0.015))
+    ax1.spines["right"].set_visible(False)
+    ax1.spines["top"].set_visible(False)
+    ax1.spines["left"].set_position(("data", -0.015))
+    ax1.spines["bottom"].set_position(("data", -0.015))
     # we shouldn't see any of axes 2 axes
-    ax2.spines['right'].set_visible(False)
-    ax2.spines['top'].set_visible(False)
-    ax2.spines['left'].set_visible(False)
-    ax2.spines['bottom'].set_visible(False)
+    ax2.spines["right"].set_visible(False)
+    ax2.spines["top"].set_visible(False)
+    ax2.spines["left"].set_visible(False)
+    ax2.spines["bottom"].set_visible(False)
     plt.tight_layout()
     return fig
 
-def precision_recall_f1iso_confintval(precision, recall, pr_upper, pr_lower, re_upper, re_lower, names, title=None):
+
+def precision_recall_f1iso_confintval(
+    precision, recall, pr_upper, pr_lower, re_upper, re_lower, names, title=None
+):
     """
     Author: Andre Anjos (andre.anjos@idiap.ch).
 
@@ -132,17 +144,30 @@ def precision_recall_f1iso_confintval(precision, recall, pr_upper, pr_lower, re_
         A matplotlib figure you can save or display
     """
     import matplotlib
-    matplotlib.use('agg')
+
+    matplotlib.use("agg")
     import matplotlib.pyplot as plt
     from itertools import cycle
+
     fig, ax1 = plt.subplots(1)
-    lines = ["-","--","-.",":"]
-    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
-              '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
-              '#bcbd22', '#17becf']
+    lines = ["-", "--", "-.", ":"]
+    colors = [
+        "#1f77b4",
+        "#ff7f0e",
+        "#2ca02c",
+        "#d62728",
+        "#9467bd",
+        "#8c564b",
+        "#e377c2",
+        "#7f7f7f",
+        "#bcbd22",
+        "#17becf",
+    ]
     colorcycler = cycle(colors)
     linecycler = cycle(lines)
-    for p, r, pu, pl, ru, rl, n in zip(precision, recall, pr_upper, pr_lower, re_upper, re_lower, names):
+    for p, r, pu, pl, ru, rl, n in zip(
+        precision, recall, pr_upper, pr_lower, re_upper, re_lower, names
+    ):
         # Plots only from the point where recall reaches its maximum, otherwise, we
         # don't see a curve...
         i = r.argmax()
@@ -152,39 +177,57 @@ def precision_recall_f1iso_confintval(precision, recall, pr_upper, pr_lower, re_
         pli = pl[i:]
         rui = ru[i:]
         rli = rl[i:]
-        valid = (pi+ri) > 0
-        f1 = 2 * (pi[valid]*ri[valid]) / (pi[valid]+ri[valid])
+        valid = (pi + ri) > 0
+        f1 = 2 * (pi[valid] * ri[valid]) / (pi[valid] + ri[valid])
         # optimal point along the curve
         argmax = f1.argmax()
         opi = pi[argmax]
         ori = ri[argmax]
         # Plot Recall/Precision as threshold changes
-        ax1.plot(ri[pi>0], pi[pi>0], next(linecycler), label='[F={:.4f}] {}'.format(f1.max(), n),)
-        ax1.plot(ori,opi, marker='o', linestyle=None, markersize=3, color='black')
+        ax1.plot(
+            ri[pi > 0],
+            pi[pi > 0],
+            next(linecycler),
+            label="[F={:.4f}] {}".format(f1.max(), n),
+        )
+        ax1.plot(ori, opi, marker="o", linestyle=None, markersize=3, color="black")
         # Plot confidence
         # Upper bound
-        #ax1.plot(r95ui[p95ui>0], p95ui[p95ui>0])
+        # ax1.plot(r95ui[p95ui>0], p95ui[p95ui>0])
         # Lower bound
-        #ax1.plot(r95li[p95li>0], p95li[p95li>0])
+        # ax1.plot(r95li[p95li>0], p95li[p95li>0])
         # create the limiting polygon
-        vert_x = np.concatenate((rui[pui>0], rli[pli>0][::-1]))
-        vert_y = np.concatenate((pui[pui>0], pli[pli>0][::-1]))
+        vert_x = np.concatenate((rui[pui > 0], rli[pli > 0][::-1]))
+        vert_y = np.concatenate((pui[pui > 0], pli[pli > 0][::-1]))
         # hacky workaround to plot 2nd human
         if np.isclose(np.mean(rui), rui[1], rtol=1e-05):
-            print('found human')
-            p = plt.Polygon(np.column_stack((vert_x, vert_y)), facecolor='none', alpha=.2, edgecolor=next(colorcycler),lw=2)
+            print("found human")
+            p = plt.Polygon(
+                np.column_stack((vert_x, vert_y)),
+                facecolor="none",
+                alpha=0.2,
+                edgecolor=next(colorcycler),
+                lw=2,
+            )
         else:
-            p = plt.Polygon(np.column_stack((vert_x, vert_y)), facecolor=next(colorcycler), alpha=.2, edgecolor='none',lw=.2)
+            p = plt.Polygon(
+                np.column_stack((vert_x, vert_y)),
+                facecolor=next(colorcycler),
+                alpha=0.2,
+                edgecolor="none",
+                lw=0.2,
+            )
         ax1.add_artist(p)
 
-    ax1.grid(linestyle='--', linewidth=1, color='gray', alpha=0.2)
+    ax1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
     if len(names) > 1:
-        plt.legend(loc='lower left', framealpha=0.5)
-    ax1.set_xlabel('Recall')
-    ax1.set_ylabel('Precision')
+        plt.legend(loc="lower left", framealpha=0.5)
+    ax1.set_xlabel("Recall")
+    ax1.set_ylabel("Precision")
     ax1.set_xlim([0.0, 1.0])
     ax1.set_ylim([0.0, 1.0])
-    if title is not None: ax1.set_title(title)
+    if title is not None:
+        ax1.set_title(title)
     # Annotates plot with F1-score iso-lines
     ax2 = ax1.twinx()
     f_scores = np.linspace(0.1, 0.9, num=9)
@@ -193,31 +236,32 @@ def precision_recall_f1iso_confintval(precision, recall, pr_upper, pr_lower, re_
     for f_score in f_scores:
         x = np.linspace(0.01, 1)
         y = f_score * x / (2 * x - f_score)
-        l, = plt.plot(x[y >= 0], y[y >= 0], color='green', alpha=0.1)
+        (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
         tick_locs.append(y[-1])
-        tick_labels.append('%.1f' % f_score)
-    ax2.tick_params(axis='y', which='both', pad=0, right=False, left=False)
-    ax2.set_ylabel('iso-F', color='green', alpha=0.3)
+        tick_labels.append("%.1f" % f_score)
+    ax2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
+    ax2.set_ylabel("iso-F", color="green", alpha=0.3)
     ax2.set_ylim([0.0, 1.0])
     ax2.yaxis.set_label_coords(1.015, 0.97)
-    ax2.set_yticks(tick_locs) #notice these are invisible
+    ax2.set_yticks(tick_locs)  # notice these are invisible
     for k in ax2.set_yticklabels(tick_labels):
-        k.set_color('green')
+        k.set_color("green")
         k.set_alpha(0.3)
         k.set_size(8)
     # we should see some of axes 1 axes
-    ax1.spines['right'].set_visible(False)
-    ax1.spines['top'].set_visible(False)
-    ax1.spines['left'].set_position(('data', -0.015))
-    ax1.spines['bottom'].set_position(('data', -0.015))
+    ax1.spines["right"].set_visible(False)
+    ax1.spines["top"].set_visible(False)
+    ax1.spines["left"].set_position(("data", -0.015))
+    ax1.spines["bottom"].set_position(("data", -0.015))
     # we shouldn't see any of axes 2 axes
-    ax2.spines['right'].set_visible(False)
-    ax2.spines['top'].set_visible(False)
-    ax2.spines['left'].set_visible(False)
-    ax2.spines['bottom'].set_visible(False)
+    ax2.spines["right"].set_visible(False)
+    ax2.spines["top"].set_visible(False)
+    ax2.spines["left"].set_visible(False)
+    ax2.spines["bottom"].set_visible(False)
     plt.tight_layout()
     return fig
 
+
 def loss_curve(df, title):
     """ Creates a loss curve given a Dataframe with column names:
 
@@ -232,15 +276,17 @@ def loss_curve(df, title):
     matplotlib.figure.Figure
     """
     import matplotlib
-    matplotlib.use('agg')
+
+    matplotlib.use("agg")
     import matplotlib.pyplot as plt
+
     ax1 = df.plot(y="median loss", grid=True)
     ax1.set_title(title)
-    ax1.set_ylabel('median loss')
-    ax1.grid(linestyle='--', linewidth=1, color='gray', alpha=0.2)
-    ax2 = df['lr'].plot(secondary_y=True,legend=True,grid=True,)
-    ax2.set_ylabel('lr')
-    ax1.set_xlabel('epoch')
+    ax1.set_ylabel("median loss")
+    ax1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
+    ax2 = df["lr"].plot(secondary_y=True, legend=True, grid=True,)
+    ax2.set_ylabel("lr")
+    ax1.set_xlabel("epoch")
     plt.tight_layout()
     fig = ax1.get_figure()
     return fig
@@ -260,7 +306,7 @@ def read_metricscsv(file):
     :py:class:`numpy.ndarray`
     :py:class:`numpy.ndarray`
     """
-    with open (file, "r") as infile:
+    with open(file, "r") as infile:
         metricsreader = csv.reader(infile)
         # skip header row
         next(metricsreader)
@@ -277,10 +323,17 @@ def read_metricscsv(file):
             pr_lower.append(float(row[9]))
             re_upper.append(float(row[11]))
             re_lower.append(float(row[12]))
-    return np.array(precision), np.array(recall), np.array(pr_upper), np.array(pr_lower), np.array(re_upper), np.array(re_lower)
+    return (
+        np.array(precision),
+        np.array(recall),
+        np.array(pr_upper),
+        np.array(pr_lower),
+        np.array(re_upper),
+        np.array(re_lower),
+    )
 
 
-def plot_overview(outputfolders,title):
+def plot_overview(outputfolders, title):
     """
     Plots comparison chart of all trained models
 
@@ -304,7 +357,7 @@ def plot_overview(outputfolders,title):
     params = []
     for folder in outputfolders:
         # metrics
-        metrics_path = os.path.join(folder,'results/Metrics.csv')
+        metrics_path = os.path.join(folder, "results/Metrics.csv")
         pr, re, pr_upper, pr_lower, re_upper, re_lower = read_metricscsv(metrics_path)
         precisions.append(pr)
         recalls.append(re)
@@ -312,19 +365,24 @@ def plot_overview(outputfolders,title):
         pr_lows.append(pr_lower)
         re_ups.append(re_upper)
         re_lows.append(re_lower)
-        modelname = folder.split('/')[-1]
-        name = '{} '.format(modelname)
+        modelname = folder.split("/")[-1]
+        name = "{} ".format(modelname)
         names.append(name)
-    #title = folder.split('/')[-4]
-    fig = precision_recall_f1iso_confintval(precisions,recalls, pr_ups, pr_lows, re_ups, re_lows, names,title)
+    # title = folder.split('/')[-4]
+    fig = precision_recall_f1iso_confintval(
+        precisions, recalls, pr_ups, pr_lows, re_ups, re_lows, names, title
+    )
     return fig
 
-def metricsviz(dataset
-                ,output_path
-                ,tp_color= (0,255,0) # (128,128,128) Gray
-                ,fp_color = (0, 0, 255) # (70, 240, 240) Cyan
-                ,fn_color = (255, 0, 0) # (245, 130, 48) Orange
-                ,overlayed=True):
+
+def metricsviz(
+    dataset,
+    output_path,
+    tp_color=(0, 255, 0),  # (128,128,128) Gray
+    fp_color=(0, 0, 255),  # (70, 240, 240) Cyan
+    fn_color=(255, 0, 0),  # (245, 130, 48) Orange
+    overlayed=True,
+):
     """ Visualizes true positives, false positives and false negatives
     Default colors TP: Gray, FP: Cyan, FN: Orange
 
@@ -343,56 +401,59 @@ def metricsviz(dataset
 
     for sample in dataset:
         # get sample
-        name  = sample[0]
-        img = VF.to_pil_image(sample[1]) # PIL Image
-        gt = sample[2].byte() # byte tensor
+        name = sample[0]
+        img = VF.to_pil_image(sample[1])  # PIL Image
+        gt = sample[2].byte()  # byte tensor
 
         # read metrics
-        metrics = pd.read_csv(os.path.join(output_path,'results','Metrics.csv'))
-        optimal_threshold = metrics['threshold'][metrics['f1_score'].idxmax()]
+        metrics = pd.read_csv(os.path.join(output_path, "results", "Metrics.csv"))
+        optimal_threshold = metrics["threshold"][metrics["f1_score"].idxmax()]
 
         # read probability output
-        pred = Image.open(os.path.join(output_path,'images',name))
-        pred = pred.convert(mode='L')
+        pred = Image.open(os.path.join(output_path, "images", name))
+        pred = pred.convert(mode="L")
         pred = VF.to_tensor(pred)
         binary_pred = torch.gt(pred, optimal_threshold).byte()
 
         # calc metrics
         # equals and not-equals
-        equals = torch.eq(binary_pred, gt) # tensor
-        notequals = torch.ne(binary_pred, gt) # tensor
+        equals = torch.eq(binary_pred, gt)  # tensor
+        notequals = torch.ne(binary_pred, gt)  # tensor
         # true positives
-        tp_tensor = (gt * binary_pred ) # tensor
+        tp_tensor = gt * binary_pred  # tensor
         tp_pil = VF.to_pil_image(tp_tensor.float())
-        tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0,0,0), tp_color)
+        tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0, 0, 0), tp_color)
         # false positives
         fp_tensor = torch.eq((binary_pred + tp_tensor), 1)
         fp_pil = VF.to_pil_image(fp_tensor.float())
-        fp_pil_colored = PIL.ImageOps.colorize(fp_pil, (0,0,0), fp_color)
+        fp_pil_colored = PIL.ImageOps.colorize(fp_pil, (0, 0, 0), fp_color)
         # false negatives
         fn_tensor = notequals - fp_tensor
         fn_pil = VF.to_pil_image(fn_tensor.float())
-        fn_pil_colored = PIL.ImageOps.colorize(fn_pil, (0,0,0), fn_color)
+        fn_pil_colored = PIL.ImageOps.colorize(fn_pil, (0, 0, 0), fn_color)
 
         # paste together
-        tp_pil_colored.paste(fp_pil_colored,mask=fp_pil)
-        tp_pil_colored.paste(fn_pil_colored,mask=fn_pil)
+        tp_pil_colored.paste(fp_pil_colored, mask=fp_pil)
+        tp_pil_colored.paste(fn_pil_colored, mask=fn_pil)
 
         if overlayed:
             tp_pil_colored = PIL.Image.blend(img, tp_pil_colored, 0.4)
-            img_metrics = pd.read_csv(os.path.join(output_path,'results',name+'.csv'))
-            f1 = img_metrics[' f1_score'].max()
+            img_metrics = pd.read_csv(
+                os.path.join(output_path, "results", name + ".csv")
+            )
+            f1 = img_metrics[" f1_score"].max()
             # add f1-score
-            fnt_size = tp_pil_colored.size[1]//25
+            fnt_size = tp_pil_colored.size[1] // 25
             draw = ImageDraw.Draw(tp_pil_colored)
-            fnt = ImageFont.truetype('FreeMono.ttf', fnt_size)
-            draw.text((0, 0),"F1: {:.4f}".format(f1),(255,255,255),font=fnt)
+            fnt = ImageFont.truetype("FreeMono.ttf", fnt_size)
+            draw.text((0, 0), "F1: {:.4f}".format(f1), (255, 255, 255), font=fnt)
 
         # save to disk
-        overlayed_path = os.path.join(output_path,'tpfnfpviz')
+        overlayed_path = os.path.join(output_path, "tpfnfpviz")
         fullpath = os.path.join(overlayed_path, name)
         fulldir = os.path.dirname(fullpath)
-        if not os.path.exists(fulldir): os.makedirs(fulldir)
+        if not os.path.exists(fulldir):
+            os.makedirs(fulldir)
         tp_pil_colored.save(fullpath)
 
 
@@ -408,25 +469,26 @@ def overlay(dataset, output_path):
 
     for sample in dataset:
         # get sample
-        name  = sample[0]
-        img = VF.to_pil_image(sample[1]) # PIL Image
+        name = sample[0]
+        img = VF.to_pil_image(sample[1])  # PIL Image
 
         # read probability output
-        pred = Image.open(os.path.join(output_path,'images',name)).convert(mode='L')
+        pred = Image.open(os.path.join(output_path, "images", name)).convert(mode="L")
         # color and overlay
-        pred_green = PIL.ImageOps.colorize(pred, (0,0,0), (0,255,0))
+        pred_green = PIL.ImageOps.colorize(pred, (0, 0, 0), (0, 255, 0))
         overlayed = PIL.Image.blend(img, pred_green, 0.4)
 
         # add f1-score
-        #fnt_size = overlayed.size[1]//25
-        #draw = ImageDraw.Draw(overlayed)
-        #fnt = ImageFont.truetype('FreeMono.ttf', fnt_size)
-        #draw.text((0, 0),"F1: {:.4f}".format(f1),(255,255,255),font=fnt)
+        # fnt_size = overlayed.size[1]//25
+        # draw = ImageDraw.Draw(overlayed)
+        # fnt = ImageFont.truetype('FreeMono.ttf', fnt_size)
+        # draw.text((0, 0),"F1: {:.4f}".format(f1),(255,255,255),font=fnt)
         # save to disk
-        overlayed_path = os.path.join(output_path,'overlayed')
+        overlayed_path = os.path.join(output_path, "overlayed")
         fullpath = os.path.join(overlayed_path, name)
         fulldir = os.path.dirname(fullpath)
-        if not os.path.exists(fulldir): os.makedirs(fulldir)
+        if not os.path.exists(fulldir):
+            os.makedirs(fulldir)
         overlayed.save(fullpath)
 
 
@@ -443,12 +505,13 @@ def savetransformedtest(dataset, output_path):
 
     for sample in dataset:
         # get sample
-        name  = sample[0]
-        img = VF.to_pil_image(sample[1]) # PIL Image
+        name = sample[0]
+        img = VF.to_pil_image(sample[1])  # PIL Image
 
         # save to disk
-        testimg_path = os.path.join(output_path,'transformedtestimages')
+        testimg_path = os.path.join(output_path, "transformedtestimages")
         fullpath = os.path.join(testimg_path, name)
         fulldir = os.path.dirname(fullpath)
-        if not os.path.exists(fulldir): os.makedirs(fulldir)
+        if not os.path.exists(fulldir):
+            os.makedirs(fulldir)
         img.save(fullpath)
diff --git a/bob/ip/binseg/utils/rsttable.py b/bob/ip/binseg/utils/rsttable.py
index fdc17982..c5329d8a 100644
--- a/bob/ip/binseg/utils/rsttable.py
+++ b/bob/ip/binseg/utils/rsttable.py
@@ -3,6 +3,7 @@ from tabulate import tabulate
 import os
 from pathlib import Path
 
+
 def get_paths(output_path, filename):
     """
     Parameters
@@ -17,39 +18,39 @@ def get_paths(output_path, filename):
         list of file paths
     """
     datadir = Path(output_path)
-    file_paths = sorted(list(datadir.glob('**/{}'.format(filename))))
+    file_paths = sorted(list(datadir.glob("**/{}".format(filename))))
     file_paths = [f.as_posix() for f in file_paths]
     return file_paths
 
 
 def create_overview_grid(output_path):
     """ Reads all Metrics.csv in a certain output path and pivots them to a rst grid table"""
-    filename = 'Metrics.csv'
-    metrics = get_paths(output_path,filename)
+    filename = "Metrics.csv"
+    metrics = get_paths(output_path, filename)
     f1s = []
     stds = []
     models = []
     databases = []
     for m in metrics:
         metrics = pd.read_csv(m)
-        maxf1 = metrics['f1_score'].max()
-        idmaxf1 = metrics['f1_score'].idxmax()
-        std = metrics['std_f1'][idmaxf1]
+        maxf1 = metrics["f1_score"].max()
+        idmaxf1 = metrics["f1_score"].idxmax()
+        std = metrics["std_f1"][idmaxf1]
         stds.append(std)
         f1s.append(maxf1)
-        model = m.split('/')[-3]
+        model = m.split("/")[-3]
         models.append(model)
-        database = m.split('/')[-4]
+        database = m.split("/")[-4]
         databases.append(database)
     df = pd.DataFrame()
-    df['database'] = databases
-    df['model'] = models
-    df['f1'] = f1s
-    df['std'] = stds
-    pivot = df.pivot(index='database',columns='model',values='f1')
-    pivot2 = df.pivot(index='database',columns='model',values='std')
+    df["database"] = databases
+    df["model"] = models
+    df["f1"] = f1s
+    df["std"] = stds
+    pivot = df.pivot(index="database", columns="model", values="f1")
+    pivot2 = df.pivot(index="database", columns="model", values="std")
 
-    with open (os.path.join(output_path,'Metrics_overview.rst'), "w+") as outfile:
-        outfile.write(tabulate(pivot,headers=pivot.columns, tablefmt="grid"))
-    with open (os.path.join(output_path,'Metrics_overview_std.rst'), "w+") as outfile:
-        outfile.write(tabulate(pivot2,headers=pivot2.columns, tablefmt="grid"))
\ No newline at end of file
+    with open(os.path.join(output_path, "Metrics_overview.rst"), "w+") as outfile:
+        outfile.write(tabulate(pivot, headers=pivot.columns, tablefmt="grid"))
+    with open(os.path.join(output_path, "Metrics_overview_std.rst"), "w+") as outfile:
+        outfile.write(tabulate(pivot2, headers=pivot2.columns, tablefmt="grid"))
diff --git a/bob/ip/binseg/utils/summary.py b/bob/ip/binseg/utils/summary.py
index 127c5e66..17cf07c3 100644
--- a/bob/ip/binseg/utils/summary.py
+++ b/bob/ip/binseg/utils/summary.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488 
+# Adapted from https://github.com/pytorch/pytorch/issues/2001#issuecomment-405675488
 import sys
 import logging
 from functools import reduce
@@ -22,42 +22,43 @@ def summary(model, file=sys.stderr):
     int
         number of parameters
     """
+
     def repr(model):
         # We treat the extra repr like the sub-module, one item per line
         extra_lines = []
         extra_repr = model.extra_repr()
         # empty string will be split into list ['']
         if extra_repr:
-            extra_lines = extra_repr.split('\n')
+            extra_lines = extra_repr.split("\n")
         child_lines = []
         total_params = 0
         for key, module in model._modules.items():
             mod_str, num_params = repr(module)
             mod_str = _addindent(mod_str, 2)
-            child_lines.append('(' + key + '): ' + mod_str)
+            child_lines.append("(" + key + "): " + mod_str)
             total_params += num_params
         lines = extra_lines + child_lines
 
         for name, p in model._parameters.items():
-            if hasattr(p,'dtype'):
+            if hasattr(p, "dtype"):
                 total_params += reduce(lambda x, y: x * y, p.shape)
 
-        main_str = model._get_name() + '('
+        main_str = model._get_name() + "("
         if lines:
             # simple one-liner info, which most builtin Modules will use
             if len(extra_lines) == 1 and not child_lines:
                 main_str += extra_lines[0]
             else:
-                main_str += '\n  ' + '\n  '.join(lines) + '\n'
+                main_str += "\n  " + "\n  ".join(lines) + "\n"
 
-        main_str += ')'
+        main_str += ")"
         if file is sys.stderr:
-            main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
+            main_str += ", \033[92m{:,}\033[0m params".format(total_params)
         else:
-            main_str += ', {:,} params'.format(total_params)
+            main_str += ", {:,} params".format(total_params)
         return main_str, total_params
 
     string, count = repr(model)
     if file is not None:
         print(string, file=file)
-    return count
\ No newline at end of file
+    return count
diff --git a/bob/ip/binseg/utils/transformfolder.py b/bob/ip/binseg/utils/transformfolder.py
index 9308d647..95c33539 100644
--- a/bob/ip/binseg/utils/transformfolder.py
+++ b/bob/ip/binseg/utils/transformfolder.py
@@ -1,9 +1,10 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-from pathlib import Path,PurePosixPath
+from pathlib import Path, PurePosixPath
 from PIL import Image
 from torchvision.transforms.functional import to_pil_image
 
+
 def transformfolder(source_path, target_path, transforms):
     """Applies a set of transfroms on an image folder 
     
@@ -18,10 +19,10 @@ def transformfolder(source_path, target_path, transforms):
     """
     source_path = Path(source_path)
     target_path = Path(target_path)
-    file_paths = sorted(list(source_path.glob('*?.*')))
+    file_paths = sorted(list(source_path.glob("*?.*")))
     for f in file_paths:
         timg_path = PurePosixPath(target_path).joinpath(f.name)
-        img = Image.open(f).convert(mode='1', dither=None)
-        img, _ = transforms(img,img)
+        img = Image.open(f).convert(mode="1", dither=None)
+        img, _ = transforms(img, img)
         img = to_pil_image(img)
-        img.save(str(timg_path))
\ No newline at end of file
+        img.save(str(timg_path))
-- 
GitLab