From 915f4e79007698ab6721d8d64fcc5f1454849bf7 Mon Sep 17 00:00:00 2001
From: Tim Laibacher <tim.laibacher@idiap.ch>
Date: Mon, 24 Jun 2019 16:53:08 +0200
Subject: [PATCH] Add valtrainer

---
 .../datasets/drivestarechasedb1hrf1024.py     |  10 ++
 bob/ip/binseg/configs/datasets/hrf1024.py     |  24 +++
 bob/ip/binseg/engine/valtrainer.py            | 151 ++++++++++++++++++
 bob/ip/binseg/script/binseg.py                | 151 +++++++++++++++++-
 bob/ip/binseg/utils/plot.py                   |   9 +-
 setup.py                                      |   2 +
 6 files changed, 341 insertions(+), 6 deletions(-)
 create mode 100644 bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py
 create mode 100644 bob/ip/binseg/configs/datasets/hrf1024.py
 create mode 100644 bob/ip/binseg/engine/valtrainer.py

diff --git a/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py
new file mode 100644
index 00000000..267fc45c
--- /dev/null
+++ b/bob/ip/binseg/configs/datasets/drivestarechasedb1hrf1024.py
@@ -0,0 +1,10 @@
+from bob.ip.binseg.configs.datasets.drive1024 import dataset as drive
+from bob.ip.binseg.configs.datasets.stare1024 import dataset as stare
+from bob.ip.binseg.configs.datasets.hrf1024 import dataset as hrf
+from bob.ip.binseg.configs.datasets.chasedb11024 import dataset as chase
+import torch
+
+#### Config ####
+
+# PyTorch dataset
+dataset = torch.utils.data.ConcatDataset([drive,stare,hrf,chase])
\ No newline at end of file
diff --git a/bob/ip/binseg/configs/datasets/hrf1024.py b/bob/ip/binseg/configs/datasets/hrf1024.py
new file mode 100644
index 00000000..2574901a
--- /dev/null
+++ b/bob/ip/binseg/configs/datasets/hrf1024.py
@@ -0,0 +1,24 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from bob.db.hrf import Database as HRF
+from bob.ip.binseg.data.transforms import *
+from bob.ip.binseg.data.binsegdataset import BinSegDataset
+
+#### Config ####
+
+transforms = Compose([  
+                        Pad((0,584,0,584))                    
+                        ,Resize((1024))
+                        ,RandomRotation()
+                        ,RandomHFlip()
+                        ,RandomVFlip()
+                        ,ColorJitter()
+                        ,ToTensor()
+                    ])
+
+# bob.db.dataset init
+bobdb = HRF(protocol = 'default')
+
+# PyTorch dataset
+dataset = BinSegDataset(bobdb, split='train', transform=transforms)
\ No newline at end of file
diff --git a/bob/ip/binseg/engine/valtrainer.py b/bob/ip/binseg/engine/valtrainer.py
new file mode 100644
index 00000000..838798ee
--- /dev/null
+++ b/bob/ip/binseg/engine/valtrainer.py
@@ -0,0 +1,151 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os 
+import logging
+import time
+import datetime
+import torch
+import pandas as pd
+from tqdm import tqdm
+
+from bob.ip.binseg.utils.metric import SmoothedValue
+from bob.ip.binseg.utils.plot import loss_curve
+from bob.ip.binseg.engine.inferencer import do_inference
+
+
+def do_valtrain(
+    model,
+    data_loader,
+    optimizer,
+    criterion,
+    scheduler,
+    checkpointer,
+    checkpoint_period,
+    device,
+    arguments,
+    output_folder,
+    val_loader = None
+):
+    """ 
+    Train model and save to disk.
+    
+    Parameters
+    ----------
+    model : :py:class:`torch.nn.Module` 
+        Network (e.g. DRIU, HED, UNet)
+    data_loader : :py:class:`torch.utils.data.DataLoader`
+    optimizer : :py:mod:`torch.optim`
+    criterion : :py:class:`torch.nn.modules.loss._Loss`
+        loss function
+    scheduler : :py:mod:`torch.optim`
+        learning rate scheduler
+    checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
+        checkpointer
+    checkpoint_period : int
+        save a checkpoint every n epochs
+    device : str  
+        device to use ``'cpu'`` or ``'cuda'``
+    arguments : dict
+        start end end epochs
+    output_folder : str 
+        output path
+    """
+    torch.manual_seed(42)
+    logger = logging.getLogger("bob.ip.binseg.engine.trainer")
+    logger.info("Start training")
+    start_epoch = arguments["epoch"]
+    max_epoch = arguments["max_epoch"]
+
+    # Logg to file
+    with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+", 1) as outfile:
+        
+        model.train().to(device)
+        # Total training timer
+        start_training_time = time.time()
+
+        for epoch in range(start_epoch, max_epoch):
+            scheduler.step()
+            losses = SmoothedValue(len(data_loader))
+            epoch = epoch + 1
+            arguments["epoch"] = epoch
+            
+            # Epoch time
+            start_epoch_time = time.time()
+
+            for samples in tqdm(data_loader):
+
+                images = samples[1].to(device)
+                ground_truths = samples[2].to(device)
+                masks = None
+                if len(samples) == 4:
+                    masks = samples[-1].to(device)
+                
+                outputs = model(images)
+                
+                loss = criterion(outputs, ground_truths, masks)
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
+
+                losses.update(loss)
+                logger.debug("batch loss: {}".format(loss.item()))
+
+            if epoch % checkpoint_period == 0:
+                checkpointer.save("model_{:03d}".format(epoch), **arguments)
+                val_folder = os.path.join((output_folder,epoch))
+                do_inference(model,val_loader, device, val_folder)
+
+            if epoch == max_epoch:
+                checkpointer.save("model_final", **arguments)
+
+            epoch_time = time.time() - start_epoch_time
+
+
+            eta_seconds = epoch_time * (max_epoch - epoch)
+            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+
+            outfile.write(("{epoch}, "
+                        "{avg_loss:.6f}, "
+                        "{median_loss:.6f}, "
+                        "{lr:.6f}, "
+                        "{memory:.0f}"
+                        "\n"
+                        ).format(
+                    eta=eta_string,
+                    epoch=epoch,
+                    avg_loss=losses.avg,
+                    median_loss=losses.median,
+                    lr=optimizer.param_groups[0]["lr"],
+                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0,
+                    )
+                )  
+            logger.info(("eta: {eta}, " 
+                        "epoch: {epoch}, "
+                        "avg. loss: {avg_loss:.6f}, "
+                        "median loss: {median_loss:.6f}, "
+                        "lr: {lr:.6f}, "
+                        "max mem: {memory:.0f}"
+                        ).format(
+                    eta=eta_string,
+                    epoch=epoch,
+                    avg_loss=losses.avg,
+                    median_loss=losses.median,
+                    lr=optimizer.param_groups[0]["lr"],
+                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0
+                    )
+                )
+
+
+        total_training_time = time.time() - start_training_time
+        total_time_str = str(datetime.timedelta(seconds=total_training_time))
+        logger.info(
+            "Total training time: {} ({:.4f} s / epoch)".format(
+                total_time_str, total_training_time / (max_epoch)
+            ))
+        
+    log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name))
+    logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss","lr","max memory"])
+    fig = loss_curve(logdf,output_folder)
+    logger.info("saving {}".format(log_plot_file))
+    fig.savefig(log_plot_file)
diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py
index 2a90d5ed..c439b93c 100644
--- a/bob/ip/binseg/script/binseg.py
+++ b/bob/ip/binseg/script/binseg.py
@@ -24,12 +24,14 @@ from bob.extension.scripts.click_helper import (verbosity_option,
 from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer
 from torch.utils.data import DataLoader
 from bob.ip.binseg.engine.trainer import do_train
+from bob.ip.binseg.engine.valtrainer import do_valtrain
 from bob.ip.binseg.engine.inferencer import do_inference
 from bob.ip.binseg.utils.plot import plot_overview
 from bob.ip.binseg.utils.click import OptionEatAll
 from bob.ip.binseg.utils.pdfcreator import create_pdf, get_paths
 from bob.ip.binseg.utils.rsttable import create_overview_grid
 from bob.ip.binseg.utils.plot import metricsviz, overlay
+from torch.utils.data import SubsetRandomSampler
 
 logger = logging.getLogger(__name__)
 
@@ -308,12 +310,17 @@ def testcheckpoints(model
     '-o',
     required=True,
     )
+@click.option(
+    '--title',
+    '-t',
+    required=False,
+    )
 @verbosity_option(cls=ResourceOption)
-def compare(output_path_list, output_path, **kwargs):
+def compare(output_path_list, output_path, title, **kwargs):
     """ Compares multiple metrics files that are stored in the format mymodel/results/Metrics.csv """
     logger.debug("Output paths: {}".format(output_path_list))
     logger.info('Plotting precision vs recall curves for {}'.format(output_path_list))
-    fig = plot_overview(output_path_list)
+    fig = plot_overview(output_path_list,title)
     if not os.path.exists(output_path): os.makedirs(output_path)
     fig_filename = os.path.join(output_path, 'precision_recall_comparison.pdf')
     logger.info('saving {}'.format(fig_filename))
@@ -390,4 +397,142 @@ def visualize(dataset, output_path, **kwargs):
     logger.info('Creating TP, FP, FN visualizations for {}'.format(output_path))
     metricsviz(dataset=dataset, output_path=output_path)
     logger.info('Creating overlay visualizations for {}'.format(output_path))
-    overlay(dataset=dataset, output_path=output_path)
\ No newline at end of file
+    overlay(dataset=dataset, output_path=output_path)
+
+
+# Validation Train
+@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
+@click.option(
+    '--output-path',
+    '-o',
+    required=True,
+    default="output",
+    cls=ResourceOption
+    )
+@click.option(
+    '--model',
+    '-m',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--dataset',
+    '-d',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--optimizer',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--criterion',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--scheduler',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--pretrained-backbone',
+    '-t',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--batch-size',
+    '-b',
+    required=True,
+    default=2,
+    cls=ResourceOption)
+@click.option(
+    '--epochs',
+    '-e',
+    help='Number of epochs used for training',
+    show_default=True,
+    required=True,
+    default=6,
+    cls=ResourceOption)
+@click.option(
+    '--checkpoint-period',
+    '-p',
+    help='Number of epochs after which a checkpoint is saved',
+    show_default=True,
+    required=True,
+    default=2,
+    cls=ResourceOption)
+@click.option(
+    '--device',
+    '-d',
+    help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
+    show_default=True,
+    required=True,
+    default='cpu',
+    cls=ResourceOption)
+@click.option(
+    '--valsize',
+    '-a',
+    help='Size of validation set',
+    show_default=True,
+    required=True,
+    default=5,
+    cls=ResourceOption)
+@verbosity_option(cls=ResourceOption)
+def valtrain(model
+        ,optimizer
+        ,scheduler
+        ,output_path
+        ,epochs
+        ,pretrained_backbone
+        ,batch_size
+        ,criterion
+        ,dataset
+        ,checkpoint_period
+        ,device
+        ,valsize
+        ,**kwargs):
+    """ Train a model """
+    
+    if not os.path.exists(output_path): os.makedirs(output_path)
+    
+
+    # Validation and training set size
+    train_size = len(dataset) - valsize 
+    # PyTorch dataloader
+
+    indices = torch.randperm(len(dataset))
+    train_indices = indices[:len(indices)-valsize][:train_size or None]
+    valid_indices = indices[len(indices)-valsize:] if valsize else None
+
+    train_loader = torch.utils.data.DataLoader(dataset, pin_memory=torch.cuda.is_available(), batch_size=batch_size,
+                                                   sampler=SubsetRandomSampler(train_indices))
+
+    valid_loader = torch.utils.data.DataLoader(dataset, pin_memory=torch.cuda.is_available(), batch_size=batch_size,
+                                                   sampler=SubsetRandomSampler(valid_indices))
+
+    # Checkpointer
+    checkpointer = DetectronCheckpointer(model, optimizer, scheduler,save_dir = output_path, save_to_disk=True)
+    arguments = {}
+    arguments["epoch"] = 0 
+    extra_checkpoint_data = checkpointer.load(pretrained_backbone)
+    arguments.update(extra_checkpoint_data)
+    arguments["max_epoch"] = epochs
+    
+    # Train
+    logger.info("Training for {} epochs".format(arguments["max_epoch"]))
+    logger.info("Continuing from epoch {}".format(arguments["epoch"]))
+    do_valtrain(model
+            , train_loader
+            , optimizer
+            , criterion
+            , scheduler
+            , checkpointer
+            , checkpoint_period
+            , device
+            , arguments
+            , output_path
+            , valid_loader
+            )
\ No newline at end of file
diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py
index bedb7625..06dd666d 100644
--- a/bob/ip/binseg/utils/plot.py
+++ b/bob/ip/binseg/utils/plot.py
@@ -158,7 +158,7 @@ def read_metricscsv(file):
     return np.array(precision), np.array(recall)
 
 
-def plot_overview(outputfolders):
+def plot_overview(outputfolders,title):
     """
     Plots comparison chart of all trained models
     
@@ -166,6 +166,8 @@ def plot_overview(outputfolders):
     ----------
     outputfolder : list
         list containing output paths of all evaluated models (e.g. ``['DRIVE/model1', 'DRIVE/model2']``)
+    title : str
+        title of plot
     Returns
     -------
     matplotlib.figure.Figure
@@ -181,15 +183,16 @@ def plot_overview(outputfolders):
         precisions.append(pr)
         recalls.append(re)
         modelname = folder.split('/')[-1]
+        datasetname =  folder.split('/')[-2]
         # parameters
         summary_path = os.path.join(folder,'results/ModelSummary.txt')
         with open (summary_path, "r") as outfile:
           rows = outfile.readlines()
           lastrow = rows[-1]
           parameter = int(lastrow.split()[1].replace(',',''))
-        name = '[P={:.2f}M] {}'.format(parameter/100**3, modelname)
+        name = '[P={:.2f}M] {} {}'.format(parameter/100**3, modelname, datasetname)
         names.append(name)
-    title = folder.split('/')[-2]
+    #title = folder.split('/')[-4]
     fig = precision_recall_f1iso(precisions,recalls,names,title)
     return fig
 
diff --git a/setup.py b/setup.py
index 20671d7f..a7b02f03 100644
--- a/setup.py
+++ b/setup.py
@@ -78,10 +78,12 @@ setup(
           'DRIVE1024 = bob.ip.binseg.configs.datasets.drive1024',
           'DRIVE1168 = bob.ip.binseg.configs.datasets.drive1168',
           'DRIVETEST = bob.ip.binseg.configs.datasets.drivetest',
+          'DRIVESTARECHASEDB1HRF1024 = bob.ip.binseg.configs.datasets.drivestarechasedb1hrf1024',
           'DRIVESTARECHASEDB1IOSTAR1168 = bob.ip.binseg.configs.datasets.drivestarechasedb1iostar1168',
           'DRIVESTAREIOSTARHRF960 = bob.ip.binseg.configs.datasets.drivestareiostarhrf960',
           'HRF = bob.ip.binseg.configs.datasets.hrf',
           'HRF960 = bob.ip.binseg.configs.datasets.hrf960',
+          'HRF1024 = bob.ip.binseg.configs.datasets.hrf1024',
           'HRF1168 = bob.ip.binseg.configs.datasets.hrf1168',
           'HRFTEST = bob.ip.binseg.configs.datasets.hrftest',
           'IOSTAROD = bob.ip.binseg.configs.datasets.iostarod',
-- 
GitLab