compare_rois.py 5.54 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

"""Compares two set of masks and prints some error metrics

This program requires that the masks for both databases (one representing the
ground-truth and a second the database with an automated method) are
represented as HDF5 files containing a ``mask`` object, which should be boolean
in nature.


Usage: %(prog)s [-v...] [-n X] <ground-truth> <database>
       %(prog)s --help
       %(prog)s --version


Arguments:
  <ground-truth>  Path to a set of files that contain ground truth annotations
                  for the ROIs you wish to compare.
  <database>      Path to a similar set of files as in `<ground-truth>`, but
                  with ROIs calculated automatically. Every HDF5 in this
                  directory will be matched to an equivalent file in the
                  `<ground-truth>` database and their masks will be compared


Options:
  -h, --help          Shows this help message and exits
  -V, --version       Prints the version and exits
  -v, --verbose       Increases the output verbosity level
  -n N, --annotate=N  Print out the N worst cases available in the database,
                      taking into consideration the various metrics analyzed


Example:

  1. Just run for basic statistics:

     $ %(prog)s -vvv gt/ automatic/

  2. Identify worst 5 samples in the database according to a certain criterion:

     $ %(prog)s -vv -n 5 gt/ automatic/

"""

import os
import sys
import fnmatch
import operator

import numpy

import bob.core
logger = bob.core.log.setup("bob.measure")

import bob.io.base


def make_catalog(d):
  """Returns a catalog dictionary containing the file stems available in ``d``

  Parameters:

    d (str): A path representing a directory that will be scanned for .hdf5
      files


  Returns

    list: A list of stems, from the directory ``d``, that represent files of
    type HDF5 in that directory. Each file should contain two objects:
    ``image`` and ``mask``.

  """

  logger.info("Scanning directory `%s'..." % d)
  retval = []
  for path, dirs, files in os.walk(d):
    basedir = os.path.relpath(path, d)
    logger.debug("Scanning sub-directory `%s'..." % basedir)
    candidates = fnmatch.filter(files, '*.hdf5')
    if not candidates: continue
    logger.debug("Found %d files" % len(candidates))
    retval += [os.path.join(basedir, k) for k in candidates]
  logger.info("Found a total of %d files at `%s'" % (len(retval), d))
  return sorted(retval)


def sort_table(table, cols):
  """Sorts a table by multiple columns


  Parameters:

    table (:py:class:`list` of :py:class:`list`): Or tuple of tuples, where
      each inner list represents a row

    cols (list, tuple): Specifies the column numbers to sort by e.g. (1,0)
      would sort by column 1, then by column 0


  Returns:

    list: of lists, with the table re-ordered as you see fit.

  """

  for col in reversed(cols):
      table = sorted(table, key=operator.itemgetter(col))
  return table


def mean_std_for_column(table, column):
  """Calculates the mean and standard deviation for the column in question


  Parameters:

    table (:py:class:`list` of :py:class:`list`): Or tuple of tuples, where
      each inner list represents a row

    col (int): The number of the column from where to extract the data for
      calculating the mean and the standard-deviation.


  Returns:

    float: mean

    float: (unbiased) standard deviation

  """

  z = numpy.array([k[column] for k in table])
  return z.mean(), z.std(ddof=1)


def main(user_input=None):

  if user_input is not None:
    argv = user_input
  else:
    argv = sys.argv[1:]

  import docopt
  import pkg_resources

  completions = dict(
      prog=os.path.basename(sys.argv[0]),
      version=pkg_resources.require('bob.bio.vein')[0].version
      )

  args = docopt.docopt(
      __doc__ % completions,
      argv=argv,
      version=completions['version'],
      )

  # Sets-up logging
  verbosity = int(args['--verbose'])
  bob.core.log.set_verbosity_level(logger, verbosity)

  # Catalogs
  gt = make_catalog(args['<ground-truth>'])
  db = make_catalog(args['<database>'])

  if gt != db:
    raise RuntimeError("Ground-truth and database have different files!")

  # Calculate all metrics required
  from ..preprocessor import utils
  metrics = []
  for k in gt:
    logger.info("Evaluating metrics for `%s'..." % k)
    gt_file = os.path.join(args['<ground-truth>'], k)
    db_file = os.path.join(args['<database>'], k)
    gt_roi = bob.io.base.HDF5File(gt_file).read('mask')
    db_roi = bob.io.base.HDF5File(db_file).read('mask')
    metrics.append((
      k,
      utils.jaccard_index(gt_roi, db_roi),
      utils.intersect_ratio(gt_roi, db_roi),
      utils.intersect_ratio_of_complement(gt_roi, db_roi),
      ))

  # Print statistics
  names = (
      (1, 'Jaccard index'),
      (2, 'Intersection ratio (m1)'),
      (3, 'Intersection ratio of complement (m2)'),
      )
  print("Statistics:")
  for k, name in names:
    mean, std = mean_std_for_column(metrics, k)
    print(name + ': ' + '%.2e +- %.2e' % (mean, std))

  # Print worst cases, if the user asked so
  if args['--annotate'] is not None:
    N = int(args['--annotate'])
    if N <= 0:
      raise docopt.DocoptExit("Argument to --annotate should be >0")

    print("Worst cases by metric:")
    for k, name in names:
      print(name + ':')

      if k in (1,2):
        worst = sort_table(metrics, (k,))[:N]
      else:
        worst = reversed(sort_table(metrics, (k,))[-N:])

      for n, l in enumerate(worst):
        fname = os.path.join(args['<database>'], l[0])
        print('  %d. [%.2e] %s' % (n, l[k], fname))