Commit a9ff0696 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[annotators] nit

parent 64bc0eb0
......@@ -11,8 +11,10 @@ from bob.extension.scripts.click_helper import (
log_parameters,
)
from bob.pipelines import wrap, ToDaskBag, DelayedSample
logger = logging.getLogger(__name__)
def save_json(data, path):
"""
Saves a dictionnary ``data`` in a json file at ``path``.
......@@ -20,6 +22,7 @@ def save_json(data, path):
with open(path, "w") as f:
json.dump(data, f)
def load_json(path):
"""
Returns a dictionnary from a json file at ``path``.
......@@ -27,6 +30,7 @@ def load_json(path):
with open(path, "r") as f:
return json.load(f)
def annotate_common_options(func):
@click.option(
"--annotator",
......@@ -35,7 +39,7 @@ def annotate_common_options(func):
cls=ResourceOption,
entry_point_group="bob.bio.annotator",
help="An annotator (instance of class inheriting from "
"bob.bio.base.Annotator) or an annotator resource name.",
"bob.bio.base.Annotator) or an annotator resource name.",
)
@click.option(
"--output-dir",
......@@ -50,7 +54,7 @@ def annotate_common_options(func):
"dask_client",
entry_point_group="dask.client",
help="Dask client for the execution of the pipeline. If not specified, "
"uses a single threaded, local Dask Client.",
"uses a single threaded, local Dask Client.",
cls=ResourceOption,
)
@functools.wraps(func)
......@@ -76,7 +80,7 @@ Examples:
cls=ResourceOption,
entry_point_group="bob.bio.database",
help="Biometric Database (class that implements the methods: "
"`background_model_samples`, `references` and `probes`).",
"`background_model_samples`, `references` and `probes`).",
)
@click.option(
"--groups",
......@@ -88,9 +92,7 @@ Examples:
)
@annotate_common_options
@verbosity_option(cls=ResourceOption)
def annotate(
database, groups, annotator, output_dir, dask_client, **kwargs
):
def annotate(database, groups, annotator, output_dir, dask_client, **kwargs):
"""Annotates a database.
The annotations are written in text file (json) format which can be read
......@@ -103,8 +105,8 @@ def annotate(
# Will save the annotations in the `data` fields to a json file
annotator = wrap(
bases=["checkpoint"],
estimator=annotator,
["checkpoint"],
annotator,
features_dir=output_dir,
extension=".json",
save_func=save_json,
......@@ -118,7 +120,6 @@ def annotate(
# Transformer that splits the samples into several Dask Bags
to_dask_bags = ToDaskBag(npartitions=50)
logger.debug("Retrieving background model samples from database.")
background_model_samples = database.background_model_samples()
......@@ -131,22 +132,14 @@ def annotate(
# Unravels all samples in one list (no SampleSets)
samples = background_model_samples
samples.extend([
sample
for r in references_samplesets
for sample in r.samples
])
samples.extend([
sample
for p in probes_samplesets
for sample in p.samples
])
samples.extend([sample for r in references_samplesets for sample in r.samples])
samples.extend([sample for p in probes_samplesets for sample in p.samples])
# Sets the scheduler to local if no dask_client is specified
if dask_client is not None:
scheduler=dask_client
scheduler = dask_client
else:
scheduler="single-threaded"
scheduler = "single-threaded"
# Splits the samples list into bags
dask_bags = to_dask_bags.transform(samples)
......@@ -241,14 +234,14 @@ def annotate_samples(
to_dask_bags = ToDaskBag(npartitions=50)
if dask_client is not None:
scheduler=dask_client
scheduler = dask_client
else:
scheduler="single-threaded"
scheduler = "single-threaded"
# Converts samples into a list of DelayedSample objects
samples_obj = [
DelayedSample(
load=functools.partial(reader,s),
load=functools.partial(reader, s),
key=make_key(s),
)
for s in samples
......
......@@ -5,6 +5,7 @@ from click.testing import CliRunner
from bob.bio.base.script.annotate import annotate, annotate_samples
from bob.bio.base.annotator import Callable, FailSafe
from bob.db.base import read_annotation_file
from bob.extension.scripts.click_helper import assert_click_runner_result
def test_annotate():
......@@ -14,16 +15,10 @@ def test_annotate():
runner = CliRunner()
result = runner.invoke(annotate, args=(
'-d', 'dummy', '-g', 'dev', '-a', 'dummy', '-o', tmp_dir))
assertion_error_message = (
'Command exited with this output: `{}\' \n'
'If the output is empty, you can run this script locally to see '
'what is wrong:\n'
'bin/bob bio annotate -vvv -d dummy -g dev -a dummy -o /tmp/temp_annotations'
''.format(result.output))
assert result.exit_code == 0, assertion_error_message
assert_click_runner_result(result)
# test if annotations exist
for dirpath, dirnames, filenames in os.walk(tmp_dir):
for dirpath, _, filenames in os.walk(tmp_dir):
for filename in filenames:
path = os.path.join(dirpath, filename)
annot = read_annotation_file(path, 'json')
......@@ -40,13 +35,7 @@ def test_annotate_samples():
runner = CliRunner()
result = runner.invoke(annotate_samples, args=(
'dummy_samples', '-a', 'dummy', '-o', tmp_dir))
assertion_error_message = (
'Command exited with this output: `{}\' \n'
'If the output is empty, you can run this script locally to see '
'what is wrong:\n'
'bin/bob bio annotate-samples -vvv dummy_samples -a dummy -o /tmp/temp_annotations'
''.format(result.output))
assert result.exit_code == 0, assertion_error_message
assert_click_runner_result(result)
# test if annotations exist
for dirpath, dirnames, filenames in os.walk(tmp_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment