Skip to content
Snippets Groups Projects
Commit 6cb749d8 authored by Manuel Günther's avatar Manuel Günther
Browse files

Implemented open set recognition rate based on threshold

parent f3e64f71
Branches
Tags
No related merge requests found
...@@ -61,6 +61,7 @@ def command_line_arguments(command_line_parameters): ...@@ -61,6 +61,7 @@ def command_line_arguments(command_line_parameters):
parser.add_argument('-m', '--mindcf', action = 'store_true', help = "If given, minDCF will be computed.") parser.add_argument('-m', '--mindcf', action = 'store_true', help = "If given, minDCF will be computed.")
parser.add_argument('--cost', default=0.99, help='Cost for FAR in minDCF') parser.add_argument('--cost', default=0.99, help='Cost for FAR in minDCF')
parser.add_argument('-r', '--rr', action = 'store_true', help = "If given, the Recognition Rate will be computed.") parser.add_argument('-r', '--rr', action = 'store_true', help = "If given, the Recognition Rate will be computed.")
parser.add_argument('-t', '--thresholds', type=float, nargs='+', help = "If given, the Recognition Rate will incorporate an Open Set handling, rejecting all scores that are below the given threshold; when multiple thresholds are given, they are applied in the same order as the --dev-files.")
parser.add_argument('-l', '--legends', nargs='+', help = "A list of legend strings used for ROC, CMC and DET plots; if given, must be the same number than --dev-files.") parser.add_argument('-l', '--legends', nargs='+', help = "A list of legend strings used for ROC, CMC and DET plots; if given, must be the same number than --dev-files.")
parser.add_argument('-F', '--legend-font-size', type=int, default=18, help = "Set the font size of the legends.") parser.add_argument('-F', '--legend-font-size', type=int, default=18, help = "Set the font size of the legends.")
parser.add_argument('-P', '--legend-position', type=int, help = "Set the font size of the legends.") parser.add_argument('-P', '--legend-position', type=int, help = "Set the font size of the legends.")
...@@ -93,6 +94,14 @@ def command_line_arguments(command_line_parameters): ...@@ -93,6 +94,14 @@ def command_line_arguments(command_line_parameters):
if len(args.dev_files) != len(args.legends): if len(args.dev_files) != len(args.legends):
logger.error("The number of --dev-files (%d) and --legends (%d) are not identical", len(args.dev_files), len(args.legends)) logger.error("The number of --dev-files (%d) and --legends (%d) are not identical", len(args.dev_files), len(args.legends))
if args.thresholds is not None:
if len(args.thresholds) == 1:
args.thresholds = args.thresholds * len(args.dev_files)
elif len(args.thresholds) != len(args.dev_files):
logger.error("If given, the number of --thresholds imust be either 1, or the same as --dev-files (%d), but it is %d", len(args.dev_files), len(args.thresholds))
else:
args.thresholds = [None] * len(args.dev_files)
return args return args
...@@ -314,24 +323,24 @@ def main(command_line_parameters=None): ...@@ -314,24 +323,24 @@ def main(command_line_parameters=None):
if args.eval_files: if args.eval_files:
cmcs_eval = [cmc_parser(os.path.join(args.directory, f)) for f in args.eval_files] cmcs_eval = [cmc_parser(os.path.join(args.directory, f)) for f in args.eval_files]
if args.cmc: if args.cmc:
logger.info("Plotting CMC curves to file '%s'", args.cmc) logger.info("Plotting CMC curves to file '%s'", args.cmc)
try: try:
# create a multi-page PDF for the ROC curve # create a multi-page PDF for the ROC curve
pdf = PdfPages(args.cmc) pdf = PdfPages(args.cmc)
# create a separate figure for dev and eval # create a separate figure for dev and eval
pdf.savefig(_plot_cmc(cmcs_dev, colors, args.legends, "CMC curve for development set", args.legend_font_size, args.legend_position)) pdf.savefig(_plot_cmc(cmcs_dev, colors, args.legends, "CMC curve for development set", args.legend_font_size, args.legend_position))
if args.eval_files: if args.eval_files:
pdf.savefig(_plot_cmc(cmcs_eval, colors, args.legends, "CMC curve for evaluation set", args.legend_font_size, args.legend_position)) pdf.savefig(_plot_cmc(cmcs_eval, colors, args.legends, "CMC curve for evaluation set", args.legend_font_size, args.legend_position))
pdf.close() pdf.close()
except RuntimeError as e: except RuntimeError as e:
raise RuntimeError("During plotting of ROC curves, the following exception occured:\n%s\nUsually this happens when the label contains characters that LaTeX cannot parse." % e) raise RuntimeError("During plotting of ROC curves, the following exception occured:\n%s\nUsually this happens when the label contains characters that LaTeX cannot parse." % e)
if args.rr: if args.rr:
logger.info("Computing recognition rate on the development " + ("and on the evaluation set" if args.eval_files else "set")) logger.info("Computing recognition rate on the development " + ("and on the evaluation set" if args.eval_files else "set"))
for i in range(len(cmcs_dev)): for i in range(len(cmcs_dev)):
rr = bob.measure.recognition_rate(cmcs_dev[i]) rr = bob.measure.recognition_rate(cmcs_dev[i], args.thresholds[i])
print("The Recognition Rate of the development set of '%s' is %2.3f%%" % (args.legends[i], rr * 100.))
if args.eval_files:
rr = bob.measure.recognition_rate(cmcs_eval[i])
print("The Recognition Rate of the development set of '%s' is %2.3f%%" % (args.legends[i], rr * 100.)) print("The Recognition Rate of the development set of '%s' is %2.3f%%" % (args.legends[i], rr * 100.))
if args.eval_files:
rr = bob.measure.recognition_rate(cmcs_eval[i], args.thresholds[i])
print("The Recognition Rate of the development set of '%s' is %2.3f%%" % (args.legends[i], rr * 100.))
...@@ -306,6 +306,8 @@ def test_evaluate(): ...@@ -306,6 +306,8 @@ def test_evaluate():
'--roc', plots[0], '--roc', plots[0],
'--det', plots[1], '--det', plots[1],
'--cmc', plots[2], '--cmc', plots[2],
'--rr',
'--thresholds', '5000', '0',
'-v', '-v',
] ]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment