Commit 56b8b301 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Added some documentation

parent 3604b8d7
......@@ -124,7 +124,9 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
query = query.filter(bob.db.cuhk_cufs.Protocol_File_Association.protocol.in_(protocols))
query = query.filter(bob.db.cuhk_cufs.Protocol_File_Association.purpose.in_(purposes))
if model_ids is not None:
if model_ids is not None and not 'probe' in purposes:
if type(model_ids) is not list and type(model_ids) is not tuple:
model_ids = [model_ids]
......@@ -134,7 +136,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
for m in model_ids:
model_aux.append(m.id)
model_ids = model_aux
query = query.filter(bob.db.cuhk_cufs.Client.id.in_(model_ids))
raw_files = query.all()
......@@ -148,7 +150,7 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
return files
def model_ids(self, protocol=None, groups=None):
def clients(self, protocol=None, groups=None):
#Checking inputs
groups = self.check_parameters_for_validity(groups, "group", GROUPS)
......@@ -165,7 +167,11 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
query = query.filter(bob.db.cuhk_cufs.Protocol_File_Association.group.in_(groups))
query = query.filter(bob.db.cuhk_cufs.Protocol_File_Association.protocol.in_(protocols))
return [c.id for c in query.all()]
return query.all()
def model_ids(self, protocol=None, groups=None):
return [c.id for c in self.clients(protocol=protocol, groups=groups)]
def groups(self, protocol = None, **kwargs):
......@@ -173,19 +179,54 @@ class Database(bob.db.verification.utils.SQLiteDatabase, bob.db.verification.uti
return GROUPS
####### score normalization methods
def tmodel_ids(self, groups = None, protocol = None, **kwargs):
"""This function returns the ids of the T-Norm models of the given groups for the given protocol."""
def zclients(self, protocol=None):
"""Returns a set of Z-Norm clients for the specific query by the user."""
return self.clients(protocol=protocol, groups="world")
def tclients(self, protocol=None):
"""Returns a set of T-Norm clients for the specific query by the user."""
return self.zclients(protocol=protocol)
def zobjects(self, protocol=None, groups=None):
"""Returns a set of Z-Norm objects for the specific query by the user."""
#Checking inputs
protocols = self.check_parameters_for_validity(protocol, "protocol", PROTOCOLS)
#You need to select only one protocol
if (len(protocols) > 1):
raise ValueError("Please, select only one of the following protocols {0}".format(protocols))
#Querying
query = self.query(bob.db.cuhk_cufs.File).join(bob.db.cuhk_cufs.Protocol_File_Association)
#filtering
query = query.filter(bob.db.cuhk_cufs.Protocol_File_Association.protocol.in_(protocols))
query = query.filter(bob.db.cuhk_cufs.Protocol_File_Association.group == "world")
###### THE MOST IMPORTANT THING IN THE METHOD
### IF THE PROTOCOL IS PHOTO --> SKETCH, THE T-OBJECTS ARE PHOTOS
### IF THE PROTOCOL IS SKETCH --> PHOTO, THE T-OBJECTS ARE SKETCHES
if "p2s" in protocol:
query = query.filter(bob.db.cuhk_cufs.File.modality == "photo")
else:
query = query.filter(bob.db.cuhk_cufs.File.modality == "sketch")
return []
return query.all()
def tobjects(self, protocol=None, model_ids=None, groups=None):
#No TObjects
return []
"""Returns a set of T-Norm objects for the specific query by the user."""
return self.zobjects(protocol=protocol)
def tmodel_ids(self, groups = None, protocol = None, **kwargs):
"""This function returns the ids of the T-Norm models of the given groups for the given protocol."""
return ["t_"+str(c.id) for c in self.tclients(protocol=protocol)]
def zobjects(self, protocol=None, groups=None):
#No TObjects
return []
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# Thu 12 Nov 2015 16:35:08 CET
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the ipyplotied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import print_function
"""This script evaluates the given score files and computes EER, HTER.
It also is able to plot CMC and ROC curves."""
import bob.measure
import argparse
import numpy, math
import os
# matplotlib stuff
import matplotlib; matplotlib.use('pdf') #avoids TkInter threaded start
from matplotlib import pyplot
from matplotlib.backends.backend_pdf import PdfPages
# enable LaTeX interpreter
matplotlib.rc('text', usetex=True)
matplotlib.rc('font', family='serif')
matplotlib.rc('lines', linewidth = 4)
# increase the default font size
matplotlib.rc('font', size=18)
import bob.core
logger = bob.core.log.setup("bob.bio.base")
def command_line_arguments(command_line_parameters):
"""Parse the program options"""
# set up command line parser
parser = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-d', '--dev-files', required=True, nargs='+', help = "A list of score files of the development set.")
parser.add_argument('-s', '--directory', default = '.', help = "A directory, where to find the --dev-files and the --eval-files")
parser.add_argument('-c', '--criterion', choices = ('EER', 'HTER'), help = "If given, the threshold of the development set will be computed with this criterion.")
parser.add_argument('--cost', default=0.99, help='Cost for FAR in minDCF')
parser.add_argument('-r', '--rr', action = 'store_true', help = "If given, the Recognition Rate will be computed.")
parser.add_argument('-l', '--legends', nargs='+', help = "A list of legend strings used for ROC, CMC and DET plots; THE NUMBER OF PLOTS SHOULD BE MULTIPLE OF THE NUMBER OF LEGGENDS. IN THAT WAY, EACH SEGMENT WILL BE AVERAGED")
parser.add_argument('-F', '--legend-font-size', type=int, default=18, help = "Set the font size of the legends.")
parser.add_argument('-P', '--legend-position', type=int, help = "Set the font size of the legends.")
parser.add_argument('-R', '--roc', help = "If given, ROC curves will be plotted into the given pdf file.")
parser.add_argument('-D', '--det', help = "If given, DET curves will be plotted into the given pdf file.")
parser.add_argument('-C', '--cmc', help = "If given, CMC curves will be plotted into the given pdf file.")
parser.add_argument('--parser', default = '4column', choices = ('4column', '5column'), help="The style of the resulting score files. The default fits to the usual output of score files.")
# add verbose option
bob.core.log.add_command_line_option(parser)
# parse arguments
args = parser.parse_args(command_line_parameters)
# set verbosity level
bob.core.log.set_verbosity_level(logger, args.verbose)
# some sanity checks:
# update legends when they are not specified on command line
if args.legends is None:
args.legends = [f.replace('_', '-') for f in args.dev_files]
logger.warn("Legends are not specified; using legends estimated from --dev-files: %s", args.legends)
# check that the legends have the same length as the dev-files
if (len(args.dev_files) % len(args.legends)) != 0:
logger.error("The number of --dev-files (%d) is not multiple of --legends (%d) ", len(args.dev_files), len(args.legends))
return args
def _plot_roc(scores_input, colors, labels, title, fontsize=18, position=None):
if position is None: position = 4
figure = pyplot.figure()
logger.info("Computing CAR curves on the development " )
fars = [math.pow(10., i * 0.25) for i in range(-16,0)] + [1.]
frrs = [bob.measure.roc_for_far(scores[0], scores[1], fars) for scores in scores_input]
offset = 0
step = int(len(scores_input)/len(labels))
#For each group of labels
for i in range(len(labels)):
frrs_accumulator = numpy.zeros((step,frrs[0][0].shape[0]))
fars_accumulator = numpy.zeros((step,frrs[0][1].shape[0]))
for j in range(offset,offset+step):
frrs_accumulator[j-i*step,:] = frrs[j][0]
fars_accumulator[j-i*step,:] = frrs[j][1]
frr_average = numpy.mean(frrs_accumulator, axis=0)
far_average = numpy.mean(fars_accumulator, axis=0); far_std = numpy.std(fars_accumulator, axis=0)
pyplot.semilogx(frr_average*100, 100. - 100.0*far_average, color=colors[i], lw=2, ms=10, mew=1.5, label=labels[i])
pyplot.errorbar(frr_average*100, 100. - 100.0*far_average, far_std*100, lw=0.5, ms=10)
offset += step
# plot FAR and CAR for each algorithm
#for i in range(len(frrs)):
#pyplot.semilogx([100.0*f for f in frrs[i][0]], [100. - 100.0*f for f in frrs[i][1]], color=colors[i+1], lw=0.5, ls='--', ms=10, mew=1.5, label=str(i))
# finalize plot
pyplot.plot([0.1,0.1],[0,100], "--", color=(0.3,0.3,0.3))
pyplot.axis([frrs[0][0][0]*100,100,0,100])
pyplot.xticks((0.01, 0.1, 1, 10, 100), ('0.01', '0.1', '1', '10', '100'))
pyplot.xlabel('FAR (\%)')
pyplot.ylabel('CAR (\%)')
pyplot.grid(True, color=(0.6,0.6,0.6))
pyplot.legend(loc=position, prop = {'size':fontsize})
pyplot.title(title)
return figure
def _plot_det(scores_input, colors, labels, title, fontsize=18, position=None):
if position is None: position = 1
# open new page for current plot
figure = pyplot.figure(figsize=(8.2,8))
dets = [bob.measure.det(scores[0], scores[1], 1000) for scores in scores_input]
offset = 0
step = int(len(scores_input)/len(labels))
#For each group of labels
for i in range(len(labels)):
frrs_accumulator = numpy.zeros((step,dets[0][0].shape[0]))
fars_accumulator = numpy.zeros((step,dets[0][1].shape[0]))
for j in range(offset,offset+step):
frrs_accumulator[j,:] = dets[j][0]
fars_accumulator[j,:] = dets[j][1]
frr_average = numpy.mean(frrs_accumulator, axis=0)
far_average = numpy.mean(fars_accumulator, axis=0); far_std = numpy.std(fars_accumulator, axis=0)
pyplot.plot(frr_average, far_average, color=colors[i], lw=2, ms=10, mew=1.5, label=labels[i])
pyplot.errorbar(frr_average, far_average, far_std, lw=0.5, ms=10)
offset += step
# plot the DET curves
#for i in range(len(dets)):
#pyplot.plot(dets[i][0], dets[i][1], color=colors[i], lw=0.5, ls="--", ms=10, mew=1.5, label=str(i))
# change axes accordingly
det_list = [0.0002, 0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 0.7, 0.9, 0.95]
ticks = [bob.measure.ppndf(d) for d in det_list]
labels = [("%.5f" % (d*100)).rstrip('0').rstrip('.') for d in det_list]
pyplot.xticks(ticks, labels)
pyplot.yticks(ticks, labels)
pyplot.axis((ticks[0], ticks[-1], ticks[0], ticks[-1]))
pyplot.xlabel('FAR (\%)')
pyplot.ylabel('FRR (\%)')
pyplot.legend(loc=position, prop = {'size':fontsize})
pyplot.title(title)
return figure
def _plot_cmc(cmcs, colors, labels, title, fontsize=18, position=None):
if position is None: position = 4
# open new page for current plot
figure = pyplot.figure()
offset = 0
step = int(len(cmcs)/len(labels))
#For each group of labels
max_x = 0 #Maximum CMC size
for i in range(len(labels)):
#Computing the CMCs
cmc_curves = []
for j in range(offset,offset+step):
cmc_curves.append(bob.measure.cmc(cmcs[j]))
max_x = max(len(cmc_curves[j-offset]), max_x)
#Adding the padding with '1's
cmc_accumulator = numpy.zeros(shape=(step,max_x), dtype='float')
for j in range(step):
padding_diff = max_x-len(cmc_curves[j])
cmc_accumulator[j,:] = numpy.pad(cmc_curves[j],(0,padding_diff), 'constant',constant_values=(1))
#cmc_average += numpy.pad(cmc_curves[j],(0,padding_diff), 'constant',constant_values=(1))
cmc_std = numpy.std(cmc_accumulator, axis=0); cmc_std[-1]
cmc_average = numpy.mean(cmc_accumulator, axis=0)
pyplot.semilogx(range(1, cmc_average.shape[0]+1), cmc_average * 100, lw=2, ms=10, mew=1.5, label=labels[i])
pyplot.errorbar(range(1, cmc_average.shape[0]+1), cmc_average*100, cmc_std*100, lw=0.5, ms=10)
offset += step
# change axes accordingly
ticks = [int(t) for t in pyplot.xticks()[0]]
pyplot.xlabel('Rank')
pyplot.ylabel('Probability (\%)')
pyplot.xticks(ticks, [str(t) for t in ticks])
pyplot.axis([0, max_x, 90, 100])
pyplot.legend(loc=position, prop = {'size':fontsize})
pyplot.title(title)
pyplot.grid(True)
return figure
def main(command_line_parameters=None):
"""Reads score files, computes error measures and plots curves."""
args = command_line_arguments(command_line_parameters)
# get some colors for plotting
cmap = pyplot.cm.get_cmap(name='hsv')
colors = [cmap(i) for i in numpy.linspace(0, 1.0, len(args.dev_files)+1)]
if args.criterion or args.roc or args.det:
score_parser = {'4column' : bob.measure.load.split_four_column, '5column' : bob.measure.load.split_five_column}[args.parser]
# First, read the score files
logger.info("Loading %d score files of the development set", len(args.dev_files))
scores_dev = [score_parser(os.path.join(args.directory, f)) for f in args.dev_files]
if args.criterion:
logger.info("Computing %s on the development " % args.criterion )
for i in range(len(scores_dev)):
# compute threshold on development set
threshold = {'EER': bob.measure.eer_threshold, 'HTER' : bob.measure.min_hter_threshold} [args.criterion](scores_dev[i][0], scores_dev[i][1])
# apply threshold to development set
far, frr = bob.measure.farfrr(scores_dev[i][0], scores_dev[i][1], threshold)
print("The %s of the development set of '%s' is %2.3f%%" % (args.criterion, args.legends[i], (far + frr) * 50.)) # / 2 * 100%
if args.roc:
logger.info("Plotting ROC curves to file '%s'", args.roc)
try:
# create a multi-page PDF for the ROC curve
pdf = PdfPages(args.roc)
# create a separate figure for dev and eval
pdf.savefig(_plot_roc(scores_dev, colors, args.legends, "CUHK-CUFS ROC Curve between 5 splits", args.legend_font_size, args.legend_position))
#del frrs_dev
pdf.close()
except RuntimeError as e:
raise RuntimeError("During plotting of ROC curves, the following exception occured:\n%s\nUsually this happens when the label contains characters that LaTeX cannot parse." % e)
if args.det:
logger.info("Computing DET curves on the development ")
#dets_dev = [bob.measure.det(scores[0], scores[1], 1000) for scores in scores_dev]
logger.info("Plotting DET curves to file '%s'", args.det)
try:
# create a multi-page PDF for the ROC curve
pdf = PdfPages(args.det)
# create a separate figure for dev and eval
pdf.savefig(_plot_det(scores_dev, colors, args.legends, "CUHK-CUFS DET between 5 splits", args.legend_font_size, args.legend_position))
#del dets_dev
pdf.close()
except RuntimeError as e:
raise RuntimeError("During plotting of ROC curves, the following exception occured:\n%s\nUsually this happens when the label contains characters that LaTeX cannot parse." % e)
if args.cmc or args.rr:
logger.info("Loading CMC data on the development ")
cmc_parser = {'4column' : bob.measure.load.cmc_four_column, '5column' : bob.measure.load.cmc_five_column}[args.parser]
cmcs_dev = [cmc_parser(os.path.join(args.directory, f)) for f in args.dev_files]
if args.cmc:
logger.info("Plotting CMC curves to file '%s'", args.cmc)
try:
# create a multi-page PDF for the ROC curve
pdf = PdfPages(args.cmc)
# create a separate figure for dev and eval
pdf.savefig(_plot_cmc(cmcs_dev, colors, args.legends, "CUHK-CUFS CMC between 5 splits", args.legend_font_size, args.legend_position))
pdf.close()
except RuntimeError as e:
raise RuntimeError("During plotting of ROC curves, the following exception occured:\n%s\nUsually this happens when the label contains characters that LaTeX cannot parse." % e)
if args.rr:
logger.info("Computing recognition rate on the development ")
for i in range(len(cmcs_dev)):
rr = bob.measure.recognition_rate(cmcs_dev[i])
print("The Recognition Rate of the development set of '%s' is %2.3f%%" % (args.legends[i], rr * 100.))
......@@ -74,9 +74,13 @@ def test02_search_files_protocols():
assert len(bob.db.cuhk_cufs.Database().objects(protocol=p, groups="dev")) == dev
assert len(bob.db.cuhk_cufs.Database().objects(protocol=p, groups="dev", purposes="enroll")) == dev_enroll
assert len(bob.db.cuhk_cufs.Database().objects(protocol=p, groups="dev", purposes="probe")) == dev_probe
assert len(bob.db.cuhk_cufs.Database().objects(protocol=p, groups="eval")) == 0
p = "search_split1_p2s"
assert len(bob.db.cuhk_cufs.Database().objects(protocol=p, groups="dev", purposes="enroll", model_ids=[5])) == 1
assert len(bob.db.cuhk_cufs.Database().objects(protocol=p, groups="dev", purposes="probe", model_ids=[5])) == dev_probe
def test03_verification_arface_protocols():
......@@ -169,7 +173,32 @@ def test05_verification_cuhk_protocols():
def test06_strings():
def test06_search_clients_protocols():
world = 404
dev = 202
protocols = bob.db.cuhk_cufs.Database().protocols()
for p in protocols:
if "search" in p:
assert len(bob.db.cuhk_cufs.Database().model_ids(protocol=p, groups="world")) == world
assert len(bob.db.cuhk_cufs.Database().model_ids(protocol=p, groups="dev")) == dev
def test07_search_tobjects():
world = 404
protocols = bob.db.cuhk_cufs.Database().protocols()
for p in protocols:
if "search" in p:
assert len(bob.db.cuhk_cufs.Database().tobjects(protocol=p)) == world
assert len(bob.db.cuhk_cufs.Database().tclients(protocol=p)) == world
assert len(bob.db.cuhk_cufs.Database().tmodel_ids(protocol=p)) == world
def test08_strings():
db = bob.db.cuhk_cufs.Database()
......@@ -185,7 +214,7 @@ def test06_strings():
assert f.group == g
def test07_annotations():
def test09_annotations():
db = bob.db.cuhk_cufs.Database()
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Andre Anjos <andre.anjos@idiap.ch>
# Mon 13 Aug 2012 12:38:15 CEST
#
# Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
import os
import sys
import glob
import pkg_resources
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#sys.path.insert(0, os.path.abspath('.'))
# -- General configuration -----------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be extensions
# coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = [
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.pngmath',
'sphinx.ext.ifconfig',
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
]
# The viewcode extension appeared only on Sphinx >= 1.0.0
import sphinx
if sphinx.__version__ >= "1.0":
extensions.append('sphinx.ext.viewcode')
# Always includes todos
todo_include_todos = True
# If we are on OSX, the 'dvipng' path maybe different
dvipng_osx = '/opt/local/libexec/texlive/binaries/dvipng'
if os.path.exists(dvipng_osx): pngmath_dvipng = dvipng_osx
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix of source filenames.
source_suffix = '.rst'
# The encoding of source files.
#source_encoding = 'utf-8-sig'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = u'CUHK Face Sketch FERET Database (CUFS) (Bob API)'
import time
copyright = u'%s, Idiap Research Institute' % time.strftime('%Y')
# Grab the setup entry
distribution = pkg_resources.require('bob.db.cuhk_cufs')[0]
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = distribution.version
# The full version, including alpha/beta/rc tags.
release = distribution.version
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#language = None
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
#today = ''
# Else, today_fmt is used as the format for a strftime call.
#today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['**/links.rst']
# The reST default role (used for this markup: `text`) to use for all documents.
#default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
#add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
#add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
#show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# A list of ignored prefixes for module index sorting.
#modindex_common_prefix = []
# -- Options for HTML output ---------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
if sphinx.__version__ >= "1.0":
html_theme = 'nature'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#html_theme_options = {}
# Add any paths that contain custom themes here, relative to this directory.
#html_theme_path = []
# The name for this set of Sphinx documents. If None, it defaults to
# "<project> v<release> documentation".
#html_title = None
# A shorter title for the navigation bar. Default is the same as html_title.