diff --git a/bob/bio/vein/script/markdet.py b/bob/bio/vein/script/markdet.py index c15bb50a807824cd20ddde2e6dd2595e69f8901c..824e3ca9cb25d89bda3de0e8cbcb32de86fd98af 100644 --- a/bob/bio/vein/script/markdet.py +++ b/bob/bio/vein/script/markdet.py @@ -13,7 +13,7 @@ Usage: %(prog)s [-v...] [--samples=N] [--model=PATH] [--points=N] [--hidden=N] Arguments: <database> Name of the database to use for creating the model (options are: - "fv3d") + "fv3d" or "verafinger") <protocol> Name of the protocol to use for creating the model (options depend on the database chosen) <group> Name of the group to use on the database/protocol with the @@ -89,7 +89,8 @@ def validate(args): ''' - from .validate import check_model_does_not_exist + from .validate import check_model_does_not_exist, validate_protocol, \ + validate_group sch = schema.Schema({ '--model': check_model_does_not_exist, @@ -98,9 +99,9 @@ def validate(args): '--hidden': schema.Use(int), '--batch': schema.Use(int), '--iterations': schema.Use(int), - '<database>': lambda n: n in ('fv3d',), - '<protocol>': lambda n: n in ('central', 'left', 'right'), - '<group>': lambda n: n in ('dev',), + '<database>': lambda n: n in ('fv3d', 'verafinger'), + '<protocol>': validate_protocol(args['<database>']), + '<group>': validate_group(args['<database>']), str: object, #ignores strings we don't care about }, ignore_extra_keys=True) @@ -134,10 +135,19 @@ def main(user_input=None): except schema.SchemaError as e: sys.exit(e) - from ..configurations.fv3d import database as db + if args['<database>'] == 'fv3d': + from ..configurations.fv3d import database as db + elif args['<database>'] == 'verafinger': + from ..configurations.verafinger import database as db + else: + raise schema.SchemaError('Database %s is not supported' % \ + args['<database>']) + database_replacement = "%s/.bob_bio_databases.txt" % os.environ["HOME"] db.replace_directories(database_replacement) objects = db.objects(protocol=args['<protocol>'], groups=args['<group>']) + if args['--samples'] is None: + args['--samples'] = len(objects) from ..preprocessor.utils import poly_to_mask features = None @@ -169,6 +179,7 @@ def main(user_input=None): dtype='bool') mask = poly_to_mask(image.shape, image.metadata['roi']) + mask = mask[1:-1, 1:-1] for y in range(windows.shape[0]): for x in range(windows.shape[1]): @@ -232,7 +243,7 @@ def main(user_input=None): pos_errors = pos_output < 0 hter_train = ((sum(neg_errors) / float(len(negatives))) + \ (sum(pos_errors)) / float(len(positives))) / 2.0 - logger.info('Training set HTER: %.2f%%', hter_train) + logger.info('Training set HTER: %.2f%%', 100*hter_train) logger.info(' Errors on negatives: %d / %d', sum(neg_errors), len(negatives)) logger.info(' Errors on positives: %d / %d', sum(pos_errors), len(positives)) @@ -241,7 +252,8 @@ def main(user_input=None): pos_errors = pos_output < -threshold hter_train = ((sum(neg_errors) / float(len(negatives))) + \ (sum(pos_errors)) / float(len(positives))) / 2.0 - logger.info('Training set HTER (threshold=%g): %.2f%%', threshold, hter_train) + logger.info('Training set HTER (threshold=%g): %.2f%%', threshold, + 100*hter_train) logger.info(' Errors on negatives: %d / %d', sum(neg_errors), len(negatives)) logger.info(' Errors on positives: %d / %d', sum(pos_errors), len(positives)) # plot separation threshold diff --git a/bob/bio/vein/script/validate.py b/bob/bio/vein/script/validate.py index 93d0c5b02109454b8dafd5b09c1848bd79a33dfe..e404f592f1b915a379b934da79b3dbd22232d303 100644 --- a/bob/bio/vein/script/validate.py +++ b/bob/bio/vein/script/validate.py @@ -170,3 +170,79 @@ def open_multipage_pdf_file(s): import matplotlib.pyplot as mpl from matplotlib.backends.backend_pdf import PdfPages return PdfPages(s) + + +class validate_protocol(object): + '''Validates the protocol for a given database + + + Parameters: + + name (str): The name of the database to validate the protocol for + + + Raises: + + schema.SchemaError: if the database is not supported + + ''' + + def __init__(self, name): + + self.dbname = name + + if name == 'fv3d': + import bob.db.fv3d + self.valid_names = bob.db.fv3d.Database().protocol_names() + elif name == 'verafinger': + import bob.db.verafinger + self.valid_names = bob.db.verafinger.Database().protocol_names() + else: + raise schema.SchemaError("do not support database {}".format(name)) + + + def __call__(self, name): + + if name not in self.valid_names: + msg = "{} is not a valid protocol for database {}" + raise schema.SchemaError(msg.format(name, self.dbname)) + + return True + + +class validate_group(object): + '''Validates the group for a given database + + + Parameters: + + name (str): The name of the database to validate the group for + + + Raises: + + schema.SchemaError: if the database is not supported + + ''' + + def __init__(self, name): + + self.dbname = name + + if name == 'fv3d': + import bob.db.fv3d + self.valid_names = bob.db.fv3d.Database().groups() + elif name == 'verafinger': + import bob.db.verafinger + self.valid_names = bob.db.verafinger.Database().groups() + else: + raise schema.SchemaError("do not support database {}".format(name)) + + + def __call__(self, name): + + if name not in self.valid_names: + msg = "{} is not a valid group for database {}" + raise schema.SchemaError(msg.format(name, self.dbname)) + + return True