Skip to content
Snippets Groups Projects
Commit afdf3f43 authored by Tim Laibacher's avatar Tim Laibacher
Browse files

Add metrics compare cli

parent 7af97acc
No related branches found
No related tags found
No related merge requests found
Pipeline #29697 failed
# see https://docs.python.org/3/library/pkgutil.html
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
\ No newline at end of file
......@@ -160,7 +160,6 @@ def do_inference(
logger.info("Saving average over all input images: {}".format(metrics_file))
avg_metrics = df_metrics.groupby('threshold').mean()
avg_metrics["model_name"] = model.name
avg_metrics.to_csv(metrics_path)
avg_metrics["f1_score"] = 2* avg_metrics["precision"]*avg_metrics["recall"]/ \
......@@ -175,7 +174,7 @@ def do_inference(
np_avg_metrics = avg_metrics.to_numpy().T
fig_name = "precision_recall.pdf".format(model.name)
logger.info("saving {}".format(fig_name))
fig = precision_recall_f1iso([np_avg_metrics[0]],[np_avg_metrics[1]], np_avg_metrics[-1])
fig = precision_recall_f1iso([np_avg_metrics[0]],[np_avg_metrics[1]], [model.name,None])
fig_filename = os.path.join(results_subfolder, fig_name)
fig.savefig(fig_filename)
......
......@@ -25,6 +25,8 @@ from bob.ip.binseg.utils.checkpointer import DetectronCheckpointer
from torch.utils.data import DataLoader
from bob.ip.binseg.engine.trainer import do_train
from bob.ip.binseg.engine.inferencer import do_inference
from bob.ip.binseg.utils.plot import plot_overview
from bob.ip.binseg.utils.click import OptionEatAll
logger = logging.getLogger(__name__)
......@@ -275,4 +277,28 @@ def testcheckpoints(model
# checkpointer, load last model in dir
checkpointer = DetectronCheckpointer(model, save_dir = output_subfolder, save_to_disk=False)
checkpointer.load(checkpoint)
do_inference(model, data_loader, device, output_subfolder)
\ No newline at end of file
do_inference(model, data_loader, device, output_subfolder)
# Plot comparison
@binseg.command(entry_point_group='bob.ip.binseg.config', cls=ConfigCommand)
@click.option(
'--output-path-list',
'-l',
required=True,
help='Pass all output paths as arguments',
cls=OptionEatAll,
)
@click.option(
'--output-path',
'-o',
required=True,
)
@verbosity_option(cls=ResourceOption)
def compare(output_path_list, output_path,**kwargs):
""" Compares multiple metrics files that are stored in the format mymodel/results/Metrics.csv """
logger.debug("Output paths: {}".format(output_path_list))
logger.info('Plotting precision vs recall curves for {}'.format(output_path_list))
fig = plot_overview(output_path_list)
fig_filename = os.path.join(output_path, 'precision_recall_comparison.pdf')
logger.info('saving {}'.format(fig_filename))
fig.savefig(fig_filename)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# https://stackoverflow.com/questions/48391777/nargs-equivalent-for-options-in-click
import click
class OptionEatAll(click.Option):
def __init__(self, *args, **kwargs):
self.save_other_options = kwargs.pop('save_other_options', True)
nargs = kwargs.pop('nargs', -1)
assert nargs == -1, 'nargs, if set, must be -1 not {}'.format(nargs)
super(OptionEatAll, self).__init__(*args, **kwargs)
self._previous_parser_process = None
self._eat_all_parser = None
def add_to_parser(self, parser, ctx):
def parser_process(value, state):
# method to hook to the parser.process
done = False
value = [value]
if self.save_other_options:
# grab everything up to the next option
while state.rargs and not done:
for prefix in self._eat_all_parser.prefixes:
if state.rargs[0].startswith(prefix):
done = True
if not done:
value.append(state.rargs.pop(0))
else:
# grab everything remaining
value += state.rargs
state.rargs[:] = []
value = tuple(value)
# call the actual process
self._previous_parser_process(value, state)
retval = super(OptionEatAll, self).add_to_parser(parser, ctx)
for name in self.opts:
our_parser = parser._long_opt.get(name) or parser._short_opt.get(name)
if our_parser:
self._eat_all_parser = our_parser
self._previous_parser_process = our_parser.process
our_parser.process = parser_process
break
return retval
\ No newline at end of file
......@@ -5,6 +5,8 @@
# author_email='andre.anjos@idiap.ch',
import numpy as np
import os
import csv
def precision_recall_f1iso(precision, recall, names, title=None, human_perf_bsds500=False):
'''Creates a precision-recall plot of the given data.
......@@ -115,4 +117,55 @@ def loss_curve(df):
ax1.set_xlabel('epoch')
plt.tight_layout()
fig = ax1.get_figure()
return fig
\ No newline at end of file
return fig
def read_metricscsv(file):
"""
Read precision and recall from csv file
Arguments
---------
file: str
path to file
Returns
-------
precision : :py:class:`np.ndarray`
recall : :py:class:`np.ndarray`
"""
with open (file, "r") as infile:
metricsreader = csv.reader(infile)
# skip header row
next(metricsreader)
precision = []
recall = []
for row in metricsreader:
precision.append(float(row[1]))
recall.append(float(row[2]))
return np.array(precision), np.array(recall)
def plot_overview(outputfolders):
"""
Plots comparison chart of all trained models
Arguments
---------
outputfolder : list
list containing output paths of all evaluated models (e.g. ['output/model1', 'output/model2'])
Returns
-------
fig : matplotlib.figure.Figure
"""
precisions = []
recalls = []
names = []
for folder in outputfolders:
metrics_path = os.path.join(folder,'results/Metrics.csv')
pr, re = read_metricscsv(metrics_path)
precisions.append(pr)
recalls.append(re)
names.append(folder)
fig = precision_recall_f1iso(precisions,recalls,names)
return fig
......@@ -6,6 +6,7 @@ package:
version: {{ environ.get('BOB_PACKAGE_VERSION', '0.0.1') }}
build:
skip: true # [not linux]
number: {{ environ.get('BOB_BUILD_NUMBER', 0) }}
run_exports:
- {{ pin_subpackage(name) }}
......@@ -38,6 +39,7 @@ requirements:
- {{ pin_compatible('numpy') }}
- pandas
- matplotlib
- tqdm
- bob.db.drive
- bob.db.stare
- bob.db.chasedb1
......@@ -56,6 +58,8 @@ test:
# test commands ("script" entry-points) from your package here
- bob binseg --help
- bob binseg train --help
- bob binseg test --help
- bob binseg compare --help
- nosetests --with-coverage --cover-package={{ name }} -sv {{ name }}
- sphinx-build -aEW {{ project_dir }}/doc {{ project_dir }}/sphinx
- sphinx-build -aEb doctest {{ project_dir }}/doc sphinx
......
File added
......@@ -13,4 +13,5 @@ bob.db.iostar
torch
torchvision
pandas
matplotlib
\ No newline at end of file
matplotlib
tqdm
\ No newline at end of file
......@@ -48,6 +48,7 @@ setup(
'bob.ip.binseg.cli': [
'train = bob.ip.binseg.script.binseg:train',
'test = bob.ip.binseg.script.binseg:test',
'compare = bob.bin.binseg.script.binseg:compare',
'testcheckpoints = bob.ip.binseg.script.binseg:testcheckpoints',
],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment