From f30a4f83205c74dd274b61e9a7e73404ba7dc91e Mon Sep 17 00:00:00 2001
From: Tim Laibacher <tim.laibacher@idiap.ch>
Date: Thu, 25 Apr 2019 16:40:35 +0200
Subject: [PATCH] Add adabound, train logging, hed, transforms unittest,
 checkpoints inference. Refactor cli options.

---
 .../binseg/configs/datasets/drive_default.py  |  17 ---
 bob/ip/binseg/configs/datasets/drivetest.py   |  18 +++
 bob/ip/binseg/configs/datasets/drivetrain.py  |  22 +++
 .../models/{driu_default.py => driu.py}       |   6 +-
 bob/ip/binseg/configs/models/driuadabound.py  |  40 +++++
 bob/ip/binseg/configs/models/hed.py           |  33 +++++
 bob/ip/binseg/data/transforms.py              |  69 +++++++++
 bob/ip/binseg/engine/adabound.py              |   2 +-
 bob/ip/binseg/engine/inference.py             |  54 -------
 bob/ip/binseg/engine/inferencer.py            |   4 +
 bob/ip/binseg/engine/train.py                 |  82 -----------
 bob/ip/binseg/engine/trainer.py               | 139 ++++++++++--------
 bob/ip/binseg/modeling/driu.py                |   1 -
 bob/ip/binseg/modeling/hed.py                 |   2 +-
 bob/ip/binseg/modeling/losses.py              |  61 ++++++--
 bob/ip/binseg/script/binseg.py                | 124 ++++++++++------
 bob/ip/binseg/test/test_transforms.py         |  49 ++++++
 bob/ip/binseg/utils/model_serialization.py    |   2 -
 bob/ip/binseg/utils/model_zoo.py              |   1 +
 bob/ip/binseg/utils/plot.py                   |  31 +++-
 setup.py                                      |   9 +-
 21 files changed, 485 insertions(+), 281 deletions(-)
 delete mode 100644 bob/ip/binseg/configs/datasets/drive_default.py
 create mode 100644 bob/ip/binseg/configs/datasets/drivetest.py
 create mode 100644 bob/ip/binseg/configs/datasets/drivetrain.py
 rename bob/ip/binseg/configs/models/{driu_default.py => driu.py} (86%)
 create mode 100644 bob/ip/binseg/configs/models/driuadabound.py
 create mode 100644 bob/ip/binseg/configs/models/hed.py
 delete mode 100644 bob/ip/binseg/engine/inference.py
 delete mode 100644 bob/ip/binseg/engine/train.py
 create mode 100644 bob/ip/binseg/test/test_transforms.py

diff --git a/bob/ip/binseg/configs/datasets/drive_default.py b/bob/ip/binseg/configs/datasets/drive_default.py
deleted file mode 100644
index bbf5788d..00000000
--- a/bob/ip/binseg/configs/datasets/drive_default.py
+++ /dev/null
@@ -1,17 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from bob.ip.binseg.data.transforms import ToTensor
-from bob.ip.binseg.data.binsegdataset import BinSegDataset
-from torch.utils.data import DataLoader
-from bob.db.drive import Database as DRIVE
-import torch
-
-
-#### Config ####
-
-# bob.db.dataset init
-bobdb = DRIVE()
-
-# transforms 
-transforms = ToTensor()
diff --git a/bob/ip/binseg/configs/datasets/drivetest.py b/bob/ip/binseg/configs/datasets/drivetest.py
new file mode 100644
index 00000000..67d21cac
--- /dev/null
+++ b/bob/ip/binseg/configs/datasets/drivetest.py
@@ -0,0 +1,18 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from bob.db.drive import Database as DRIVE
+from bob.ip.binseg.data.transforms import *
+from bob.ip.binseg.data.binsegdataset import BinSegDataset
+
+#### Config ####
+
+transforms = Compose([
+                        ToTensor()
+                    ])
+
+# bob.db.dataset init
+bobdb = DRIVE(protocol = 'default')
+
+# PyTorch dataset
+dataset = BinSegDataset(bobdb, split='test', transform=transforms)
\ No newline at end of file
diff --git a/bob/ip/binseg/configs/datasets/drivetrain.py b/bob/ip/binseg/configs/datasets/drivetrain.py
new file mode 100644
index 00000000..6662e1fc
--- /dev/null
+++ b/bob/ip/binseg/configs/datasets/drivetrain.py
@@ -0,0 +1,22 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from bob.db.drive import Database as DRIVE
+from bob.ip.binseg.data.transforms import *
+from bob.ip.binseg.data.binsegdataset import BinSegDataset
+
+#### Config ####
+
+transforms = Compose([
+                        RandomHFlip()
+                        ,RandomVFlip()
+                        ,RandomRotation()
+                        ,ColorJitter()
+                        ,ToTensor()
+                    ])
+
+# bob.db.dataset init
+bobdb = DRIVE(protocol = 'default')
+
+# PyTorch dataset
+dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
diff --git a/bob/ip/binseg/configs/models/driu_default.py b/bob/ip/binseg/configs/models/driu.py
similarity index 86%
rename from bob/ip/binseg/configs/models/driu_default.py
rename to bob/ip/binseg/configs/models/driu.py
index 2e9ef8d1..c801206b 100644
--- a/bob/ip/binseg/configs/models/driu_default.py
+++ b/bob/ip/binseg/configs/models/driu.py
@@ -4,12 +4,10 @@
 from torch.optim.lr_scheduler import MultiStepLR
 from bob.ip.binseg.modeling.driu import build_driu
 import torch.optim as optim
-from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
 from bob.ip.binseg.utils.model_zoo import modelurls
 
 ##### Config #####
-pretrained_weight = 'vgg16'
-
 lr = 0.001
 betas = (0.9, 0.999)
 eps = 1e-08
@@ -29,7 +27,7 @@ pretrained_backbone = modelurls['vgg16']
 optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
     
 # criterion
-criterion = BCEWithLogitsLoss()
+criterion = WeightedBCELogitsLoss(reduction='mean')
 
 # scheduler
 scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/configs/models/driuadabound.py b/bob/ip/binseg/configs/models/driuadabound.py
new file mode 100644
index 00000000..b17e1fac
--- /dev/null
+++ b/bob/ip/binseg/configs/models/driuadabound.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.driu import build_driu
+import torch.optim as optim
+from torch.nn import BCEWithLogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+from bob.ip.binseg.modeling.losses import WeightedBCELogitsLoss
+from bob.ip.binseg.engine.adabound import AdaBound
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+final_lr = 0.1
+gamma = 1e-3
+eps = 1e-8
+amsbound = False
+
+scheduler_milestones = [150]
+scheduler_gamma = 0.1
+
+# model
+model = build_driu()
+
+# pretrained backbone
+pretrained_backbone = modelurls['vgg16']
+
+# optimizer
+# TODO: Add Adabound
+optimizer = AdaBound(model.parameters(), lr=lr, betas=betas, final_lr=final_lr, gamma=gamma,
+                 eps=eps, weight_decay=weight_decay, amsbound=amsbound) 
+    
+# criterion
+criterion = WeightedBCELogitsLoss(reduction='mean')
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/configs/models/hed.py b/bob/ip/binseg/configs/models/hed.py
new file mode 100644
index 00000000..0677ef59
--- /dev/null
+++ b/bob/ip/binseg/configs/models/hed.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from torch.optim.lr_scheduler import MultiStepLR
+from bob.ip.binseg.modeling.hed import build_hed
+import torch.optim as optim
+from bob.ip.binseg.modeling.losses import HEDWeightedBCELogitsLoss
+from bob.ip.binseg.utils.model_zoo import modelurls
+
+##### Config #####
+lr = 0.001
+betas = (0.9, 0.999)
+eps = 1e-08
+weight_decay = 0
+amsgrad = False
+
+scheduler_milestones = [150]
+scheduler_gamma = 0.1
+
+# model
+model = build_hed()
+
+# pretrained backbone
+pretrained_backbone = modelurls['vgg16']
+
+# optimizer
+optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
+    
+# criterion
+criterion = HEDWeightedBCELogitsLoss(reduction='mean')
+
+# scheduler
+scheduler = MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)
diff --git a/bob/ip/binseg/data/transforms.py b/bob/ip/binseg/data/transforms.py
index 3229d7e0..4518668d 100644
--- a/bob/ip/binseg/data/transforms.py
+++ b/bob/ip/binseg/data/transforms.py
@@ -4,6 +4,8 @@
 import torchvision.transforms.functional as VF
 import random
 from PIL import Image
+from torchvision.transforms.transforms import Lambda
+from torchvision.transforms.transforms import Compose as TorchVisionCompose
 
 # Compose 
 
@@ -161,5 +163,72 @@ class RandomRotation:
         if random.random() < self.prob:
             degree = random.randint(*self.degree_range)
             return [VF.rotate(img, degree, resample = Image.BILINEAR) for img in args]
