Commit d7e18c01 authored by André Anjos's avatar André Anjos

Make plotting optional; Calculate boundary fore/background thresholds

parent ce7cb318
Pipeline #14691 passed with stages
in 16 minutes 54 seconds
......@@ -5,8 +5,8 @@
"""Trains a new MLP to perform pre-watershed marker detection
Usage: %(prog)s [-v...] [--samples=N] [--model=PATH] [--points=N] [--hidden=N]
[--batch=N] [--iterations=N] [--hollow]
<database> <protocol> <group> <size>
[--batch=N] [--iterations=N] [--hollow] [--plot]
[--maximum-error=F] <database> <protocol> <group> <size>
%(prog)s --help
%(prog)s --version
......@@ -47,7 +47,10 @@ Options:
training and the other half for validation. If all
samples are used for training, then no samples will be
used for validation.
-p N, --points=N Maximum number of samples to use for plotting
-p, --plot Plot samples of the validation set exposed to the just
trained neural network. Useful to visualize where
errors happen more frequently.
-P N, --points=N Maximum number of samples to use for plotting
ground-truth and classification errors. The more
points, the less responsive the plot becomes
[default: 1000]
......@@ -60,6 +63,8 @@ Options:
network will be hollow - the pixels that are not in
the outside border of the window are not passed
(except for the center pixel value).
-e, --maximum-error=F Maximum relative error to allow for foreground and
background marker detection [default: 0.03]
Examples:
......@@ -145,6 +150,7 @@ def validate(args):
'--hidden': schema.Use(int),
'--batch': schema.Use(int),
'--iterations': schema.Use(int),
'--maximum-error': schema.Use(float),
'<database>': lambda n: n in ('fv3d', 'verafinger', 'hkpu', 'thufvdt'),
'<protocol>': validate_protocol(args['<database>']),
'<group>': validate_group(args['<database>']),
......@@ -214,32 +220,57 @@ def load_data(objects, original_directory, original_extension, footprint):
return features, target_float, loaded
def analyze(machine, negatives, positives, name, threshold):
def scan_thresholds(machine, negatives, positives, max_error=0.01):
'''Calculates best thresholds for detecting negatives and positives'''
neg_output = machine(negatives)
pos_output = machine(positives)
step = 0.01
fg = 1.0 - step
bg = None
for k in numpy.arange(0.0+step, 1.0, 0.01):
neg_errors = float((neg_output >= k).sum())/ len(negatives)
pos_errors = float((pos_output < k).sum())/ len(positives)
#print(k, neg_errors, pos_errors)
if neg_errors < max_error and bg is None:
logger.debug('Reset background threshold to %g (error = %g)', k,
neg_errors)
bg = k
if pos_errors < max_error:
logger.debug('Reset foreground threshold to %g (error = %g)', k,
pos_errors)
fg = k
if bg is None:
bg = 0.0 + step
if fg < bg:
bg = fg - step
neg_errors = float((neg_output >= bg).sum())/ len(negatives)
logger.debug('Reset background threshold to %g (error = %g) since it was '
'bigger than the foreground threshold', bg, neg_errors)
return fg, bg
def analyze(machine, negatives, positives, name, fg_threshold, bg_threshold):
'''Prints performance analysis'''
# describe errors
neg_output = machine(negatives)
pos_output = machine(positives)
neg_errors = neg_output >= 0
pos_errors = pos_output < 0
neg_errors = neg_output >= fg_threshold
pos_errors = pos_output < bg_threshold
hter = ((sum(neg_errors) / float(len(negatives))) + \
(sum(pos_errors)) / float(len(positives))) / 2.0
logger.info('%s set HTER: %.2f%%', name.capitalize(), 100*hter)
logger.info(' Errors on negatives: %d / %d', sum(neg_errors), len(negatives))
logger.info(' Errors on positives: %d / %d', sum(pos_errors), len(positives))
neg_errors = neg_output >= threshold
pos_errors = pos_output < -threshold
hter = ((sum(neg_errors) / float(len(negatives))) + \
(sum(pos_errors)) / float(len(positives))) / 2.0
logger.info('%s set HTER (threshold=%g): %.2f%%', name.capitalize(),
threshold, 100*hter)
logger.info(' Errors on negatives: %d / %d', sum(neg_errors), len(negatives))
logger.info(' Errors on positives: %d / %d', sum(pos_errors), len(positives))
def plot(machine, negatives, positives, npoints, sample, directory, extension,
threshold):
fg_threshold, bg_threshold):
'''Provides a graphical overview of errors'''
# plot separation threshold
......@@ -275,8 +306,8 @@ def plot(machine, negatives, positives, npoints, sample, directory, extension,
pos_output = machine(positives)
ax = fig.add_subplot(212, projection='3d')
neg_plot = negatives[neg_output[:,0]>=threshold]
pos_plot = positives[pos_output[:,0]<-threshold]
neg_plot = negatives[neg_output[:,0]>=fg_threshold]
pos_plot = positives[pos_output[:,0]<bg_threshold]
N = numpy.random.randint(min(len(neg_plot), len(pos_plot)),
size=min(len(neg_plot), len(pos_plot), npoints))
ax.scatter(image.shape[1]*neg_plot[N,-1], image.shape[0]*neg_plot[N,-2],
......@@ -390,6 +421,8 @@ def main(user_input=None):
# by default, machine uses hyperbolic tangent output
machine = bob.learn.mlp.Machine(
(train_features.shape[1], args['--hidden'], 1))
logger.debug('Machine architecture is %d-%d-1', train_features.shape[1],
args['--hidden'])
machine.randomize() #initialize weights randomly
loss = bob.learn.mlp.SquareError(machine.output_activation)
train_biases = True
......@@ -423,17 +456,30 @@ def main(user_input=None):
args['--batch'])
break
# replaces output function with a sigmoid (output between 0.0 and 1.0)
import bob.learn.activation
machine.output_activation = bob.learn.activation.Logistic()
# check what is the best threshold for the just trained neural network
fg, bg = scan_thresholds(machine, valid_negatives, valid_positives,
args['--maximum-error'])
logger.info('Background threshold = %g', bg)
logger.info('Foreground threshold = %g', fg)
# runs analysis
threshold = 0.8
analyze(machine, train_negatives, train_positives, 'training', threshold)
analyze(machine, valid_negatives, valid_positives, 'validation', threshold)
analyze(machine, train_negatives, train_positives, 'training', fg, bg)
analyze(machine, valid_negatives, valid_positives, 'validation', fg, bg)
plot(machine, valid_negatives, valid_positives, args['--points'], objects[0],
db.original_directory, db.original_extension, threshold)
if args['--plot']:
plot(machine, valid_negatives, valid_positives, args['--points'],
objects[0], db.original_directory, db.original_extension, fg, bg)
# save models
import bob.io.base
h5f = bob.io.base.HDF5File(args['--model'], 'w')
machine.save(h5f)
h5f['footprint'] = footprint
h5f['fg_threshold'] = fg
h5f['bg_threshold'] = bg
del h5f
logger.info('Saved MLP model and footprint to %s', args['--model'])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment