From 0052b817eed3445d1cf5dce00abfda271f475133 Mon Sep 17 00:00:00 2001
From: Tim Laibacher <tim.laibacher@idiap.ch>
Date: Thu, 16 May 2019 13:31:59 +0200
Subject: [PATCH] Add metrics visualization. Improve documentation

---
 bob/ip/binseg/data/transforms.py    |  17 ++-
 bob/ip/binseg/engine/inferencer.py  |  11 +-
 bob/ip/binseg/engine/trainer.py     |   8 +-
 bob/ip/binseg/modeling/losses.py    |   8 +-
 bob/ip/binseg/script/binseg.py      |  41 ++++++-
 bob/ip/binseg/utils/checkpointer.py |   9 ++
 bob/ip/binseg/utils/plot.py         | 165 ++++++++++++++++++++++------
 doc/api.rst                         |  18 ++-
 doc/conf.py                         |   1 +
 doc/links.rst                       |   3 +-
 doc/references.rst                  |   3 +-
 setup.py                            |   1 +
 12 files changed, 223 insertions(+), 62 deletions(-)

diff --git a/bob/ip/binseg/data/transforms.py b/bob/ip/binseg/data/transforms.py
index 99e941b6..b81163fa 100644
--- a/bob/ip/binseg/data/transforms.py
+++ b/bob/ip/binseg/data/transforms.py
@@ -114,7 +114,7 @@ class Pad:
         return [VF.pad(img, self.padding, self.fill, padding_mode='constant') for img in args]
     
 class ToTensor:
-    """Converts PIL.Image to torch.tensor """
+    """Converts :py:class:`PIL.Image.Image` to :py:class:`torch.Tensor` """
     def __call__(self, *args):
         return [VF.to_tensor(img) for img in args]
 
@@ -191,16 +191,16 @@ class ColorJitter(object):
     ----------
     brightness : float 
         how much to jitter brightness. brightness_factor
-        is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
+        is chosen uniformly from ``[max(0, 1 - brightness), 1 + brightness]``.
     contrast : float
         how much to jitter contrast. contrast_factor
-        is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
+        is chosen uniformly from ``[max(0, 1 - contrast), 1 + contrast]``.
     saturation : float 
         how much to jitter saturation. saturation_factor
-        is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
+        is chosen uniformly from ``[max(0, 1 - saturation), 1 + saturation]``.
     hue : float 
         how much to jitter hue. hue_factor is chosen uniformly from
-        [-hue, hue]. Should be >=0 and <= 0.5
+        ``[-hue, hue]``. Should be >=0 and <= 0.5
     prob : float
         probability at which the operation is applied
     """
@@ -247,10 +247,9 @@ class ColorJitter(object):
 
 class RandomResizedCrop:
     """Crop to random size and aspect ratio.
-    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
-    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
-    is finally resized to given size.
-    This is popularly used to train the Inception networks.
+    A crop of random size of the original size and a random aspect ratio of 
+    the original aspect ratio is made. This crop is finally resized to 
+    given size. This is popularly used to train the Inception networks.
     
     Attributes
     ----------
diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py
index c3d9c594..00f60d7c 100644
--- a/bob/ip/binseg/engine/inferencer.py
+++ b/bob/ip/binseg/engine/inferencer.py
@@ -31,13 +31,13 @@ def batch_metrics(predictions, ground_truths, names, output_folder, logger):
         list of file names 
     output_folder : str
         output path
-    logger : :py:class:logging
+    logger : :py:class:`logging.Logger`
         python logger
 
     Returns
     -------
     list 
-        list containing batch metrics (name, threshold, precision, recall, specificity, accuracy, jaccard, f1_score)
+        list containing batch metrics: ``[name, threshold, precision, recall, specificity, accuracy, jaccard, f1_score]``
     """
     step_size = 0.01
     batch_metrics = []
@@ -101,7 +101,7 @@ def save_probability_images(predictions, names, output_folder, logger):
         list of file names 
     output_folder : str
         output path
-    logger : :py:class:logging
+    logger : :py:class:`logging.Logger`
         python logger
     """
     images_subfolder = os.path.join(output_folder,'images') 
@@ -125,10 +125,9 @@ def do_inference(
     
     Parameters
     ---------
-    model : :py:class:torch.nn.Module
+    model : :py:class:`torch.nn.Module`
         neural network model (e.g. DRIU, HED, UNet)
-    data_loader : py:class:torch.torch.utils.data.DataLoader
-        PyTorch DataLoader
+    data_loader : py:class:`torch.torch.utils.data.DataLoader`
     device : str
         device to use ``'cpu'`` or ``'cuda'``
     output_folder : str
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 03746e4f..a260ab0d 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -33,12 +33,12 @@ def do_train(
     model : :py:class:`torch.nn.Module` 
         Network (e.g. DRIU, HED, UNet)
     data_loader : :py:class:`torch.utils.data.DataLoader`
-    optimizer : :py:class.`torch.optim.Optimizer`
-    criterion : :py:class.`torch.nn.modules.loss._Loss`
+    optimizer : :py:mod:`torch.optim`
+    criterion : :py:class:`torch.nn.modules.loss._Loss`
         loss function
-    scheduler : :py:class.`torch.optim._LRScheduler`
+    scheduler : :py:mod:`torch.optim`
         learning rate scheduler
-    checkpointer : :py:class.`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
+    checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
         checkpointer
     checkpoint_period : int
         save a checkpoint every n epochs
diff --git a/bob/ip/binseg/modeling/losses.py b/bob/ip/binseg/modeling/losses.py
index 3e32b090..5eeb7950 100644
--- a/bob/ip/binseg/modeling/losses.py
+++ b/bob/ip/binseg/modeling/losses.py
@@ -7,7 +7,7 @@ from torch._jit_internal import weak_script_method
 
 class WeightedBCELogitsLoss(_Loss):
     """ 
-    Implements Equation 1 in [DRIU16]_. Based on torch.nn.modules.loss.BCEWithLogitsLoss. 
+    Implements Equation 1 in [DRIU16]_. Based on ``torch.nn.modules.loss.BCEWithLogitsLoss``. 
     Calculate sum of weighted cross entropy loss.
     """
     def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None):
@@ -44,7 +44,7 @@ class WeightedBCELogitsLoss(_Loss):
 
 class SoftJaccardBCELogitsLoss(_Loss):
     """ 
-    Implements Equation 6 in [SAT17]_. Based on torch.nn.modules.loss.BCEWithLogitsLoss. 
+    Implements Equation 6 in [SAT17]_. Based on ``torch.nn.modules.loss.BCEWithLogitsLoss``. 
 
     Attributes
     ----------
@@ -82,7 +82,7 @@ class SoftJaccardBCELogitsLoss(_Loss):
 
 class HEDWeightedBCELogitsLoss(_Loss):
     """ 
-    Implements Equation 2 in [HED15]_. Based on torch.nn.modules.loss.BCEWithLogitsLoss. 
+    Implements Equation 2 in [HED15]_. Based on ``torch.nn.modules.loss.BCEWithLogitsLoss``. 
     Calculate sum of weighted cross entropy loss.
     """
     def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None):
@@ -123,7 +123,7 @@ class HEDWeightedBCELogitsLoss(_Loss):
 
 class HEDSoftJaccardBCELogitsLoss(_Loss):
     """ 
-    Implements Equation 6 in [SAT17]_ for the hed network. Based on torch.nn.modules.loss.BCEWithLogitsLoss. 
+    Implements Equation 6 in [SAT17]_ for the hed network. Based on ``torch.nn.modules.loss.BCEWithLogitsLoss``. 
 
     Attributes
     ----------
diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py
index e821930d..e56aa2fe 100644
--- a/bob/ip/binseg/script/binseg.py
+++ b/bob/ip/binseg/script/binseg.py
@@ -29,6 +29,7 @@ from bob.ip.binseg.utils.plot import plot_overview
 from bob.ip.binseg.utils.click import OptionEatAll
 from bob.ip.binseg.utils.pdfcreator import create_pdf, get_paths
 from bob.ip.binseg.utils.rsttable import create_overview_grid
+from bob.ip.binseg.utils.plot import metricsviz, overlay
 
 logger = logging.getLogger(__name__)
 
@@ -342,6 +343,42 @@ def pdfoverview(output_path, **kwargs):
 @verbosity_option(cls=ResourceOption)
 def gridtable(output_path, **kwargs):
     """ Creates an overview table in grid rst format for all Metrics.csv in the output_path
-    tree structure: ``outputpath/DATABASE/MODEL`` """
+    tree structure: 
+        ├── DATABASE
+        ├── MODEL
+            ├── images
+            └── results
+    """
+    logger.info('Creating grid for all results in {}'.format(output_path))
     create_overview_grid(output_path)
-    
\ No newline at end of file
+
+
+# Create metrics viz
+@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
+@click.option(
+    '--dataset',
+    '-d',
+    required=True,
+    cls=ResourceOption
+    )
+@click.option(
+    '--output-path',
+    '-o',
+    required=True,
+    )
+@verbosity_option(cls=ResourceOption)
+def visualize(dataset, output_path, **kwargs):
+    """ Creates the following visualizations of the probabilties output maps:
+    overlayed: test images overlayed with prediction probabilities vessel tree
+    tpfnfpviz: highlights true positives, false negatives and false positives
+
+    Required tree structure: 
+    ├── DATABASE
+        ├── MODEL
+            ├── images
+            └── results
+    """
+    logger.info('Creating TP, FP, FN visualizations for {}'.format(output_path))
+    metricsviz(dataset=dataset, output_path=output_path)
+    logger.info('Creating overlay visualizations for {}'.format(output_path))
+    overlay(dataset=dataset, output_path=output_path)
\ No newline at end of file
diff --git a/bob/ip/binseg/utils/checkpointer.py b/bob/ip/binseg/utils/checkpointer.py
index 494d990d..4da05227 100644
--- a/bob/ip/binseg/utils/checkpointer.py
+++ b/bob/ip/binseg/utils/checkpointer.py
@@ -11,6 +11,15 @@ from bob.ip.binseg.utils.model_serialization import load_state_dict
 from bob.ip.binseg.utils.model_zoo import cache_url
 
 class Checkpointer:
+    """Adapted from [MASKRCNNBENCHMARK_18]_
+    
+    Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+    
+    Returns
+    -------
+    [type]
+        [description]
+    """
     def __init__(
         self,
         model,
diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py
index 0f1bb882..bedb7625 100644
--- a/bob/ip/binseg/utils/plot.py
+++ b/bob/ip/binseg/utils/plot.py
@@ -1,40 +1,44 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-# author='Andre Anjos',
-# author_email='andre.anjos@idiap.ch',
-
 import numpy as np
 import os
 import csv 
+import pandas as pd
+import PIL
+from PIL import Image
+import torchvision.transforms.functional as VF
+import torch
 
-def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds500=False):
-    '''Creates a precision-recall plot of the given data.   
+def precision_recall_f1iso(precision, recall, names, title=None):
+    """
+    Author: Andre Anjos (andre.anjos@idiap.ch).
+    
+    Creates a precision-recall plot of the given data.   
     The plot will be annotated with F1-score iso-lines (in which the F1-score
     maintains the same value)   
+    
     Parameters
     ----------  
-      precision : :py:class:`np.ndarray` or :py:class:`list`
+    precision : :py:class:`numpy.ndarray` or :py:class:`list`
         A list of 1D np arrays containing the Y coordinates of the plot, or
         the precision, or a 2D np array in which the rows correspond to each
         of the system's precision coordinates.  
-      recall : :py:class:`np.ndarray` or :py:class:`list`
+    recall : :py:class:`numpy.ndarray` or :py:class:`list`
         A list of 1D np arrays containing the X coordinates of the plot, or
         the recall, or a 2D np array in which the rows correspond to each
         of the system's recall coordinates. 
-      names : :py:class:`list`
+    names : :py:class:`list`
         An iterable over the names of each of the systems along the rows of
         ``precision`` and ``recall``      
-      title : :py:class:`str`, optional
+    title : :py:class:`str`, optional
         A title for the plot. If not set, omits the title   
-      human_perf_bsds500 : :py:class:`bool`, optional
-        Whether to display the human performance on the BSDS-500 dataset - it is
-        a fixed point on precision=0.897659 and recall=0.700762.    
+
     Returns
     ------- 
-      figure : matplotlib.figure.Figure
+    matplotlib.figure.Figure
         A matplotlib figure you can save or display 
-    ''' 
+    """ 
     import matplotlib
     matplotlib.use('agg')
     import matplotlib.pyplot as plt 
@@ -74,8 +78,6 @@ def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds
         x = np.linspace(0.01, 1)
         y = f_score * x / (2 * x - f_score)
         l, = plt.plot(x[y >= 0], y[y >= 0], color='green', alpha=0.1)
-        if human_perf_bsds500:
-            plt.plot(0.700762, 0.897659, 'go', markersize=5, label='[F=0.800] Human')
         tick_locs.append(y[-1])
         tick_labels.append('%.1f' % f_score)  
     ax2.tick_params(axis='y', which='both', pad=0, right=False, left=False)
@@ -103,17 +105,18 @@ def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds
 
 
 def loss_curve(df, title):
-    ''' Creates a loss curve
-    Dataframe with column names:
-    ["avg. loss", "median loss","lr","max memory"]
-    Arguments
-    ---------
-    df : :py:class.`pandas.DataFrame`
+    """ Creates a loss curve given a Dataframe with column names:
+
+    ``['avg. loss', 'median loss','lr','max memory']``
+    
+    Parameters
+    ----------
+    df : :py:class:`pandas.DataFrame`
     
     Returns
     -------
-    fig : matplotlib.figure.Figure
-    ''' 
+    matplotlib.figure.Figure
+    """   
     import matplotlib
     matplotlib.use('agg')
     import matplotlib.pyplot as plt 
@@ -134,14 +137,14 @@ def read_metricscsv(file):
     Read precision and recall from csv file
     
     Parameters
-    ---------
-    file: str
-           path to file
+    ----------
+    file : str
+        path to file
     
     Returns
     -------
-        precision : :py:class:`np.ndarray`
-        recall : :py:class:`np.ndarray`
+    :py:class:`numpy.ndarray`
+    :py:class:`numpy.ndarray`
     """
     with open (file, "r") as infile:
         metricsreader = csv.reader(infile)
@@ -158,13 +161,14 @@ def read_metricscsv(file):
 def plot_overview(outputfolders):
     """
     Plots comparison chart of all trained models
-    Arguments
-    ---------
+    
+    Parameters
+    ----------
     outputfolder : list
-                    list containing output paths of all evaluated models (e.g. ['DRIVE/model1', 'DRIVE/model2'])
+        list containing output paths of all evaluated models (e.g. ``['DRIVE/model1', 'DRIVE/model2']``)
     Returns
     -------
-    fig : matplotlib.figure.Figure
+    matplotlib.figure.Figure
     """
     precisions = []
     recalls = []
@@ -189,4 +193,97 @@ def plot_overview(outputfolders):
     fig = precision_recall_f1iso(precisions,recalls,names,title)
     return fig
 
-  
+def metricsviz(dataset
+                ,output_path
+                ,tp_color= (128,128,128)
+                ,fp_color = (70, 240, 240)
+                ,fn_color = (245, 130, 48)
+                ):
+    """ Visualizes true positives, false positives and false negatives
+    Default colors TP: Gray, FP: Cyan, FN: Orange
+    
+    Parameters
+    ----------
+    dataset : :py:class:`torch.utils.data.Dataset`
+    output_path : str
+        path where results and probability output images are stored. E.g. ``'DRIVE/MODEL'``
+    tp_color : tuple
+        RGB values, by default (128,128,128)
+    fp_color : tuple
+        RGB values, by default (70, 240, 240)
+    fn_color : tuple
+        RGB values, by default (245, 130, 48)
+    """
+
+    for sample in dataset:
+        # get sample
+        name  = sample[0]
+        img = VF.to_pil_image(sample[1]) # PIL Image
+        gt = sample[2].byte() # byte tensor
+        
+        # read metrics 
+        metrics = pd.read_csv(os.path.join(output_path,'results','Metrics.csv'))
+        optimal_threshold = metrics['threshold'][metrics['f1_score'].idxmax()]
+        
+        # read probability output 
+        pred = Image.open(os.path.join(output_path,'images',name))
+        pred = VF.to_tensor(pred)
+        binary_pred = torch.gt(pred, optimal_threshold).byte()
+        
+        # calc metrics
+        # equals and not-equals
+        equals = torch.eq(binary_pred, gt) # tensor
+        notequals = torch.ne(binary_pred, gt) # tensor      
+        # true positives 
+        tp_tensor = (gt * binary_pred ) # tensor
+        tp_pil = VF.to_pil_image(tp_tensor.float())
+        tp_pil_colored = PIL.ImageOps.colorize(tp_pil, (0,0,0), tp_color)
+        # false positives 
+        fp_tensor = torch.eq((binary_pred + tp_tensor), 1) 
+        fp_pil = VF.to_pil_image(fp_tensor.float())
+        fp_pil_colored = PIL.ImageOps.colorize(fp_pil, (0,0,0), fp_color)
+        # false negatives
+        fn_tensor = notequals - fp_tensor
+        fn_pil = VF.to_pil_image(fn_tensor.float())
+        fn_pil_colored = PIL.ImageOps.colorize(fn_pil, (0,0,0), fn_color)
+
+        # paste together
+        tp_pil_colored.paste(fp_pil_colored,mask=fp_pil)
+        tp_pil_colored.paste(fn_pil_colored,mask=fn_pil)
+
+        # save to disk 
+        overlayed_path = os.path.join(output_path,'tpfnfpviz')
+        if not os.path.exists(overlayed_path): os.makedirs(overlayed_path)
+        tp_pil_colored.save(os.path.join(overlayed_path,name))
+
+
+def overlay(dataset, output_path):
+    """Overlays prediction probabilities vessel tree with original test image.
+    
+    Parameters
+    ----------
+    dataset : :py:class:`torch.utils.data.Dataset`
+    output_path : str
+        path where results and probability output images are stored. E.g. ``'DRIVE/MODEL'``
+    """
+
+    for sample in dataset:
+        # get sample
+        name  = sample[0]
+        img = VF.to_pil_image(sample[1]) # PIL Image
+        gt = sample[2].byte() # byte tensor
+        
+        # read metrics 
+        metrics = pd.read_csv(os.path.join(output_path,'results','Metrics.csv'))
+        optimal_threshold = metrics['threshold'][metrics['f1_score'].idxmax()]
+        
+        # read probability output 
+        pred = Image.open(os.path.join(output_path,'images',name))
+        # color and overlay
+        pred_green = PIL.ImageOps.colorize(pred, (0,0,0), (0,255,0))
+        overlayed = PIL.Image.blend(img, pred_green, 0.4)
+
+        # save to disk
+        overlayed_path = os.path.join(output_path,'overlayed')
+        if not os.path.exists(overlayed_path): os.makedirs(overlayed_path)
+        overlayed.save(os.path.join(overlayed_path,name))
\ No newline at end of file
diff --git a/doc/api.rst b/doc/api.rst
index 103d9622..e2b6182d 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -16,7 +16,8 @@ PyTorch Dataset
 Transforms
 ----------
 .. note:: 
-    All transforms work with PIL.Image.Image objects.
+    All transforms work with :py:class:`PIL.Image.Image` objects. We make heavy use of the
+    `torchvision package`_
 
 .. automodule:: bob.ip.binseg.data.transforms
 
@@ -24,5 +25,20 @@ Losses
 ------
 .. automodule:: bob.ip.binseg.modeling.losses
 
+Training
+--------
+.. automodule:: bob.ip.binseg.engine.trainer
+
+Checkpointer
+------------
+.. automodule:: bob.ip.binseg.utils.checkpointer
+
+Inference and Evaluation
+------------------------
+.. automodule:: bob.ip.binseg.engine.inferencer
+
+Plotting
+--------
+.. automodule:: bob.ip.binseg.utils.plot
 
 .. include:: links.rst
diff --git a/doc/conf.py b/doc/conf.py
index 4c2c5a75..d64d7794 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -235,6 +235,7 @@ else:
 
 intersphinx_mapping['torch'] = ('https://pytorch.org/docs/stable/', None)
 intersphinx_mapping['PIL'] = ('http://pillow.readthedocs.io/en/stable', None)
+intersphinx_mapping['pandas'] = ('https://pandas.pydata.org/pandas-docs/stable/',None)
 # We want to remove all private (i.e. _. or __.__) members
 # that are not in the list of accepted functions
 accepted_private_functions = ['__array__']
diff --git a/doc/links.rst b/doc/links.rst
index d22f71b4..d8a696c4 100644
--- a/doc/links.rst
+++ b/doc/links.rst
@@ -5,4 +5,5 @@
 .. _idiap: http://www.idiap.ch
 .. _bob: http://www.idiap.ch/software/bob
 .. _installation: https://www.idiap.ch/software/bob/install
-.. _mailing list: https://www.idiap.ch/software/bob/discuss
\ No newline at end of file
+.. _mailing list: https://www.idiap.ch/software/bob/discuss
+.. _torchvision package: https://github.com/pytorch/vision
\ No newline at end of file
diff --git a/doc/references.rst b/doc/references.rst
index 55500588..9422bbc1 100644
--- a/doc/references.rst
+++ b/doc/references.rst
@@ -6,4 +6,5 @@ References
 
 .. [HED15] *Saining Xie and Zhuowen Tu*, **Holistically-Nested Edge Detection**, in: Proceedings of IEEE International Conference on Computer Vision, 2015
 .. [SAT17] *Alexey Shvets, Vladimir Iglovikov, Alexander Rakhlin and Alexandr A. Kalinin** , in:  17th IEEE International Conference on Machine Learning and Applications (ICMLA), 2017
-.. [DRIU16] *K.K. Maninis, J. Pont-Tuset, P. Arbeláez, and L. Van Gool**, in: Medical Image Computing and Computer-Assisted Intervention (MICCAI), 2016
\ No newline at end of file
+.. [DRIU16] *K.K. Maninis, J. Pont-Tuset, P. Arbeláez, and L. Van Gool**, in: Medical Image Computing and Computer-Assisted Intervention (MICCAI), 2016
+.. [MASKRCNNBENCHMARK_18] **Francisco Massa and Ross Girshick**, in https://github.com/facebookresearch/maskrcnn-benchmark
\ No newline at end of file
diff --git a/setup.py b/setup.py
index ac1795b8..2b70354f 100644
--- a/setup.py
+++ b/setup.py
@@ -52,6 +52,7 @@ setup(
           'testcheckpoints = bob.ip.binseg.script.binseg:testcheckpoints',
           'pdfoverview = bob.ip.binseg.script.binseg:testcheckpoints',
           'gridtable = bob.ip.binseg.script.binseg:testcheckpoints',
+          'visualize = bob.ip.binseg.script.binseg:visualize',
         ],
 
          #bob hed train configurations
-- 
GitLab