From eb0079a5eeea2dc0086415fc44d1b1c5bfc5f615 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Fri, 1 Sep 2017 10:39:06 +0200 Subject: [PATCH] Add mlp trainer for watershed masker --- bob/bio/vein/script/markdet.py | 291 ++++++++++++++++++++++++++++++++ bob/bio/vein/script/validate.py | 172 +++++++++++++++++++ 2 files changed, 463 insertions(+) create mode 100644 bob/bio/vein/script/markdet.py create mode 100644 bob/bio/vein/script/validate.py diff --git a/bob/bio/vein/script/markdet.py b/bob/bio/vein/script/markdet.py new file mode 100644 index 0000000..a59b48c --- /dev/null +++ b/bob/bio/vein/script/markdet.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + + +"""Trains a new MLP to perform pre-watershed marker detection + +Usage: %(prog)s [-v...] [--samples=N] [--model=PATH] [--points=N] [--hidden=N] + [--batch=N] [--iterations=N] <database> <protocol> <group> + %(prog)s --help + %(prog)s --version + + +Arguments: + + <database> Name of the database to use for creating the model (options are: + "fv3d") + <protocol> Name of the protocol to use for creating the model (options + depend on the database chosen) + <group> Name of the group to use on the database/protocol with the + samples to use for training the model (options are: "train", + "dev" or "eval") + +Options: + + -h, --help Shows this help message and exits + -V, --version Prints the version and exits + -v, --verbose Increases the output verbosity level. Using "-vv" + allows the program to output informational messages as + it goes along. + -m PATH, --model=PATH Path to the generated model file [default: model.hdf5] + -s N, --samples=N Maximum number of samples to use for training. If not + set, use all samples + -p N, --points=N Maximum number of samples to use for plotting + ground-truth and classification errors. The more + points, the less responsive the plot becomes + [default: 1000] + -H N, --hidden=N Number of neurons on the hidden layer of the + multi-layer perceptron [default: 5] + -b N, --batch=N Number of samples to use for every batch [default: 1] + -i N, --iterations=N Number of iterations to train the neural net for + [default: 2000] + + +Examples: + + Trains on the 3D Fingervein database: + + $ %(prog)s -vv fv3d central dev + + Saves the model to a different file, use only 100 samples: + + $ %(prog)s -vv -s 100 --model=/path/to/saved-model.hdf5 fv3d central dev + +""" + + +import os +import sys +import schema +import docopt +import numpy +import skimage + + +def validate(args): + '''Validates command-line arguments, returns parsed values + + This function uses :py:mod:`schema` for validating :py:mod:`docopt` + arguments. Logging level is not checked by this procedure (actually, it is + ignored) and must be previously setup as some of the elements here may use + logging for outputing information. + + + Parameters: + + args (dict): Dictionary of arguments as defined by the help message and + returned by :py:mod:`docopt` + + + Returns + + dict: Validate dictionary with the same keys as the input and with values + possibly transformed by the validation procedure + + + Raises: + + schema.SchemaError: in case one of the checked options does not validate. + + ''' + + from .validate import check_model_does_not_exist + + sch = schema.Schema({ + '--model': check_model_does_not_exist, + '--samples': schema.Or(schema.Use(int), None), + '--points': schema.Use(int), + '--hidden': schema.Use(int), + '--batch': schema.Use(int), + '--iterations': schema.Use(int), + '<database>': lambda n: n in ('fv3d',), + '<protocol>': lambda n: n in ('central',), + '<group>': lambda n: n in ('dev',), + str: object, #ignores strings we don't care about + }, ignore_extra_keys=True) + + return sch.validate(args) + + +def main(user_input=None): + + if user_input is not None: + argv = user_input + else: + argv = sys.argv[1:] + + import pkg_resources + + completions = dict( + prog=os.path.basename(sys.argv[0]), + version=pkg_resources.require('bob.bio.vein')[0].version + ) + + args = docopt.docopt( + __doc__ % completions, + argv=argv, + version=completions['version'], + ) + + try: + from .validate import setup_logger + logger = setup_logger('bob.bio.vein', args['--verbose']) + args = validate(args) + except schema.SchemaError as e: + sys.exit(e) + + from ..configurations.fv3d import database as db + database_replacement = "%s/.bob_bio_databases.txt" % os.environ["HOME"] + db.replace_directories(database_replacement) + objects = db.objects(protocol=args['<protocol>'], groups=args['<group>']) + + from ..preprocessor.utils import poly_to_mask + features = None + target = None + for k, sample in enumerate(objects): + + if args['--samples'] is not None and k >= args['--samples']: break + path = sample.make_path(directory=db.original_directory, + extension=db.original_extension) + logger.info('Loading sample %d/%d (%s)...', k, len(objects), path) + image = sample.load(directory=db.original_directory, + extension=db.original_extension) + if not (hasattr(image, 'metadata') and 'roi' in image.metadata): + logger.info('Skipping sample (no ROI)') + continue + + # copy() required by skimage.util.shape.view_as_windows() + image = image.copy().astype('float64') / 255. + windows = skimage.util.shape.view_as_windows(image, (3,3)) + + if features is None and target is None: + features = numpy.zeros( + (args['--samples']*windows.shape[0]*windows.shape[1], + windows.shape[2]*windows.shape[3]+2), dtype='float64') + target = numpy.zeros(args['--samples']*windows.shape[0]*windows.shape[1], + dtype='bool') + + mask = poly_to_mask(image.shape, image.metadata['roi']) + mask = mask[1:-1, 1:-1] + for y in range(windows.shape[0]): + for x in range(windows.shape[1]): + idx = (k*windows.shape[0]*windows.shape[1]) + (y*windows.shape[1]) + x + features[idx,:-2] = windows[y,x].flatten() + features[idx,-2] = y+1 + features[idx,-1] = x+1 + target[idx] = mask[y,x] + + # normalize w.r.t. dimensions + features[:,-2] /= image.shape[0] + features[:,-1] /= image.shape[1] + + target_float = target.astype('float64') + target_float[~target] = -1.0 + target_float = target_float.reshape(len(target), 1) + positives = features[target] + negatives = features[~target] + logger.info('There are %d samples on input dataset', len(target)) + logger.info(' %d are negatives', len(negatives)) + logger.info(' %d are positives', len(positives)) + + import bob.learn.mlp + + # by default, machine uses hyperbolic tangent output + machine = bob.learn.mlp.Machine((features.shape[1], args['--hidden'], 1)) + machine.randomize() #initialize weights randomly + loss = bob.learn.mlp.SquareError(machine.output_activation) + train_biases = True + trainer = bob.learn.mlp.RProp(args['--batch'], loss, machine, train_biases) + trainer.reset() + shuffler = bob.learn.mlp.DataShuffler([negatives, positives], + [[-1.0], [+1.0]]) + + # start cost + output = machine(features) + cost = loss.f(output, target_float) + logger.info('[initial] MSE = %g', cost.mean()) + + # trains the network until the error is near zero + for i in range(args['--iterations']): + try: + _feats, _tgts = shuffler.draw(args['--batch']) + trainer.train(machine, _feats, _tgts) + logger.info('[%d] MSE = %g', i, trainer.cost(_tgts)) + except KeyboardInterrupt: + print() #avoids the ^C line + logger.info('Gracefully stopping training before limit (%d iterations)', + args['--batch'] + break + + # describe errors + neg_output = machine(negatives) + pos_output = machine(positives) + neg_errors = neg_output >= 0 + pos_errors = pos_output < 0 + hter_train = ((sum(neg_errors) / float(len(negatives))) + \ + (sum(pos_errors)) / float(len(positives))) / 2.0 + logger.info('Training set HTER: %.2f%%', hter_train) + logger.info(' Errors on negatives: %d / %d', sum(neg_errors), len(negatives)) + logger.info(' Errors on positives: %d / %d', sum(pos_errors), len(positives)) + + threshold = 0.8 + neg_errors = neg_output >= threshold + pos_errors = pos_output < -threshold + hter_train = ((sum(neg_errors) / float(len(negatives))) + \ + (sum(pos_errors)) / float(len(positives))) / 2.0 + logger.info('Training set HTER (threshold=%g): %.2f%%', threshold, hter_train) + logger.info(' Errors on negatives: %d / %d', sum(neg_errors), len(negatives)) + logger.info(' Errors on positives: %d / %d', sum(pos_errors), len(positives)) + # plot separation threshold + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D + + # only plot N random samples otherwise it makes it too slow + N = numpy.random.randint(min(len(negatives), len(positives)), + size=min(len(negatives), len(positives), args['--points'])) + + fig = plt.figure() + + ax = fig.add_subplot(211, projection='3d') + ax.scatter(image.shape[1]*negatives[N,-1], image.shape[0]*negatives[N,-2], + 255*negatives[N,4], label='negatives', color='blue', marker='.') + ax.scatter(image.shape[1]*positives[N,-1], image.shape[0]*positives[N,-2], + 255*positives[N,4], label='positives', color='red', marker='.') + ax.set_xlabel('Width') + ax.set_xlim(0, image.shape[1]) + ax.set_ylabel('Height') + ax.set_ylim(0, image.shape[0]) + ax.set_zlabel('Intensity') + ax.set_zlim(0, 255) + ax.legend() + ax.grid() + ax.set_title('Ground Truth') + plt.tight_layout() + + ax = fig.add_subplot(212, projection='3d') + neg_plot = negatives[neg_output[:,0]>=threshold] + pos_plot = positives[pos_output[:,0]<-threshold] + N = numpy.random.randint(min(len(neg_plot), len(pos_plot)), + size=min(len(neg_plot), len(pos_plot), args['--points'])) + ax.scatter(image.shape[1]*neg_plot[N,-1], image.shape[0]*neg_plot[N,-2], + 255*neg_plot[N,4], label='negatives', color='red', marker='.') + ax.scatter(image.shape[1]*pos_plot[N,-1], image.shape[0]*pos_plot[N,-2], + 255*pos_plot[N,4], label='positives', color='blue', marker='.') + ax.set_xlabel('Width') + ax.set_xlim(0, image.shape[1]) + ax.set_ylabel('Height') + ax.set_ylim(0, image.shape[0]) + ax.set_zlabel('Intensity') + ax.set_zlim(0, 255) + ax.legend() + ax.grid() + ax.set_title('Classifier Errors') + plt.tight_layout() + + print('Close plot window to save model and end program...') + plt.show() + import bob.io.base + h5f = bob.io.base.HDF5File(args['--model'], 'w') + machine.save(h5f) + del h5f + logger.info('Saved MLP model to %s', args['--model']) diff --git a/bob/bio/vein/script/validate.py b/bob/bio/vein/script/validate.py new file mode 100644 index 0000000..93d0c5b --- /dev/null +++ b/bob/bio/vein/script/validate.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : + + +'''Utilities for command-line option validation''' + + +import os +import glob +import schema +import logging +logger = logging.getLogger(__name__) + + +def setup_logger(name, level): + '''Sets up and checks a verbosity level respects min and max boundaries + + + Parameters: + + name (str): The name of the logger to setup + + v (int): A value indicating the verbosity that must be set + + + Returns: + + logging.Logger: A standard Python logger that can be used to log messages + + + Raises: + + schema.SchemaError: If the verbosity level exceeds the maximum allowed of 4 + + ''' + + import bob.core + logger = bob.core.log.setup(name) + + if not (0 <= level < 4): + raise schema.SchemaError("there can be only up to 3 -v's in a command-line") + + # Sets-up logging + bob.core.log.set_verbosity_level(logger, level) + + return logger + + +def make_dir(p): + '''Checks if a path exists, if it doesn't, creates it + + + Parameters: + + p (str): The path to check + + + Returns + + bool: ``True``, always + + ''' + + if not os.path.exists(p): + logger.info("Creating directory `%s'...", p) + os.makedirs(p) + + return True + + +def check_path_does_not_exist(p): + '''Checks if a path exists, if it does, raises an exception + + + Parameters: + + p (str): The path to check + + + Returns: + + bool: ``True``, always + + + Raises: + + schema.SchemaError: if the path exists + + ''' + + if os.path.exists(p): + raise schema.SchemaError("path to {} exists".format(p)) + + return True + + +def check_path_exists(p): + '''Checks if a path exists, if it doesn't, raises an exception + + + Parameters: + + p (str): The path to check + + + Returns: + + bool: ``True``, always + + + Raises: + + schema.SchemaError: if the path doesn't exist + + ''' + + if not os.path.exists(p): + raise schema.SchemaError("path to {} does not exist".format(p)) + + return True + + +def check_model_does_not_exist(p): + '''Checks if the path to any potential model file does not exist + + + Parameters: + + p (str): The path to check + + + Returns: + + bool: ``True``, always + + + Raises: + + schema.SchemaError: if the path exists + + ''' + + files = glob.glob(p + '.*') + if files: + raise schema.SchemaError("{} already exists".format(files)) + + return True + + +def open_multipage_pdf_file(s): + '''Returns an opened matplotlib multi-page file + + + Parameters: + + p (str): The path to the file to open + + + Returns: + + matplotlib.backends.backend_pdf.PdfPages: with the handle to the multipage + PDF file + + + Raises: + + schema.SchemaError: if the path exists + + ''' + import matplotlib.pyplot as mpl + from matplotlib.backends.backend_pdf import PdfPages + return PdfPages(s) -- GitLab