+        else:
+            return args
+
+class ColorJitter(object):
+    """ 
+    Randomly change the brightness, contrast and saturation of an image.
+    
+    Attributes
+    -----------
+
+        brightness : float
+                        How much to jitter brightness. brightness_factor
+                        is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
+        contrast : float
+                        How much to jitter contrast. contrast_factor
+                        is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
+        saturation : float
+                        How much to jitter saturation. saturation_factor
+                        is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
+        hue : float
+                How much to jitter hue. hue_factor is chosen uniformly from
+                [-hue, hue]. Should be >=0 and <= 0.5.
+
+    """
+    def __init__(self,prob=0.5, brightness=0.3, contrast=0.3, saturation=0.02, hue=0.02):
+        self.brightness = brightness
+        self.contrast = contrast
+        self.saturation = saturation
+        self.hue = hue
+        self.prob = prob
+
+    @staticmethod
+    def get_params(brightness, contrast, saturation, hue):
+        """Get a randomized transform to be applied on image.
+        Arguments are same as that of __init__.
+        Returns:
+            Transform which randomly adjusts brightness, contrast and
+            saturation in a random order.
+        """
+        transforms = []
+        if brightness > 0:
+            brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
+            transforms.append(Lambda(lambda img: VF.adjust_brightness(img, brightness_factor)))
+
+        if contrast > 0:
+            contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
+            transforms.append(Lambda(lambda img: VF.adjust_contrast(img, contrast_factor)))
+
+        if saturation > 0:
+            saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
+            transforms.append(Lambda(lambda img: VF.adjust_saturation(img, saturation_factor)))
+
+        if hue > 0:
+            hue_factor = random.uniform(-hue, hue)
+            transforms.append(Lambda(lambda img: VF.adjust_hue(img, hue_factor)))
+
+        random.shuffle(transforms)
+        transform = TorchVisionCompose(transforms)
+
+        return transform
+
+    def __call__(self, *args):
+        if random.random() < self.prob:
+            transform = self.get_params(self.brightness, self.contrast,
+                                        self.saturation, self.hue)
+            trans_img = transform(args[0])
+            return [trans_img, *args[1:]]
         else:
             return args
\ No newline at end of file
diff --git a/bob/ip/binseg/engine/adabound.py b/bob/ip/binseg/engine/adabound.py
index 60f63ead..735d0399 100644
--- a/bob/ip/binseg/engine/adabound.py
+++ b/bob/ip/binseg/engine/adabound.py
@@ -246,4 +246,4 @@ class AdaBoundW(Optimizer):
                 else:
                     p.data.add_(-step_size)
 
-return loss
\ No newline at end of file
+        return loss
\ No newline at end of file
diff --git a/bob/ip/binseg/engine/inference.py b/bob/ip/binseg/engine/inference.py
deleted file mode 100644
index 4e2edd78..00000000
--- a/bob/ip/binseg/engine/inference.py
+++ /dev/null
@@ -1,54 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from bob.db.drive import Database as DRIVE
-from bob.ip.binseg.data.binsegdataset import BinSegDataset
-from bob.ip.binseg.data.transforms import ToTensor
-from bob.ip.binseg.engine.inferencer import do_inference
-from bob.ip.binseg.modeling.driu import build_driu
-from torch.utils.data import DataLoader
-from bob.ip.binseg.utils.checkpointer import Checkpointer
-
-
-import logging
-logging.basicConfig(level=logging.INFO)
-
-logger = logging.getLogger("bob.ip.binseg.engine.inferencer")
-
-def inference():
-    # bob.db.dataset init
-    drive = DRIVE()
-    
-    # Transforms 
-    transforms = ToTensor()
-
-    # PyTorch dataset
-    bsdataset = BinSegDataset(drive,split='test', transform=transforms)
-    
-    # Build model
-    model = build_driu()
-    
-    # Dataloader
-    data_loader = DataLoader(
-        dataset = bsdataset
-        ,batch_size = 2
-        ,shuffle= False
-        ,pin_memory = False
-        )
-    
-    # checkpointer, load last model in dir
-    checkpointer = Checkpointer(model, save_dir = "./output_temp", save_to_disk=False)
-    checkpointer.load()
-
-    # device 
-    device = "cpu"
-    logger.info("Run inference and calculate evaluation metrics")
-    do_inference(model
-            , data_loader
-            , device
-            , "./output_temp"
-            )
-
-
-if __name__ == '__main__':
-    inference()
\ No newline at end of file
diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py
index f1307372..bc1dc9e8 100644
--- a/bob/ip/binseg/engine/inferencer.py
+++ b/bob/ip/binseg/engine/inferencer.py
@@ -128,6 +128,10 @@ def do_inference(
             start_time = time.perf_counter()
 
             outputs = model(images)
+            # necessary check for hed architecture that uses several outputs 
+            # for loss calculation instead of just the last concatfuse block
+            if isinstance(outputs,list):
+                outputs = outputs[-1]
             probabilities = sigmoid(outputs)
             
             batch_time = time.perf_counter() - start_time
diff --git a/bob/ip/binseg/engine/train.py b/bob/ip/binseg/engine/train.py
deleted file mode 100644
index 3dee98ec..00000000
--- a/bob/ip/binseg/engine/train.py
+++ /dev/null
@@ -1,82 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from bob.db.drive import Database as DRIVE
-from bob.ip.binseg.data.binsegdataset import BinSegDataset
-from bob.ip.binseg.data.transforms import ToTensor
-from bob.ip.binseg.engine.trainer import do_train
-from bob.ip.binseg.modeling.driu import build_driu
-from torch.utils.data import DataLoader
-import torch.optim as optim
-from torch.optim.lr_scheduler import MultiStepLR
-from bob.ip.binseg.utils.checkpointer import Checkpointer, DetectronCheckpointer
-from torch.nn import BCEWithLogitsLoss
-
-import logging
-logging.basicConfig(level=logging.DEBUG)
-
-logger = logging.getLogger("bob.ip.binseg.engine.train")
-
-def train():
-    # bob.db.dataset init
-    drive = DRIVE()
-    
-    # Transforms 
-    transforms = ToTensor()
-
-    # PyTorch dataset
-    bsdataset = BinSegDataset(drive,split='train', transform=transforms)
-    
-    # Build model
-    model = build_driu()
-    
-    # Dataloader
-    data_loader = DataLoader(
-        dataset = bsdataset
-        ,batch_size = 2
-        ,shuffle= True
-        ,pin_memory = False
-        )
-    
-    # optimizer
-    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
-    
-    # criterion
-    criterion = BCEWithLogitsLoss()
-
-    # scheduler
-    scheduler = MultiStepLR(optimizer, milestones=[150], gamma=0.1)
-
-    # checkpointer
-    checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = "./output_temp", save_to_disk=True)
-
-    # checkpoint period
-    checkpoint_period = 2
-
-    # pretrained backbone
-    pretraind_backbone = model_urls['vgg16']
-
-    # device 
-    device = "cpu"
-    
-    # arguments 
-    arguments = {}
-    arguments["epoch"] = 0 
-    arguments["max_epoch"] = 6
-    extra_checkpoint_data = checkpointer.load(pretraind_backbone)
-    arguments.update(extra_checkpoint_data)
-    logger.info("Training for {} epochs".format(arguments["max_epoch"]))
-    logger.info("Continuing from epoch {}".format(arguments["epoch"]))
-    do_train(model
-            , data_loader
-            , optimizer
-            , criterion
-            , scheduler
-            , checkpointer
-            , checkpoint_period
-            , device
-            , arguments)
-
-
-if __name__ == '__main__':
-    train()
\ No newline at end of file
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index ac75123c..c84a0733 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -6,8 +6,11 @@ import time
 import datetime
 from tqdm import tqdm
 import torch
