Commit 12b3da0c authored by André Anjos's avatar André Anjos 💬

Fixes to marker detector

parent fab9c86e
......@@ -13,7 +13,7 @@ Usage: %(prog)s [-v...] [--samples=N] [--model=PATH] [--points=N] [--hidden=N]
Arguments:
<database> Name of the database to use for creating the model (options are:
"fv3d")
"fv3d" or "verafinger")
<protocol> Name of the protocol to use for creating the model (options
depend on the database chosen)
<group> Name of the group to use on the database/protocol with the
......@@ -89,7 +89,8 @@ def validate(args):
'''
from .validate import check_model_does_not_exist
from .validate import check_model_does_not_exist, validate_protocol, \
validate_group
sch = schema.Schema({
'--model': check_model_does_not_exist,
......@@ -98,9 +99,9 @@ def validate(args):
'--hidden': schema.Use(int),
'--batch': schema.Use(int),
'--iterations': schema.Use(int),
'<database>': lambda n: n in ('fv3d',),
'<protocol>': lambda n: n in ('central', 'left', 'right'),
'<group>': lambda n: n in ('dev',),
'<database>': lambda n: n in ('fv3d', 'verafinger'),
'<protocol>': validate_protocol(args['<database>']),
'<group>': validate_group(args['<database>']),
str: object, #ignores strings we don't care about
}, ignore_extra_keys=True)
......@@ -134,10 +135,19 @@ def main(user_input=None):
except schema.SchemaError as e:
sys.exit(e)
from ..configurations.fv3d import database as db
if args['<database>'] == 'fv3d':
from ..configurations.fv3d import database as db
elif args['<database>'] == 'verafinger':
from ..configurations.verafinger import database as db
else:
raise schema.SchemaError('Database %s is not supported' % \
args['<database>'])
database_replacement = "%s/.bob_bio_databases.txt" % os.environ["HOME"]
db.replace_directories(database_replacement)
objects = db.objects(protocol=args['<protocol>'], groups=args['<group>'])
if args['--samples'] is None:
args['--samples'] = len(objects)
from ..preprocessor.utils import poly_to_mask
features = None
......@@ -169,6 +179,7 @@ def main(user_input=None):
dtype='bool')
mask = poly_to_mask(image.shape, image.metadata['roi'])
mask = mask[1:-1, 1:-1]
for y in range(windows.shape[0]):
for x in range(windows.shape[1]):
......@@ -232,7 +243,7 @@ def main(user_input=None):
pos_errors = pos_output < 0
hter_train = ((sum(neg_errors) / float(len(negatives))) + \
(sum(pos_errors)) / float(len(positives))) / 2.0
logger.info('Training set HTER: %.2f%%', hter_train)
logger.info('Training set HTER: %.2f%%', 100*hter_train)
logger.info(' Errors on negatives: %d / %d', sum(neg_errors), len(negatives))
logger.info(' Errors on positives: %d / %d', sum(pos_errors), len(positives))
......@@ -241,7 +252,8 @@ def main(user_input=None):
pos_errors = pos_output < -threshold
hter_train = ((sum(neg_errors) / float(len(negatives))) + \
(sum(pos_errors)) / float(len(positives))) / 2.0
logger.info('Training set HTER (threshold=%g): %.2f%%', threshold, hter_train)
logger.info('Training set HTER (threshold=%g): %.2f%%', threshold,
100*hter_train)
logger.info(' Errors on negatives: %d / %d', sum(neg_errors), len(negatives))
logger.info(' Errors on positives: %d / %d', sum(pos_errors), len(positives))
# plot separation threshold
......
......@@ -170,3 +170,79 @@ def open_multipage_pdf_file(s):
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
return PdfPages(s)
class validate_protocol(object):
'''Validates the protocol for a given database
Parameters:
name (str): The name of the database to validate the protocol for
Raises:
schema.SchemaError: if the database is not supported
'''
def __init__(self, name):
self.dbname = name
if name == 'fv3d':
import bob.db.fv3d
self.valid_names = bob.db.fv3d.Database().protocol_names()
elif name == 'verafinger':
import bob.db.verafinger
self.valid_names = bob.db.verafinger.Database().protocol_names()
else:
raise schema.SchemaError("do not support database {}".format(name))
def __call__(self, name):
if name not in self.valid_names:
msg = "{} is not a valid protocol for database {}"
raise schema.SchemaError(msg.format(name, self.dbname))
return True
class validate_group(object):
'''Validates the group for a given database
Parameters:
name (str): The name of the database to validate the group for
Raises:
schema.SchemaError: if the database is not supported
'''
def __init__(self, name):
self.dbname = name
if name == 'fv3d':
import bob.db.fv3d
self.valid_names = bob.db.fv3d.Database().groups()
elif name == 'verafinger':
import bob.db.verafinger
self.valid_names = bob.db.verafinger.Database().groups()
else:
raise schema.SchemaError("do not support database {}".format(name))
def __call__(self, name):
if name not in self.valid_names:
msg = "{} is not a valid group for database {}"
raise schema.SchemaError(msg.format(name, self.dbname))
return True
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment