From c6f07e8d212cd517887b242742d6808440f0c076 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Sat, 4 Apr 2020 17:56:01 +0200
Subject: [PATCH] [docs,scripts] Separate evaluation script; Improve trainer
 engine code; Reset network model names to match resource names

---
 bob/ip/binseg/engine/ssltrainer.py  |   3 +-
 bob/ip/binseg/engine/trainer.py     | 129 +++++++++++++---------------
 bob/ip/binseg/modeling/driu.py      |   2 +-
 bob/ip/binseg/modeling/driubn.py    |   2 +-
 bob/ip/binseg/modeling/driuod.py    |   2 +-
 bob/ip/binseg/modeling/driupix.py   |   2 +-
 bob/ip/binseg/modeling/hed.py       |   2 +-
 bob/ip/binseg/modeling/m2u.py       |   2 +-
 bob/ip/binseg/modeling/resunet.py   |   2 +-
 bob/ip/binseg/script/binseg.py      |  46 ----------
 bob/ip/binseg/script/evaluate.py    |  96 +++++++++++++++++++++
 bob/ip/binseg/script/train.py       |  37 +++++---
 bob/ip/binseg/utils/checkpointer.py |  20 ++---
 bob/ip/binseg/utils/plot.py         |  50 ++++++-----
 doc/cli.rst                         |  10 +++
 doc/evaluation.rst                  |  75 ++++------------
 doc/links.rst                       |  33 ++++---
 doc/models.rst                      |  65 ++++++++++++++
 doc/training.rst                    |  10 +--
 doc/usage.rst                       |   1 +
 setup.py                            |   2 +-
 21 files changed, 351 insertions(+), 240 deletions(-)
 create mode 100644 bob/ip/binseg/script/evaluate.py
 create mode 100644 doc/models.rst

diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 02ad551c..f03e01e4 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -206,7 +206,8 @@ def do_ssltrain(
         rampup epochs
 
     """
-    logger.info("Start training")
+    logger.info("Start SSL training")
+
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index a5eb759c..44ab76f1 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -2,6 +2,7 @@
 # -*- coding: utf-8 -*-
 
 import os
+import csv
 import time
 import datetime
 import torch
@@ -12,6 +13,7 @@ from bob.ip.binseg.utils.metric import SmoothedValue
 from bob.ip.binseg.utils.plot import loss_curve
 
 import logging
+
 logger = logging.getLogger(__name__)
 
 
@@ -28,37 +30,57 @@ def do_train(
     output_folder,
 ):
     """
-    Train model and save to disk.
+    Train models and save it to disk.
+
+    This method supports periodic checkpointing and the output of a
+    CSV-formatted log with the evolution of some figures during training.
+
 
     Parameters
     ----------
+
     model : :py:class:`torch.nn.Module`
         Network (e.g. DRIU, HED, UNet)
+
     data_loader : :py:class:`torch.utils.data.DataLoader`
+
     optimizer : :py:mod:`torch.optim`
+
     criterion : :py:class:`torch.nn.modules.loss._Loss`
         loss function
+
     scheduler : :py:mod:`torch.optim`
         learning rate scheduler
+
     checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
-        checkpointer
+        checkpointer implementation
+
     checkpoint_period : int
-        save a checkpoint every n epochs
+        save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
+        not save intermediary checkpoints
+
     device : str
-        device to use ``'cpu'`` or ``'cuda'``
+        device to use ``'cpu'`` or ``cuda:0``
+
     arguments : dict
         start end end epochs
+
     output_folder : str
         output path
     """
+
     logger.info("Start training")
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
-    # Logg to file
-    with open(
-        os.path.join(output_folder, "{}_trainlog.csv".format(model.name)), "a+"
-    ) as outfile:
+    # Log to file
+    logfile_name = os.path.join(output_folder, "trainlog.csv")
+    logfile_fields = ("epoch", "total-time", "eta", "average-loss",
+            "median-loss", "learning-rate", "memory-megabytes")
+
+    with open(logfile_name, "w", newline="") as logfile:
+        logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
+        logwriter.writeheader()
 
         model.train().to(device)
         for state in optimizer.state.values():
@@ -95,70 +117,43 @@ def do_train(
                 losses.update(loss)
                 logger.debug("batch loss: {}".format(loss.item()))
 
-            if epoch % checkpoint_period == 0:
+            if checkpoint_period and (epoch % checkpoint_period == 0):
                 checkpointer.save("model_{:03d}".format(epoch), **arguments)
 
             if epoch == max_epoch:
                 checkpointer.save("model_final", **arguments)
 
+            # computes ETA (estimated time-of-arrival; end of training) taking
+            # into consideration previous epoch performance
             epoch_time = time.time() - start_epoch_time
-
             eta_seconds = epoch_time * (max_epoch - epoch)
-            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
-
-            outfile.write(
-                (
-                    "{epoch}, "
-                    "{avg_loss:.6f}, "
-                    "{median_loss:.6f}, "
-                    "{lr:.6f}, "
-                    "{memory:.0f}"
-                    "\n"
-                ).format(
-                    eta=eta_string,
-                    epoch=epoch,
-                    avg_loss=losses.avg,
-                    median_loss=losses.median,
-                    lr=optimizer.param_groups[0]["lr"],
-                    memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
-                    if torch.cuda.is_available()
-                    else 0.0,
-                )
-            )
-            logger.info(
-                (
-                    "eta: {eta}, "
-                    "epoch: {epoch}, "
-                    "avg. loss: {avg_loss:.6f}, "
-                    "median loss: {median_loss:.6f}, "
-                    "lr: {lr:.6f}, "
-                    "max mem: {memory:.0f}"
-                ).format(
-                    eta=eta_string,
-                    epoch=epoch,
-                    avg_loss=losses.avg,
-                    median_loss=losses.median,
-                    lr=optimizer.param_groups[0]["lr"],
-                    memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
-                    if torch.cuda.is_available()
-                    else 0.0,
-                )
-            )
-
+            current_time = time.time() - start_training_time
+
+            logdata = (
+                    ("epoch", f"{epoch}"),
+                    ("total-time",
+                        f"{datetime.timedelta(seconds=int(current_time))}"),
+                    ("eta",
+                        f"{datetime.timedelta(seconds=int(eta_seconds))}"),
+                    ("average-loss", f"{losses.avg:.6f}"),
+                    ("median-loss", f"{losses.median:.6f}"),
+                    ("learning-rate",
+                        f"{optimizer.param_groups[0]['lr']:.6f}"),
+                    ("gpu-memory-megabytes",
+                        f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}" \
+                        if torch.cuda.is_available() else "0.0"),
+                    )
+
+            logwriter.writerow(dict(k for k in logdata))
+            logger.info("|".join([f"{k}: {v}" for (k,v) in logdata]))
+
+        logger.info("End of training.")
         total_training_time = time.time() - start_training_time
-        total_time_str = str(datetime.timedelta(seconds=total_training_time))
-        logger.info(
-            "Total training time: {} ({:.4f} s / epoch)".format(
-                total_time_str, total_training_time / (max_epoch)
-            )
-        )
-
-    log_plot_file = os.path.join(output_folder, "{}_trainlog.pdf".format(model.name))
-    logdf = pd.read_csv(
-        os.path.join(output_folder, "{}_trainlog.csv".format(model.name)),
-        header=None,
-        names=["avg. loss", "median loss", "lr", "max memory"],
-    )
-    fig = loss_curve(logdf, output_folder)
-    logger.info("saving {}".format(log_plot_file))
-    fig.savefig(log_plot_file)
+        logger.info(f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)")
+
+    # plots a version of the CSV trainlog into a PDF
+    logplot_name = os.path.join(output_folder, "trainlog.pdf")
+    logdf = pd.read_csv(logplot_name, header=0, names=logfile_fields)
+    fig = loss_curve(logdf, title="Loss Evolution")
+    logger.info(f"Saving {log_plot_file}")
+    fig.savefig(logplot_name)
diff --git a/bob/ip/binseg/modeling/driu.py b/bob/ip/binseg/modeling/driu.py
index 5b4425c2..c63dc843 100644
--- a/bob/ip/binseg/modeling/driu.py
+++ b/bob/ip/binseg/modeling/driu.py
@@ -96,5 +96,5 @@ def build_driu():
     model = torch.nn.Sequential(
         OrderedDict([("backbone", backbone), ("head", driu_head)])
     )
-    model.name = "DRIU"
+    model.name = "driu"
     return model
diff --git a/bob/ip/binseg/modeling/driubn.py b/bob/ip/binseg/modeling/driubn.py
index 245fdf17..fd834353 100644
--- a/bob/ip/binseg/modeling/driubn.py
+++ b/bob/ip/binseg/modeling/driubn.py
@@ -93,5 +93,5 @@ def build_driu():
     model = torch.nn.Sequential(
         OrderedDict([("backbone", backbone), ("head", driu_head)])
     )
-    model.name = "DRIUBN"
+    model.name = "driu-bn"
     return model
diff --git a/bob/ip/binseg/modeling/driuod.py b/bob/ip/binseg/modeling/driuod.py
index 25e5b82d..dbd26167 100644
--- a/bob/ip/binseg/modeling/driuod.py
+++ b/bob/ip/binseg/modeling/driuod.py
@@ -88,5 +88,5 @@ def build_driuod():
     model = torch.nn.Sequential(
         OrderedDict([("backbone", backbone), ("head", driu_head)])
     )
-    model.name = "DRIUOD"
+    model.name = "driu-od"
     return model
diff --git a/bob/ip/binseg/modeling/driupix.py b/bob/ip/binseg/modeling/driupix.py
index 3ad10aa7..eef95c9f 100644
--- a/bob/ip/binseg/modeling/driupix.py
+++ b/bob/ip/binseg/modeling/driupix.py
@@ -92,5 +92,5 @@ def build_driupix():
     model = torch.nn.Sequential(
         OrderedDict([("backbone", backbone), ("head", driu_head)])
     )
-    model.name = "DRIUPIX"
+    model.name = "driu-pix"
     return model
diff --git a/bob/ip/binseg/modeling/hed.py b/bob/ip/binseg/modeling/hed.py
index 02c2b957..db42515c 100644
--- a/bob/ip/binseg/modeling/hed.py
+++ b/bob/ip/binseg/modeling/hed.py
@@ -98,5 +98,5 @@ def build_hed():
     model = torch.nn.Sequential(
         OrderedDict([("backbone", backbone), ("head", hed_head)])
     )
-    model.name = "HED"
+    model.name = "hed"
     return model
diff --git a/bob/ip/binseg/modeling/m2u.py b/bob/ip/binseg/modeling/m2u.py
index 25bc0515..8861b965 100644
--- a/bob/ip/binseg/modeling/m2u.py
+++ b/bob/ip/binseg/modeling/m2u.py
@@ -116,5 +116,5 @@ def build_m2unet():
     model = torch.nn.Sequential(
         OrderedDict([("backbone", backbone), ("head", m2u_head)])
     )
-    model.name = "M2UNet"
+    model.name = "m2unet"
     return model
diff --git a/bob/ip/binseg/modeling/resunet.py b/bob/ip/binseg/modeling/resunet.py
index c27efeb3..cce8242e 100644
--- a/bob/ip/binseg/modeling/resunet.py
+++ b/bob/ip/binseg/modeling/resunet.py
@@ -70,5 +70,5 @@ def build_res50unet():
     backbone = resnet50(pretrained=False, return_features=[2, 4, 5, 6, 7])
     unet_head = ResUNet([64, 256, 512, 1024, 2048], pixel_shuffle=False)
     model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", unet_head)]))
-    model.name = "ResUNet"
+    model.name = "resunet"
     return model
diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py
index 3db339a7..a3a6a27f 100644
--- a/bob/ip/binseg/script/binseg.py
+++ b/bob/ip/binseg/script/binseg.py
@@ -22,7 +22,6 @@ from bob.extension.scripts.click_helper import (
 
 from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer
 from torch.utils.data import DataLoader
-from bob.ip.binseg.engine.inferencer import do_inference
 from bob.ip.binseg.utils.plot import plot_overview
 from bob.ip.binseg.utils.click import OptionEatAll
 from bob.ip.binseg.utils.rsttable import create_overview_grid
@@ -40,51 +39,6 @@ def binseg():
     """Binary 2D Image Segmentation Benchmark commands."""
 
 
-# Inference
-@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
-@click.option(
-    "--output-path", "-o", required=True, default="output", cls=ResourceOption
-)
-@click.option("--model", "-m", required=True, cls=ResourceOption)
-@click.option("--dataset", "-d", required=True, cls=ResourceOption)
-@click.option("--batch-size", "-b", required=True, default=2, cls=ResourceOption)
-@click.option(
-    "--device",
-    "-d",
-    help='A string indicating the device to use (e.g. "cpu" or "cuda:0"',
-    show_default=True,
-    required=True,
-    default="cpu",
-    cls=ResourceOption,
-)
-@click.option(
-    "--weight",
-    "-w",
-    help="Path or URL to pretrained model",
-    required=False,
-    default=None,
-    cls=ResourceOption,
-)
-@verbosity_option(cls=ResourceOption)
-def test(model, output_path, device, batch_size, dataset, weight, **kwargs):
-    """ Run inference and evaluate the model performance """
-
-    # PyTorch dataloader
-    data_loader = DataLoader(
-        dataset=dataset,
-        batch_size=batch_size,
-        shuffle=False,
-        pin_memory=torch.cuda.is_available(),
-    )
-
-    # checkpointer, load last model in dir
-    checkpointer = DetectronCheckpointer(
-        model, save_dir=output_path, save_to_disk=False
-    )
-    checkpointer.load(weight)
-    do_inference(model, data_loader, device, output_path)
-
-
 # Plot comparison
 @binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand)
 @click.option(
diff --git a/bob/ip/binseg/script/evaluate.py b/bob/ip/binseg/script/evaluate.py
new file mode 100644
index 00000000..bbbe8f4d
--- /dev/null
+++ b/bob/ip/binseg/script/evaluate.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import os
+import pkg_resources
+
+import click
+from click_plugins import with_plugins
+
+import torch
+from torch.utils.data import DataLoader
+
+from bob.extension.scripts.click_helper import (
+    verbosity_option,
+    ConfigCommand,
+    ResourceOption,
+    AliasedGroup,
+)
+
+from ..utils.checkpointer import DetectronCheckpointer
+from ..engine.inferencer import do_inference
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+@click.command(
+    entry_point_group="bob.ip.binseg.config",
+    cls=ConfigCommand,
+    epilog="""Examples:
+
+\b
+    1. Evaluates a M2U-Net model on the DRIVE test set:
+
+       $ bob binseg evaluate -vv m2unet drive-test --weight=results/model_final.pth
+
+""",
+)
+@click.option(
+    "--model",
+    "-m",
+    help="A torch.nn.Module instance implementing the network to be evaluated",
+    required=True,
+    cls=ResourceOption,
+)
+@click.option(
+    "--dataset",
+    "-d",
+    help="A torch.utils.data.dataset.Dataset instance implementing a dataset to be used for evaluating the model, possibly including all pre-processing pipelines required",
+    required=True,
+    cls=ResourceOption,
+)
+@click.option(
+    "--batch-size",
+    "-b",
+    help="Number of samples in every batch (this parameter affects memory requirements for the network)",
+    required=True,
+    show_default=True,
+    default=1,
+    cls=ResourceOption,
+)
+@click.option(
+    "--device",
+    "-d",
+    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
+    show_default=True,
+    required=True,
+    default="cpu",
+    cls=ResourceOption,
+)
+@click.option(
+    "--weight",
+    "-w",
+    help="Path or URL to pretrained model file (.pth extension)",
+    required=True,
+    cls=ResourceOption,
+)
+@verbosity_option(cls=ResourceOption)
+def evaluate(model, output_path, device, batch_size, dataset, weight, **kwargs):
+    """Evaluates an FCN on a binary segmentation task.
+    """
+
+    # PyTorch dataloader
+    data_loader = DataLoader(
+        dataset=dataset,
+        batch_size=batch_size,
+        shuffle=False,
+        pin_memory=torch.cuda.is_available(),
+    )
+
+    # checkpointer, load last model in dir
+    checkpointer = DetectronCheckpointer(
+        model, save_dir=output_path, save_to_disk=False
+    )
+    checkpointer.load(weight)
+    do_inference(model, data_loader, device, output_path)
diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py
index 44534a13..5a2ce6f1 100644
--- a/bob/ip/binseg/script/train.py
+++ b/bob/ip/binseg/script/train.py
@@ -22,6 +22,7 @@ from ..engine.trainer import do_train
 from ..engine.ssltrainer import do_ssltrain
 
 import logging
+
 logger = logging.getLogger(__name__)
 
 
@@ -68,7 +69,9 @@ logger = logging.getLogger(__name__)
 @click.option(
     "--dataset",
     "-d",
-    help="A torch.utils.data.dataset.Dataset instance implementing a dataset to be used for training the model, possibly including all pre-processing pipelines required.",
+    help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
+    "to be used for training the model, possibly including all pre-processing"
+    " pipelines required",
     required=True,
     cls=ResourceOption,
 )
@@ -80,27 +83,32 @@ logger = logging.getLogger(__name__)
 )
 @click.option(
     "--criterion",
-    help="A loss function to compute the FCN error for every sample respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
+    help="A loss function to compute the FCN error for every sample "
+    "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
     required=True,
     cls=ResourceOption,
 )
 @click.option(
     "--scheduler",
-    help="A learning rate scheduler that drives changes in the learning rate depending on the FCN state (see torch.optim.lr_scheduler)",
+    help="A learning rate scheduler that drives changes in the learning "
+    "rate depending on the FCN state (see torch.optim.lr_scheduler)",
     required=True,
     cls=ResourceOption,
 )
 @click.option(
     "--pretrained-backbone",
     "-t",
-    help="URLs of a pre-trained model file that will be used to preset FCN weights (where relevant) before training starts.  (e.g. vgg-16)",
+    help="URLs of a pre-trained model file that will be used to preset "
+    "FCN weights (where relevant) before training starts "
+    "(e.g. vgg16, mobilenetv2)",
     required=True,
     cls=ResourceOption,
 )
 @click.option(
     "--batch-size",
     "-b",
-    help="Number of samples in every batch (notice that changing this parameter affects memory requirements for the network)",
+    help="Number of samples in every batch (this parameter affects "
+    "memory requirements for the network)",
     required=True,
     show_default=True,
     default=2,
@@ -118,10 +126,14 @@ logger = logging.getLogger(__name__)
 @click.option(
     "--checkpoint-period",
     "-p",
-    help="Number of epochs after which a checkpoint is saved",
+    help="Number of epochs after which a checkpoint is saved.  "
+    "A value of zero will disable check-pointing.  If checkpointing is "
+    "enabled and training stops, it is automatically resumed from the "
+    "last saved checkpoint if training is restarted with the same "
+    "configuration.",
     show_default=True,
     required=True,
-    default=100,
+    default=0,
     cls=ResourceOption,
 )
 @click.option(
@@ -156,7 +168,7 @@ logger = logging.getLogger(__name__)
     help="Ramp-up length in epochs (for SSL training only)",
     show_default=True,
     required=True,
-    default=900,
+    default=1000,
     cls=ResourceOption,
 )
 @verbosity_option(cls=ResourceOption)
@@ -180,7 +192,11 @@ def train(
 ):
     """Trains an FCN to perform binary segmentation using a supervised approach
 
-    Training is performed for a fixed number of steps (not configurable).
+    Training is performed for a configurable number of epochs, and generates at
+    least a final model (.pth file).  It may also generate a number of
+    intermediate checkpoints.  Checkpoints are model files (.pth files) that
+    are stored during the training and useful to resume the procedure in case
+    it stops abruptly.
     """
 
     if not os.path.exists(output_path):
@@ -199,6 +215,7 @@ def train(
     checkpointer = DetectronCheckpointer(
         model, optimizer, scheduler, save_dir=output_path, save_to_disk=True
     )
+
     arguments = {}
     arguments["epoch"] = 0
     extra_checkpoint_data = checkpointer.load(pretrained_backbone)
@@ -209,7 +226,6 @@ def train(
     logger.info("Continuing from epoch {}".format(arguments["epoch"]))
 
     if not ssl:
-        logger.info("Doing SUPERVISED training...")
         do_train(
             model,
             data_loader,
@@ -225,7 +241,6 @@ def train(
 
     else:
 
-        logger.info("Doing SEMI-SUPERVISED training...")
         do_ssltrain(
             model,
             data_loader,
diff --git a/bob/ip/binseg/utils/checkpointer.py b/bob/ip/binseg/utils/checkpointer.py
index 4b375e9f..8c0def2e 100644
--- a/bob/ip/binseg/utils/checkpointer.py
+++ b/bob/ip/binseg/utils/checkpointer.py
@@ -3,12 +3,14 @@
 
 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
 
-import logging
 import torch
 import os
 from bob.ip.binseg.utils.model_serialization import load_state_dict
 from bob.ip.binseg.utils.model_zoo import cache_url
 
+import logging
+logger = logging.getLogger(__name__)
+
 
 class Checkpointer:
     """Adapted from `maskrcnn-benchmark
@@ -22,16 +24,12 @@ class Checkpointer:
         scheduler=None,
         save_dir="",
         save_to_disk=None,
-        logger=None,
     ):
         self.model = model
         self.optimizer = optimizer
         self.scheduler = scheduler
         self.save_dir = save_dir
         self.save_to_disk = save_to_disk
-        if logger is None:
-            logger = logging.getLogger(__name__)
-        self.logger = logger
 
     def save(self, name, **kwargs):
         if not self.save_dir:
@@ -49,7 +47,7 @@ class Checkpointer:
         data.update(kwargs)
 
         save_file = os.path.join(self.save_dir, "{}.pth".format(name))
-        self.logger.info("Saving checkpoint to {}".format(save_file))
+        logger.info("Saving checkpoint to {}".format(save_file))
         torch.save(data, save_file)
         self.tag_last_checkpoint(save_file)
 
@@ -59,16 +57,16 @@ class Checkpointer:
             f = self.get_checkpoint_file()
         if not f:
             # no checkpoint could be found
-            self.logger.warn("No checkpoint found. Initializing model from scratch")
+            logger.warn("No checkpoint found. Initializing model from scratch")
             return {}
-        self.logger.info("Loading checkpoint from {}".format(f))
+        logger.info("Loading checkpoint from {}".format(f))
         checkpoint = self._load_file(f)
         self._load_model(checkpoint)
         if "optimizer" in checkpoint and self.optimizer:
-            self.logger.info("Loading optimizer from {}".format(f))
+            logger.info("Loading optimizer from {}".format(f))
             self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
         if "scheduler" in checkpoint and self.scheduler:
-            self.logger.info("Loading scheduler from {}".format(f))
+            logger.info("Loading scheduler from {}".format(f))
             self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
 
         # return any further checkpoint data
@@ -121,7 +119,7 @@ class DetectronCheckpointer(Checkpointer):
         if f.startswith("http"):
             # if the file is a url path, download it and cache it
             cached_f = cache_url(f)
-            self.logger.info("url {} cached in {}".format(f, cached_f))
+            logger.info("url {} cached in {}".format(f, cached_f))
             f = cached_f
         # load checkpoint
         loaded = super(DetectronCheckpointer, self)._load_file(f)
diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py
index caf75c65..2873b715 100644
--- a/bob/ip/binseg/utils/plot.py
+++ b/bob/ip/binseg/utils/plot.py
@@ -1,15 +1,19 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import numpy as np
 import os
 import csv
+
+import numpy as np
 import pandas as pd
 import PIL
-from PIL import Image, ImageFont, ImageDraw
+
 import torchvision.transforms.functional as VF
 import torch
 
+import matplotlib
+matplotlib.use("agg")
+
 
 def precision_recall_f1iso(precision, recall, names, title=None):
     """
@@ -262,31 +266,35 @@ def precision_recall_f1iso_confintval(
     return fig
 
 
-def loss_curve(df, title):
-    """ Creates a loss curve given a Dataframe with column names:
-
-    ``['avg. loss', 'median loss','lr','max memory']``
+def loss_curve(df, title=None):
+    """Creates a loss curve in a Matplotlib figure.
 
     Parameters
     ----------
+
     df : :py:class:`pandas.DataFrame`
+        A dataframe containing, at least, "epoch", "median-loss" and
+        "learning-rate" columns, that will be plotted.
+
+    title : :py:class:`str`, Optional
+        Optional title, that will be set on the figure if passed
 
     Returns
     -------
-    matplotlib.figure.Figure
-    """
-    import matplotlib
 
-    matplotlib.use("agg")
+    figure : matplotlib.figure.Figure
+        A figure, that may be saved or displayed
+
+    """
     import matplotlib.pyplot as plt
 
-    ax1 = df.plot(y="median loss", grid=True)
-    ax1.set_title(title)
-    ax1.set_ylabel("median loss")
+    ax1 = df.plot(x="epoch", y="median-loss", grid=True)
+    if title is not None: ax1.set_title(title)
+    ax1.set_ylabel("Median Loss")
     ax1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
-    ax2 = df["lr"].plot(secondary_y=True, legend=True, grid=True,)
-    ax2.set_ylabel("lr")
-    ax1.set_xlabel("epoch")
+    ax2 = df["learning-rate"].plot(secondary_y=True, legend=True, grid=True,)
+    ax2.set_ylabel("Learning Rate")
+    ax1.set_xlabel("Epoch")
     plt.tight_layout()
     fig = ax1.get_figure()
     return fig
@@ -443,8 +451,8 @@ def metricsviz(
             f1 = img_metrics[" f1_score"].max()
             # add f1-score
             fnt_size = tp_pil_colored.size[1] // 25
-            draw = ImageDraw.Draw(tp_pil_colored)
-            fnt = ImageFont.truetype("FreeMono.ttf", fnt_size)
+            draw = PIL.ImageDraw.Draw(tp_pil_colored)
+            fnt = PIL.ImageFont.truetype("FreeMono.ttf", fnt_size)
             draw.text((0, 0), "F1: {:.4f}".format(f1), (255, 255, 255), font=fnt)
 
         # save to disk
@@ -472,15 +480,15 @@ def overlay(dataset, output_path):
         img = VF.to_pil_image(sample[1])  # PIL Image
 
         # read probability output
-        pred = Image.open(os.path.join(output_path, "images", name)).convert(mode="L")
+        pred = PIL.Image.open(os.path.join(output_path, "images", name)).convert(mode="L")
         # color and overlay
         pred_green = PIL.ImageOps.colorize(pred, (0, 0, 0), (0, 255, 0))
         overlayed = PIL.Image.blend(img, pred_green, 0.4)
 
         # add f1-score
         # fnt_size = overlayed.size[1]//25
-        # draw = ImageDraw.Draw(overlayed)
-        # fnt = ImageFont.truetype('FreeMono.ttf', fnt_size)
+        # draw = PIL.ImageDraw.Draw(overlayed)
+        # fnt = PIL.ImageFont.truetype('FreeMono.ttf', fnt_size)
         # draw.text((0, 0),"F1: {:.4f}".format(f1),(255,255,255),font=fnt)
         # save to disk
         overlayed_path = os.path.join(output_path, "overlayed")
diff --git a/doc/cli.rst b/doc/cli.rst
index 47bdfc80..d65aaf79 100644
--- a/doc/cli.rst
+++ b/doc/cli.rst
@@ -71,4 +71,14 @@ evaluation tests or for inference.
 .. command-output:: bob binseg train --help
 
 
+.. _bob.ip.binseg.cli.evaluate:
+
+FCN Performance Evaluation
+--------------------------
+
+Evaluation takes as input a PyTorch_ model and generates analysis information.
+
+.. command-output:: bob binseg evaluate --help
+
+
 .. include:: links.rst
diff --git a/doc/evaluation.rst b/doc/evaluation.rst
index feceb673..8bc91d0a 100644
--- a/doc/evaluation.rst
+++ b/doc/evaluation.rst
@@ -1,43 +1,45 @@
 .. -*- coding: utf-8 -*-
-.. _bob.ip.binseg.evaluation:
 
-==========
-Evaluation
-==========
+.. _bob.ip.binseg.eval:
 
-To evaluate trained models use use ``bob binseg test`` followed by
-the model config, the dataset config and the path to the pretrained
-model via the argument ``-w``.
+============
+ Evaluation
+============
 
-Alternatively point to the output folder used during training via
-the ``-o`` argument. The Checkpointer will load the model as indicated
+To evaluate trained models use our CLI interface. ``bob binseg evaluate``
+followed by the model and the dataset configuration, and the path to the
+pretrained model via the argument ``--weight``.
+
+Alternatively point to the output folder used during training via the
+``--output-path`` argument.   The Checkpointer will load the model as indicated
 in the file: ``last_checkpoint``.
 
-Use ``bob binseg test --help`` for more information.
+Use ``bob binseg evaluate --help`` for more information.
 
 E.g. run inference on model M2U-Net on the DRIVE test set:
 
 .. code-block:: bash
 
     # Point directly to saved model via -w argument:
-    bob binseg test M2UNet DRIVETEST -o /outputfolder/for/results -w /direct/path/to/weight/model_final.pth
+    bob binseg evaluate m2unet drive-test -o /outputfolder/for/results -w /direct/path/to/weight/model_final.pth
 
     # Use training output path (requries last_checkpoint file to be present)
     # The evaluation results will be stored in the same folder
-    bob binseg test M2UNet DRIVETEST -o /DRIVE/M2UNet/output
+    bob binseg test m2unet drive-test -o /outputfolder/for/results
+
 
 Outputs
 ========
 The inference run generates the following output files:
 
-.. code-block:: bash
+.. code-block:: text
 
     .
     ├── images  # the predicted probabilities as grayscale images in .png format
     ├── hdf5    # the predicted probabilties in hdf5 format
     ├── last_checkpoint  # text file that keeps track of the last checkpoint
-    ├── M2UNet_trainlog.csv # training log
-    ├── M2UNet_trainlog.pdf # training log plot
+    ├── trainlog.csv # training log
+    ├── trainlog.pdf # training log plot
     ├── model_*.pth # model checkpoints
     └── results
         ├── image*.jpg.csv # evaluation metrics for each image
@@ -46,6 +48,7 @@ The inference run generates the following output files:
         ├── precision_recall.pdf # precision vs recall plot
         └── Times.txt # inference times
 
+
 Inference Only Mode
 ====================
 
@@ -57,48 +60,6 @@ If you wish to run inference only on a folder containing images, use the
 
     bob binseg predict M2UNet /path/to/myinferencedatasetconfig.py -b 1 -d cpu -o /my/output/path -w /path/to/pretrained/weight/model_final.pth -vv
 
-Pretrained Models
-=================
-
-Due to storage limitations we only provide weights of a subset
-of all evaluated models:
-
-
-
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-|                    | DRIU               | M2UNet                                                                                                                         |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| DRIVE              | `DRIU_DRIVE.pth`_  | `M2UNet_DRIVE.pth <m2unet_drive.pth_>`_                                                                                        |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| COVD-DRIVE         |                    | `M2UNet_COVD-DRIVE.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-DRIVE.pth>`_               |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| COVD-DRIVE SSL     |                    | `M2UNet_COVD-DRIVE_SSL.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-DRIVE_SSL.pth>`_       |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| STARE              | DRIU_STARE.pth_    | `M2UNet_STARE.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_STARE.pth>`_                         |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| COVD-STARE         |                    | `M2UNet_COVD-STARE.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-STARE.pth>`_               |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| COVD-STARE SSL     |                    | `M2UNet_COVD-STARE_SSL.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-STARE_SSL.pth>`_       |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| CHASE_DB1          | DRIU_CHASEDB1.pth_ | `M2UNet_CHASEDB1.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_CHASEDB1.pth>`_                   |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| COVD-CHASE_DB1     |                    | `M2UNet_COVD-CHASEDB1.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-CHASEDB1.pth>`_         |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| COVD-CHASE_DB1 SSL |                    | `M2UNet_COVD-CHASEDB1_SSL.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-CHASEDB1_SSL.pth>`_ |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| IOSTARVESSEL       | DRIU_IOSTAR.pth_   | `M2UNet_IOSTARVESSEL.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_IOSTARVESSEL.pth>`_           |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| COVD-IOSTAR        |                    | `M2UNet_COVD-IOSTAR.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-IOSTAR.pth>`_             |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| COVD-IOSTAR SSL    |                    | `M2UNet_COVD-IOSTAR_SSL.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-IOSTAR_SSL.pth>`_     |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| HRF                | DRIU_HRF.pth_      | `M2UNet_HRF1168.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_HRF1168.pth>`_                     |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| COVD-HRF           |                    | `M2UNet_COVD-HRF.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-HRF.pth>`_                   |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-| COVD-HRF SSL       |                    | `M2UNet_COVD-HRF_SSL.pth <https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-HRF_SSL.pth>`_           |
-+--------------------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+
-
 
 
 To run evaluation of pretrained models pass url as ``-w`` argument. E.g.:
diff --git a/doc/links.rst b/doc/links.rst
index ab294a1e..ee753a18 100644
--- a/doc/links.rst
+++ b/doc/links.rst
@@ -26,25 +26,32 @@
 
 .. Pretrained models
 
-.. _driu_chasedb1.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/DRIU_CHASEDB1.pth
+.. DRIVE
 .. _driu_drive.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/DRIU_DRIVE.pth
-.. _driu_hrf.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/DRIU_HRF1168.pth
-.. _driu_stare.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/DRIU_STARE.pth
-.. _driu_iostar.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/DRIU_IOSTARVESSEL.pth
-
-.. _m2unet_chasedb1.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_CHASEDB1.pth
 .. _m2unet_drive.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_DRIVE.pth
-.. _m2unet_hrf.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_HRF1168.pth
-.. _m2unet_stare.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_STARE.pth
-.. _m2unet_iostar.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_IOSTARVESSEL.pth
 .. _m2unet_covd-drive.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-DRIVE.pth
 .. _m2unet_covd-drive_ssl.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-DRIVE_SSL.pth
+
+.. STARE
+.. _driu_stare.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/DRIU_STARE.pth
+.. _m2unet_stare.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_STARE.pth
 .. _m2unet_covd-stare.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-STARE.pth
 .. _m2unet_covd-stare_ssl.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-STARE_SSL.pth
-.. _m2unet_covd-chaesdb1.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-CHASEDB1.pth
-.. _m2unet_covd-chaesdb1_ssl.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-CHASEDB1_SSL.pth
-.. _m2unet_covd-hrf.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-HRF.pth
-.. _m2unet_covd-hrf_ssl.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-HRF_SSL.pth
+
+.. CHASE-DB1
+.. _driu_chasedb1.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/DRIU_CHASEDB1.pth
+.. _m2unet_chasedb1.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_CHASEDB1.pth
+.. _m2unet_covd-chasedb1.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-CHASEDB1.pth
+.. _m2unet_covd-chasedb1_ssl.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-CHASEDB1_SSL.pth
+
+.. IOSTAR
+.. _driu_iostar.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/DRIU_IOSTARVESSEL.pth
+.. _m2unet_iostar.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_IOSTARVESSEL.pth
 .. _m2unet_covd-iostar.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-IOSTAR.pth
 .. _m2unet_covd-iostar_ssl.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-IOSTAR_SSL.pth
 
+.. HRF
+.. _driu_hrf.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/DRIU_HRF1168.pth
+.. _m2unet_hrf.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_HRF1168.pth
+.. _m2unet_covd-hrf.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-HRF.pth
+.. _m2unet_covd-hrf_ssl.pth: https://www.idiap.ch/software/bob/data/bob/bob.ip.binseg/master/M2UNet_COVD-HRF_SSL.pth
diff --git a/doc/models.rst b/doc/models.rst
new file mode 100644
index 00000000..25e347f8
--- /dev/null
+++ b/doc/models.rst
@@ -0,0 +1,65 @@
+.. -*- coding: utf-8 -*-
+
+.. _bob.ip.binseg.models:
+
+===================
+ Pretrained Models
+===================
+
+We offer the following pre-trained models allowing inference and score
+reproduction of our results.  Due to storage limitations we only provide
+weights of a subset of all evaluated models.
+
+
+.. list-table::
+
+   * - **Datasets / Models**
+     - :py:mod:`driu <bob.ip.binseg.configs.models.driu>`
+     - :py:mod:`m2unet <bob.ip.binseg.configs.models.m2unet>`
+   * - :py:mod:`drive <bob.ip.binseg.configs.datasets.drive>`
+     - driu_drive.pth_
+     - m2unet_drive.pth_
+   * - :py:mod:`covd-drive <bob.ip.binseg.configs.datasets.starechasedb1iostarhrf544>`
+     -
+     - m2unet_covd-drive.pth_
+   * - :py:mod:`covd-drive-ssl <bob.ip.binseg.configs.datasets.starechasedb1iostarhrf544ssldrive>`
+     -
+     - m2unet_covd-drive_ssl.pth_
+   * - :py:mod:`stare <bob.ip.binseg.configs.datasets.stare>`
+     - driu_stare.pth_
+     - m2unet_stare.pth_
+   * - :py:mod:`covd-stare <bob.ip.binseg.configs.datasets.drivechasedb1iostarhrf608>`
+     -
+     - m2unet_covd-stare.pth_
+   * - :py:mod:`covd-stare-ssl <bob.ip.binseg.configs.datasets.drivechasedb1iostarhrf608sslstare>`
+     -
+     - m2unet_covd-stare_ssl.pth_
+   * - :py:mod:`chasedb1 <bob.ip.binseg.configs.datasets.chasedb1>`
+     - driu_chasedb1.pth_
+     - m2unet_chasedb1.pth_
+   * - :py:mod:`covd-chasedb1 <bob.ip.binseg.configs.datasets.drivestareiostarhrf960>`
+     -
+     - m2unet_covd-chasedb1.pth_
+   * - :py:mod:`covd-chasedb1-ssl <bob.ip.binseg.configs.datasets.drivestareiostarhrf960sslchase>`
+     -
+     - m2unet_covd-chasedb1_ssl.pth_
+   * - :py:mod:`iostar-vessel <bob.ip.binseg.configs.datasets.iostarvessel>`
+     - driu_iostar.pth_
+     - m2unet_iostar.pth_
+   * - :py:mod:`covd-iostar-vessel <bob.ip.binseg.configs.datasets.drivestarechasedb1hrf1024>`
+     -
+     - m2unet_covd-iostar.pth_
+   * - :py:mod:`covd-iostar-vessel-ssl <bob.ip.binseg.configs.datasets.drivestarechasedb1hrf1024ssliostar>`
+     -
+     - m2unet_covd-iostar_ssl.pth_
+   * - :py:mod:`hrf <bob.ip.binseg.configs.datasets.hrf1168>`
+     - driu_hrf.pth_
+     - m2unet_hrf.pth_
+   * - :py:mod:`covd-hrf <bob.ip.binseg.configs.datasets.drivestarechasedb1iostar1168>`
+     -
+     - m2unet_covd-hrf.pth_
+   * - :py:mod:`covd-hrf-ssl <bob.ip.binseg.configs.datasets.drivestarechasedb1iostar1168sslhrf>`
+     -
+     - m2unet_covd-hrf_ssl.pth_
+
+.. include:: links.rst
diff --git a/doc/training.rst b/doc/training.rst
index 8b0eaef4..db9daed6 100644
--- a/doc/training.rst
+++ b/doc/training.rst
@@ -130,11 +130,11 @@ card, for semi-supervised learning of COVD- systems.  Use it like this:
 .. list-table::
 
   * - **Models / Datasets**
-    - :py:mod:`covd-drive <bob.ip.binseg.configs.datasets.starechasedb1iostarhrf544ssldrive>`
-    - :py:mod:`covd-stare <bob.ip.binseg.configs.datasets.drivechasedb1iostarhrf608sslstare>`
-    - :py:mod:`covd-chasedb1 <bob.ip.binseg.configs.datasets.drivestareiostarhrf960sslchase>`
-    - :py:mod:`covd-iostar-vessel <bob.ip.binseg.configs.datasets.drivestarechasedb1hrf1024ssliostar>`
-    - :py:mod:`covd-hrf <bob.ip.binseg.configs.datasets.drivestarechasedb1iostar1168sslhrf>`
+    - :py:mod:`covd-drive-ssl <bob.ip.binseg.configs.datasets.starechasedb1iostarhrf544ssldrive>`
+    - :py:mod:`covd-stare-ssl <bob.ip.binseg.configs.datasets.drivechasedb1iostarhrf608sslstare>`
+    - :py:mod:`covd-chasedb1-ssl <bob.ip.binseg.configs.datasets.drivestareiostarhrf960sslchase>`
+    - :py:mod:`covd-iostar-vessel-ssl <bob.ip.binseg.configs.datasets.drivestarechasedb1hrf1024ssliostar>`
+    - :py:mod:`covd-hrf-ssl <bob.ip.binseg.configs.datasets.drivestarechasedb1iostar1168sslhrf>`
   * - :py:mod:`driu-ssl <bob.ip.binseg.configs.models.driussl>` / :py:mod:`driu-bn <bob.ip.binseg.configs.models.driubnssl>`
     - 4
     - 4
diff --git a/doc/usage.rst b/doc/usage.rst
index f5133af1..24c3fa1a 100644
--- a/doc/usage.rst
+++ b/doc/usage.rst
@@ -47,6 +47,7 @@ modifying one of our configuration resources.
    :maxdepth: 2
 
    training
+   models
    evaluation
    plotting
    visualization
diff --git a/setup.py b/setup.py
index bef1aec4..3fe6be37 100644
--- a/setup.py
+++ b/setup.py
@@ -35,10 +35,10 @@ setup(
             "evalpred = bob.ip.binseg.script.binseg:evalpred",
             "gridtable = bob.ip.binseg.script.binseg:testcheckpoints",
             "predict = bob.ip.binseg.script.binseg:predict",
-            "test = bob.ip.binseg.script.binseg:test",
             "visualize = bob.ip.binseg.script.binseg:visualize",
             "config = bob.ip.binseg.script.config:config",
             "train = bob.ip.binseg.script.train:train",
+            "evaluate = bob.ip.binseg.script.evaluate:evaluate",
         ],
         # bob train configurations
         "bob.ip.binseg.config": [
-- 
GitLab