+import os 
+import pandas as pd
 
 from bob.ip.binseg.utils.metric import SmoothedValue
+from bob.ip.binseg.utils.plot import loss_curve
 
 def do_train(
     model,
@@ -18,7 +21,8 @@ def do_train(
     checkpointer,
     checkpoint_period,
     device,
-    arguments
+    arguments,
+    output_folder
 ):
     """ Trains the model """
     logger = logging.getLogger("bob.ip.binseg.engine.trainer")
@@ -26,71 +30,88 @@ def do_train(
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
-    model.train().to(device)
-    # Total training timer
-    start_training_time = time.time()
-
-    for epoch in range(start_epoch, max_epoch):
-        scheduler.step()
-        losses = SmoothedValue(len(data_loader))
-        epoch = epoch + 1
-        arguments["epoch"] = epoch
-        start_epoch_time = time.time()
+    # Logg to file
+    with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile:
         
-        for images, ground_truths, masks, _ in tqdm(data_loader):
+        model.train().to(device)
+        # Total training timer
+        start_training_time = time.time()
 
-            images = images.to(device)
-            ground_truths = ground_truths.to(device)
-            #masks = masks.to(device) 
+        for epoch in range(start_epoch, max_epoch):
+            scheduler.step()
+            losses = SmoothedValue(len(data_loader))
+            epoch = epoch + 1
+            arguments["epoch"] = epoch
+            start_epoch_time = time.time()
 
-            outputs = model(images)
-            loss = criterion(outputs, ground_truths)
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
+            for images, ground_truths, masks, _ in tqdm(data_loader):
 
-            losses.update(loss)
-            logger.debug("batch loss: {}".format(loss.item()))
+                images = images.to(device)
+                ground_truths = ground_truths.to(device)
+                #masks = masks.to(device) 
 
-        if epoch % checkpoint_period == 0:
-            checkpointer.save("model_{:03d}".format(epoch), **arguments)
-        
-        if epoch == max_epoch:
-            checkpointer.save("model_final", **arguments)
-        
-        epoch_time = time.time() - start_epoch_time
-        
+                outputs = model(images)
+                loss = criterion(outputs, ground_truths)
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
 
-        eta_seconds = epoch_time * (max_epoch - epoch)
-        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
-
-           
-        logger.info(("eta: {eta}, " 
-                    "epoch: {epoch}, "
-                    "avg. loss: {avg_loss:.6f}, "
-                    "median loss: {median_loss:.6f}, "
-                    "lr: {lr:.6f}, "
-                    "max mem: {memory:.0f}"
-                    ).format(
-                eta=eta_string,
-                epoch=epoch,
-                avg_loss=losses.avg,
-                median_loss=losses.median,
-                lr=optimizer.param_groups[0]["lr"],
-                memory = 0.0
-                # TODO: uncomment for CUDA
-                #memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
-                )
-            )
-        
+                losses.update(loss)
+                logger.debug("batch loss: {}".format(loss.item()))
 
-    total_training_time = time.time() - start_training_time
-    total_time_str = str(datetime.timedelta(seconds=total_training_time))
-    logger.info(
-        "Total training time: {} ({:.4f} s / epoch)".format(
-            total_time_str, total_training_time / (max_epoch)
-        )
-)
+            if epoch % checkpoint_period == 0:
+                checkpointer.save("model_{:03d}".format(epoch), **arguments)
 
+            if epoch == max_epoch:
+                checkpointer.save("model_final", **arguments)
 
+            epoch_time = time.time() - start_epoch_time
 
+
+            eta_seconds = epoch_time * (max_epoch - epoch)
+            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+
+            outfile.write(("{epoch}, "
+                        "{avg_loss:.6f}, "
+                        "{median_loss:.6f}, "
+                        "{lr:.6f}, "
+                        "{memory:.0f}"
+                        "\n"
+                        ).format(
+                    eta=eta_string,
+                    epoch=epoch,
+                    avg_loss=losses.avg,
+                    median_loss=losses.median,
+                    lr=optimizer.param_groups[0]["lr"],
+                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0,
+                    )
+                )  
+            logger.info(("eta: {eta}, " 
+                        "epoch: {epoch}, "
+                        "avg. loss: {avg_loss:.6f}, "
+                        "median loss: {median_loss:.6f}, "
+                        "lr: {lr:.6f}, "
+                        "max mem: {memory:.0f}"
+                        ).format(
+                    eta=eta_string,
+                    epoch=epoch,
+                    avg_loss=losses.avg,
+                    median_loss=losses.median,
+                    lr=optimizer.param_groups[0]["lr"],
+                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0
+                    )
+                )
+
+
+        total_training_time = time.time() - start_training_time
+        total_time_str = str(datetime.timedelta(seconds=total_training_time))
+        logger.info(
+            "Total training time: {} ({:.4f} s / epoch)".format(
+                total_time_str, total_training_time / (max_epoch)
+            ))
+        
+    log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name))
+    logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss","lr","max memory"])
+    fig = loss_curve(logdf)
+    logger.info("saving {}".format(log_plot_file))
+    fig.savefig(log_plot_file)
diff --git a/bob/ip/binseg/modeling/driu.py b/bob/ip/binseg/modeling/driu.py
index e57d2425..b1478f63 100644
--- a/bob/ip/binseg/modeling/driu.py
+++ b/bob/ip/binseg/modeling/driu.py
@@ -53,7 +53,6 @@ class DRIU(nn.Module):
         return out
 
 def build_driu():
-    #backbone = vgg16(pretrained=False, return_features = [1, 4, 8, 12])
     backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22])
     driu_head = DRIU([64, 128, 256, 512])
 
diff --git a/bob/ip/binseg/modeling/hed.py b/bob/ip/binseg/modeling/hed.py
index f9026cd9..6a3e1d8c 100644
--- a/bob/ip/binseg/modeling/hed.py
+++ b/bob/ip/binseg/modeling/hed.py
@@ -56,7 +56,7 @@ class HED(nn.Module):
         return out
 
 def build_hed():
-    backbone = vgg16(pretrained=False, return_features = [1, 4, 8, 12, 16])
+    backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22, 29])
     hed_head = HED([64, 128, 256, 512, 512])
 
     model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", hed_head)]))
diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py
index af18c101..e094dbe6 100644
--- a/bob/ip/binseg/modeling/losses.py
+++ b/bob/ip/binseg/modeling/losses.py
@@ -1,18 +1,49 @@
 import torch
+from torch.nn.modules.loss import _Loss
+from torch._jit_internal import weak_script_method
 
-# TODO: REWRITE AS loss class, similary to BCELoss!
-def weighted_cross_entropy_loss(preds, edges):
-    """ Calculate sum of weighted cross entropy loss.
-    https://github.com/xwjabc/hed/blob/master/hed.py
+class WeightedBCELogitsLoss(_Loss):
+    """ 
+    Calculate sum of weighted cross entropy loss. Use for binary classification.
     """
-    mask = (edges > 0.5).float()
-    b, c, h, w = mask.shape
-    num_pos = torch.sum(mask, dim=[1, 2, 3]).float()  # Shape: [b,].
-    num_neg = c * h * w - num_pos                     # Shape: [b,].
-    weight = torch.zeros_like(mask)
-    weight[edges > 0.5]  = num_neg / (num_pos + num_neg)
-    weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
-    # Calculate loss.
-    losses = torch.nn.functional.binary_cross_entropy(preds.float(), edges.float(), weight=weight, reduction='none')
-    loss   = torch.sum(losses) / b
-    return loss
+    def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None):
+        super(WeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction)
+        self.register_buffer('weight', weight)
+        self.register_buffer('pos_weight', pos_weight)
+
+    @weak_script_method
+    def forward(self, input, target):
+        n, c, h, w = target.shape
+        num_pos = torch.sum(target, dim=[1, 2, 3]).float().reshape(n,1) # torch.Size([n, 1])
+        num_neg = c * h * w - num_pos  # torch.Size([n, 1])
+        numposnumtotal = torch.ones_like(target) * (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2)
+        numnegnumtotal = torch.ones_like(target) * (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2)
+        weight = torch.where((target <= 0.5) , numposnumtotal, numnegnumtotal)
+
+        loss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction=self.reduction)
+        return loss 
+
+class HEDWeightedBCELogitsLoss(_Loss):
+    """ 
+    Calculate sum of weighted cross entropy loss. Use for binary classification.
+    """
+    def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None):
+        super(HEDWeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction)
+        self.register_buffer('weight', weight)
+        self.register_buffer('pos_weight', pos_weight)
+
+    @weak_script_method
+    def forward(self, inputlist, target):
+        loss_over_all_inputs = []
+        for input in inputlist:
+            n, c, h, w = target.shape
+            num_pos = torch.sum(target, dim=[1, 2, 3]).float().reshape(n,1) # torch.Size([n, 1])
+            num_neg = c * h * w - num_pos  # torch.Size([n, 1])
+            numposnumtotal = torch.ones_like(target) * (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2)
+            numnegnumtotal = torch.ones_like(target) * (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2)
+            weight = torch.where((target <= 0.5) , numposnumtotal, numnegnumtotal)
+
+            loss = torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, reduction=self.reduction)
+            loss_over_all_inputs.append(loss.unsqueeze(0))
+        final_loss = torch.cat(loss_over_all_inputs).mean()
+        return final_loss 
\ No newline at end of file
diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py
index b9d6f138..0db77074 100644
--- a/bob/ip/binseg/script/binseg.py
+++ b/bob/ip/binseg/script/binseg.py
@@ -9,6 +9,7 @@ import time
 import numpy
 import collections
 import pkg_resources
+import glob
 
 import click
 from click_plugins import with_plugins
@@ -21,7 +22,6 @@ from bob.extension.scripts.click_helper import (verbosity_option,
     ConfigCommand, ResourceOption, AliasedGroup)
 
 from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer
-from bob.ip.binseg.data.binsegdataset import BinSegDataset
 from torch.utils.data import DataLoader
 from bob.ip.binseg.engine.trainer import do_train
 from bob.ip.binseg.engine.inferencer import do_inference
@@ -51,40 +51,31 @@ def binseg():
     cls=ResourceOption
     )
 @click.option(
-    '--optimizer',
+    '--dataset',
+    '-d',
     required=True,
     cls=ResourceOption
     )
 @click.option(
-    '--criterion',
+    '--optimizer',
     required=True,
     cls=ResourceOption
     )
 @click.option(
-    '--scheduler',
+    '--criterion',
     required=True,
     cls=ResourceOption
     )
 @click.option(
-    '--pretrained-backbone',
+    '--scheduler',
     required=True,
     cls=ResourceOption
     )
 @click.option(
-    '--bobdb',
+    '--pretrained-backbone',
     required=True,
     cls=ResourceOption
     )
-@click.option(
-    '--split',
-    '-s',
-    required=True,
-    default='train',
-    cls=ResourceOption)
-@click.option(
-    '--transforms',
-    required=True,
-    cls=ResourceOption)
 @click.option(
     '--batch-size',
     '-b',
@@ -123,22 +114,17 @@ def train(model
         ,output_path
         ,epochs
         ,pretrained_backbone
-        ,split
         ,batch_size
         ,criterion
-        ,bobdb
-        ,transforms
+        ,dataset
         ,checkpoint_period
         ,device
         ,**kwargs):
     if not os.path.exists(output_path): os.makedirs(output_path)
     
-    # PyTorch dataset
-    bsdataset = BinSegDataset(bobdb, split=split, transform=transforms)
-
     # PyTorch dataloader
     data_loader = DataLoader(
-        dataset = bsdataset
+        dataset = dataset
         ,batch_size = batch_size
         ,shuffle= True
         ,pin_memory = torch.cuda.is_available()
@@ -148,9 +134,9 @@ def train(model
     checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = output_path, save_to_disk=True)
     arguments = {}
     arguments["epoch"] = 0 
-    arguments["max_epoch"] = epochs
     extra_checkpoint_data = checkpointer.load(pretrained_backbone)
     arguments.update(extra_checkpoint_data)
+    arguments["max_epoch"] = epochs
     
     # Train
     logger.info("Training for {} epochs".format(arguments["max_epoch"]))
@@ -164,8 +150,10 @@ def train(model
             , checkpoint_period
             , device
             , arguments
+            , output_path
             )
 
+
 # Inference
 @binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
 @click.option(
@@ -182,20 +170,69 @@ def train(model
     cls=ResourceOption
     )
 @click.option(
-    '--bobdb',
+    '--dataset',
+    '-d',
     required=True,
     cls=ResourceOption
     )
 @click.option(
-    '--transforms',
+    '--batch-size',
+    '-b',
     required=True,
+    default=2,
     cls=ResourceOption)
 @click.option(
-    '--split',
-    '-s',
+    '--device',
+    '-d',
+    help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
+    show_default=True,
     required=True,
-    default='test',
+    default='cpu',
     cls=ResourceOption)
+@verbosity_option(cls=ResourceOption)
+def test(model
+        ,output_path
+        ,device
+        ,batch_size
+        ,dataset
+        , **kwargs):
+
+
+    # PyTorch dataloader
+    data_loader = DataLoader(
+        dataset = dataset
+        ,batch_size = batch_size
+        ,shuffle= False
+        ,pin_memory = torch.cuda.is_available()
+        )
+    
+    # checkpointer, load last model in dir
+    checkpointer = DetectronCheckpointer(model, save_dir = output_path, save_to_disk=False)
+    checkpointer.load()
+    do_inference(model, data_loader, device, output_path)
+
+
+# Inference all checkpoints
+@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
+@click.option(
+    '--output-path',
+    '-o',
+    required=True,
+    default="output",
+    cls=ResourceOption
+    )
+@click.option(
+    '--model',
+    '-m',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--dataset',
+    '-d',
+    required=True,
+    cls=ResourceOption
+    )
 @click.option(
     '--batch-size',
     '-b',
@@ -210,29 +247,32 @@ def train(model
     required=True,
     default='cpu',
     cls=ResourceOption)
-
 @verbosity_option(cls=ResourceOption)
-def test(model
+def testcheckpoints(model
         ,output_path
         ,device
-        ,split
         ,batch_size
-        ,bobdb
-        ,transforms
+        ,dataset
         , **kwargs):
 
-    # PyTorch dataset
-    bsdataset = BinSegDataset(bobdb, split=split, transform=transforms)
 
     # PyTorch dataloader
     data_loader = DataLoader(
-        dataset = bsdataset
+        dataset = dataset
         ,batch_size = batch_size
         ,shuffle= False
         ,pin_memory = torch.cuda.is_available()
         )
-
-    # checkpointer, load last model in dir
-    checkpointer = DetectronCheckpointer(model, save_dir = output_path, save_to_disk=False)
-    checkpointer.load()
-    do_inference(model, data_loader, device, output_path)
\ No newline at end of file
+    
+    # list checkpoints
+    ckpts = glob.glob(os.path.join(output_path,"*.pth"))
+    # output
+    for checkpoint in ckpts:
+        ckpts_name = os.path.basename(checkpoint).split('.')[0]
+        logger.info("Testing checkpoint: {}".format(ckpts_name))
+        output_subfolder = os.path.join(output_path, ckpts_name)
+        if not os.path.exists(output_subfolder): os.makedirs(output_subfolder)
+        # checkpointer, load last model in dir
+        checkpointer = DetectronCheckpointer(model, save_dir = output_subfolder, save_to_disk=False)
+        checkpointer.load(checkpoint)
+        do_inference(model, data_loader, device, output_subfolder)
\ No newline at end of file
diff --git a/bob/ip/binseg/test/test_transforms.py b/bob/ip/binseg/test/test_transforms.py
new file mode 100644
index 00000000..479cd79c
--- /dev/null
+++ b/bob/ip/binseg/test/test_transforms.py
@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import torch
+import unittest
+import numpy as np
+from bob.ip.binseg.data.transforms import *
+
+transforms = Compose([
+                        RandomHFlip(prob=1)
+                        ,RandomHFlip(prob=1)
+                        ,RandomVFlip(prob=1)
+                        ,RandomVFlip(prob=1)
+                    ])
+
+def create_img():
+    t = torch.randn((3,42,24))
+    pil = VF.to_pil_image(t)
+    return pil
+
+
+class Tester(unittest.TestCase):
+    """
+    Unit test for random flips
+    """
+    
+    def test_flips(self):
+        transforms = Compose([
+                        RandomHFlip(prob=1)
+                        ,RandomHFlip(prob=1)
+                        ,RandomVFlip(prob=1)
+                        ,RandomVFlip(prob=1)
+                    ])
+        img, gt, mask = [create_img() for i in range(3)]
+        img_t, gt_t, mask_t = transforms(img, gt, mask)
+        self.assertTrue(np.all(np.array(img_t) == np.array(img)))
+        self.assertTrue(np.all(np.array(gt_t) == np.array(gt)))
+        self.assertTrue(np.all(np.array(mask_t) == np.array(mask)))
+
+    def test_to_tensor(self):
+        transforms = ToTensor()
+        img, gt, mask = [create_img() for i in range(3)]
+        img_t, gt_t, mask_t = transforms(img, gt, mask)
+        self.assertEqual(str(img_t.dtype),"torch.float32")
+        self.assertEqual(str(gt_t.dtype),"torch.float32")
+        self.assertEqual(str(mask_t.dtype),"torch.float32")
+
+if __name__ == '__main__':
+    unittest.main()
\ No newline at end of file
diff --git a/bob/ip/binseg/utils/model_serialization.py b/bob/ip/binseg/utils/model_serialization.py
index cf98118e..3e7b565a 100644
--- a/bob/ip/binseg/utils/model_serialization.py
+++ b/bob/ip/binseg/utils/model_serialization.py
@@ -20,9 +20,7 @@ def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
     backbone[0].body.res2.conv1.weight to res2.conv1.weight.
     """
     current_keys = sorted(list(model_state_dict.keys()))
-    print(current_keys)
     loaded_keys = sorted(list(loaded_state_dict.keys()))
-    print(loaded_keys)
     # get a matrix of string matches, where each (i, j) entry correspond to the size of the
     # loaded_key string, if it matches
     match_matrix = [
diff --git a/bob/ip/binseg/utils/model_zoo.py b/bob/ip/binseg/utils/model_zoo.py
index 5077707e..e8f94abf 100644
--- a/bob/ip/binseg/utils/model_zoo.py
+++ b/bob/ip/binseg/utils/model_zoo.py
@@ -17,6 +17,7 @@ import warnings
 import zipfile
 from urllib.request import urlopen
 from urllib.parse import urlparse
+from tqdm import tqdm 
 
 modelurls = {
     "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py
index 6456188b..fb838e25 100644
--- a/bob/ip/binseg/utils/plot.py
+++ b/bob/ip/binseg/utils/plot.py
@@ -33,6 +33,8 @@ def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds
       figure : matplotlib.figure.Figure
         A matplotlib figure you can save or display 
     ''' 
+    import matplotlib
+    matplotlib.use('agg')
     import matplotlib.pyplot as plt 
     fig, ax1 = plt.subplots(1)  
     for p, r, n in zip(precision, recall, names):   
@@ -86,4 +88,31 @@ def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds
     ax2.spines['left'].set_visible(False)
     ax2.spines['bottom'].set_visible(False) 
     plt.tight_layout()  
-    return fig  
\ No newline at end of file
+    return fig  
+
+
+
+def loss_curve(df):
+    ''' Creates a loss curve
+    Dataframe with column names:
+    ["avg. loss", "median loss","lr","max memory"]
+    Arguments
+    ---------
+    df : :py:class.`pandas.DataFrame`
+    
+    Returns
+    -------
+    fig : matplotlib.figure.Figure
+    ''' 
+    import matplotlib
+    matplotlib.use('agg')
+    import matplotlib.pyplot as plt 
+    ax1 = df.plot(y="median loss", grid=True)
+    ax1.set_ylabel('median loss')
+    ax1.grid(linestyle='--', linewidth=1, color='gray', alpha=0.2)
+    ax2 = df['lr'].plot(secondary_y=True,legend=True,grid=True,)
+    ax2.set_ylabel('lr')
+    ax1.set_xlabel('epoch')
+    plt.tight_layout()  
+    fig = ax1.get_figure()
+    return fig
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 2b4c6c0c..2a2a92f5 100644
--- a/setup.py
+++ b/setup.py
@@ -47,12 +47,17 @@ setup(
          #bob hed sub-commands
         'bob.ip.binseg.cli': [
           'train = bob.ip.binseg.script.binseg:train',
+          'test = bob.ip.binseg.script.binseg:test',
+          'testcheckpoints = bob.ip.binseg.script.binseg:testcheckpoints',
         ],
 
          #bob hed train configurations
         'bob.ip.binseg.config': [
-          'DriuDefault = bob.ip.binseg.configs.models.driu_default',
-          'DriveDefault = bob.ip.binseg.configs.datasets.drive_default',
+          'DRIU = bob.ip.binseg.configs.models.driu',
+          'HED = bob.ip.binseg.configs.models.hed',
+          'DRIUADABOUND = bob.ip.binseg.configs.models.driuadabound',
+          'DRIVETRAIN = bob.ip.binseg.configs.datasets.drivetrain',
+          'DRIVETEST = bob.ip.binseg.configs.datasets.drivetest',
           ]
     },
 
-- 
GitLab