diff --git a/bob/ip/binseg/engine/__init__.py b/bob/ip/binseg/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca5e07cb73f0bdddcb863ef497955964087e301 --- /dev/null +++ b/bob/ip/binseg/engine/__init__.py @@ -0,0 +1,3 @@ +# see https://docs.python.org/3/library/pkgutil.html +from pkgutil import extend_path +__path__ = extend_path(__path__, __name__) \ No newline at end of file diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py index 650583fc64c97d8db03686721f7303e5cdab56ed..c8166a4729153335eafd7e15afbd1aa7385afe7b 100644 --- a/bob/ip/binseg/engine/inferencer.py +++ b/bob/ip/binseg/engine/inferencer.py @@ -160,7 +160,6 @@ def do_inference( logger.info("Saving average over all input images: {}".format(metrics_file)) avg_metrics = df_metrics.groupby('threshold').mean() - avg_metrics["model_name"] = model.name avg_metrics.to_csv(metrics_path) avg_metrics["f1_score"] = 2* avg_metrics["precision"]*avg_metrics["recall"]/ \ @@ -175,7 +174,7 @@ def do_inference( np_avg_metrics = avg_metrics.to_numpy().T fig_name = "precision_recall.pdf".format(model.name) logger.info("saving {}".format(fig_name)) - fig = precision_recall_f1iso([np_avg_metrics[0]],[np_avg_metrics[1]], np_avg_metrics[-1]) + fig = precision_recall_f1iso([np_avg_metrics[0]],[np_avg_metrics[1]], [model.name,None]) fig_filename = os.path.join(results_subfolder, fig_name) fig.savefig(fig_filename) diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index 0db77074075b814fb3b140a6cabdea008d63c901..d57406f707c752a06af384c1cd009c383eb15659 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -25,6 +25,8 @@ from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer from torch.utils.data import DataLoader from bob.ip.binseg.engine.trainer import do_train from bob.ip.binseg.engine.inferencer import do_inference +from bob.ip.binseg.utils.plot import plot_overview +from bob.ip.binseg.utils.click import OptionEatAll logger = logging.getLogger(__name__) @@ -275,4 +277,28 @@ def testcheckpoints(model # checkpointer, load last model in dir checkpointer = DetectronCheckpointer(model, save_dir = output_subfolder, save_to_disk=False) checkpointer.load(checkpoint) - do_inference(model, data_loader, device, output_subfolder) \ No newline at end of file + do_inference(model, data_loader, device, output_subfolder) + +# Plot comparison +@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand) +@click.option( + '--output-path-list', + '-l', + required=True, + help='Pass all output paths as arguments', + cls=OptionEatAll, + ) +@click.option( + '--output-path', + '-o', + required=True, + ) +@verbosity_option(cls=ResourceOption) +def compare(output_path_list, output_path,**kwargs): + """ Compares multiple metrics files that are stored in the format mymodel/results/Metrics.csv """ + logger.debug("Output paths: {}".format(output_path_list)) + logger.info('Plotting precision vs recall curves for {}'.format(output_path_list)) + fig = plot_overview(output_path_list) + fig_filename = os.path.join(output_path, 'precision_recall_comparison.pdf') + logger.info('saving {}'.format(fig_filename)) + fig.savefig(fig_filename) diff --git a/bob/ip/binseg/utils/click.py b/bob/ip/binseg/utils/click.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3582878cf1cc3d1634fcd66f416fa468448030 --- /dev/null +++ b/bob/ip/binseg/utils/click.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# https://stackoverflow.com/questions/48391777/nargs-equivalent-for-options-in-click + +import click + +class OptionEatAll(click.Option): + + def __init__(self, *args, **kwargs): + self.save_other_options = kwargs.pop('save_other_options', True) + nargs = kwargs.pop('nargs', -1) + assert nargs == -1, 'nargs, if set, must be -1 not {}'.format(nargs) + super(OptionEatAll, self).__init__(*args, **kwargs) + self._previous_parser_process = None + self._eat_all_parser = None + + def add_to_parser(self, parser, ctx): + + def parser_process(value, state): + # method to hook to the parser.process + done = False + value = [value] + if self.save_other_options: + # grab everything up to the next option + while state.rargs and not done: + for prefix in self._eat_all_parser.prefixes: + if state.rargs[0].startswith(prefix): + done = True + if not done: + value.append(state.rargs.pop(0)) + else: + # grab everything remaining + value += state.rargs + state.rargs[:] = [] + value = tuple(value) + + # call the actual process + self._previous_parser_process(value, state) + + retval = super(OptionEatAll, self).add_to_parser(parser, ctx) + for name in self.opts: + our_parser = parser._long_opt.get(name) or parser._short_opt.get(name) + if our_parser: + self._eat_all_parser = our_parser + self._previous_parser_process = our_parser.process + our_parser.process = parser_process + break + return retval \ No newline at end of file diff --git a/bob/ip/binseg/utils/plot.py b/bob/ip/binseg/utils/plot.py index 8dbdb1f710a79d521edbfea24aa11dd216e6f2f5..98c9adef42af9dc3555864923cfec8a55063d126 100644 --- a/bob/ip/binseg/utils/plot.py +++ b/bob/ip/binseg/utils/plot.py @@ -5,6 +5,8 @@ # author_email='andre.anjos@idiap.ch', import numpy as np +import os +import csv def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds500=False): '''Creates a precision-recall plot of the given data. @@ -115,4 +117,55 @@ def loss_curve(df): ax1.set_xlabel('epoch') plt.tight_layout() fig = ax1.get_figure() - return fig \ No newline at end of file + return fig + + +def read_metricscsv(file): + """ + Read precision and recall from csv file + + Arguments + --------- + file: str + path to file + + Returns + ------- + precision : :py:class:`np.ndarray` + recall : :py:class:`np.ndarray` + """ + with open (file, "r") as infile: + metricsreader = csv.reader(infile) + # skip header row + next(metricsreader) + precision = [] + recall = [] + for row in metricsreader: + precision.append(float(row[1])) + recall.append(float(row[2])) + return np.array(precision), np.array(recall) + + +def plot_overview(outputfolders): + """ + Plots comparison chart of all trained models + Arguments + --------- + outputfolder : list + list containing output paths of all evaluated models (e.g. ['output/model1', 'output/model2']) + + Returns + ------- + fig : matplotlib.figure.Figure + """ + precisions = [] + recalls = [] + names = [] + for folder in outputfolders: + metrics_path = os.path.join(folder,'results/Metrics.csv') + pr, re = read_metricscsv(metrics_path) + precisions.append(pr) + recalls.append(re) + names.append(folder) + fig = precision_recall_f1iso(precisions,recalls,names) + return fig diff --git a/conda/meta.yaml b/conda/meta.yaml index b85fcc209bee58e5baddff85524e69decb709df0..1e569ac4d7897703dd3a32cb0e2ceb885c478172 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -6,6 +6,7 @@ package: version: {{ environ.get('BOB_PACKAGE_VERSION', '0.0.1') }} build: + skip: true # [not linux] number: {{ environ.get('BOB_BUILD_NUMBER', 0) }} run_exports: - {{ pin_subpackage(name) }} @@ -38,6 +39,7 @@ requirements: - {{ pin_compatible('numpy') }} - pandas - matplotlib + - tqdm - bob.db.drive - bob.db.stare - bob.db.chasedb1 @@ -56,6 +58,8 @@ test: # test commands ("script" entry-points) from your package here - bob binseg --help - bob binseg train --help + - bob binseg test --help + - bob binseg compare --help - nosetests --with-coverage --cover-package={{ name }} -sv {{ name }} - sphinx-build -aEW {{ project_dir }}/doc {{ project_dir }}/sphinx - sphinx-build -aEb doctest {{ project_dir }}/doc sphinx diff --git a/metrics.pkl b/metrics.pkl deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/precision_recall_comparison.pdf b/precision_recall_comparison.pdf new file mode 100644 index 0000000000000000000000000000000000000000..1568e9c1a33e6421a2fd77af3bb6d720f1af156d Binary files /dev/null and b/precision_recall_comparison.pdf differ diff --git a/requirements.txt b/requirements.txt index ca8e629373cf02f9d9dd85245f1efb8f22198554..d2f0e4456468493efa7de752dcf6cd4a69feada2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ bob.db.iostar torch torchvision pandas -matplotlib \ No newline at end of file +matplotlib +tqdm \ No newline at end of file diff --git a/setup.py b/setup.py index b3423a5dd6a2931690d985eb0c3c5bbd8a1ff434..8c8bb0bfadff16e0e6de4a83c428ab648b34ffdf 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ setup( 'bob.ip.binseg.cli': [ 'train = bob.ip.binseg.script.binseg:train', 'test = bob.ip.binseg.script.binseg:test', + 'compare = bob.bin.binseg.script.binseg:compare', 'testcheckpoints = bob.ip.binseg.script.binseg:testcheckpoints', ],