Commit 42322c0a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'dask-annotators' into 'master'

Dask annotators

See merge request !202
parents f58b7d75 854eb190
Pipeline #45326 failed with stages
in 4 minutes and 23 seconds
from bob.bio.base import read_original_data as base_read
from sklearn.base import TransformerMixin, BaseEstimator
class Annotator(object):
class Annotator(TransformerMixin, BaseEstimator):
"""Annotator class for all annotators. This class is meant to be used in
conjunction with the bob bio annotate script.
Attributes
----------
read_original_data : callable
A function that loads the samples. The syntax is like
`bob.bio.base.read_original_data`.
conjunction with the bob bio annotate script or to be used in pipelines.
"""
def __init__(self, read_original_data=None, **kwargs):
super(Annotator, self).__init__(**kwargs)
self.read_original_data = read_original_data or base_read
def annotate(self, sample, **kwargs):
def transform(self, samples, **kwargs):
"""Annotates a sample and returns annotations in a dictionary.
Parameters
----------
sample : numpy.ndarray
The sample that is being annotated.
samples : numpy.ndarray
The samples that are being annotated.
**kwargs
The extra arguments that may be passed.
......@@ -35,6 +25,6 @@ class Annotator(object):
"""
raise NotImplementedError
# Alias call to annotate
def __call__(self, sample, **kwargs):
return self.annotate(sample, **kwargs)
# Alias call to transform
def __call__(self, samples, **kwargs):
return self.transform(samples, **kwargs)
......@@ -18,5 +18,5 @@ class Callable(Annotator):
super(Callable, self).__init__(**kwargs)
self.callable = callable
def annotate(self, sample, **kwargs):
def transform(self, sample, **kwargs):
return self.callable(sample, **kwargs)
......@@ -34,31 +34,35 @@ class FailSafe(Annotator):
self.required_keys = list(required_keys)
self.only_required_keys = only_required_keys
def annotate(self, sample, **kwargs):
def transform(self, sample_batch, **kwargs):
if 'annotations' not in kwargs or kwargs['annotations'] is None:
kwargs['annotations'] = {}
for annotator in self.annotators:
try:
annotations = annotator(sample, **kwargs)
except Exception:
logger.debug(
"The annotator `%s' failed to annotate!", annotator,
exc_info=True)
all_annotations = []
for sample in sample_batch:
annotations = kwargs['annotations'].copy()
for annotator in self.annotators:
try:
annot = annotator([sample], **kwargs)[0]
except Exception:
logger.debug(
"The annotator `%s' failed to annotate!", annotator,
exc_info=True)
annot = None
if not annot:
logger.debug(
"Annotator `%s' returned empty annotations.", annotator)
else:
logger.debug("Annotator `%s' succeeded!", annotator)
annotations.update(annot or {})
# check if we have all the required annotations
if all(key in annotations for key in self.required_keys):
break
else: # this else is for the for loop
# we don't want to return half of the annotations
annotations = None
if not annotations:
logger.debug(
"Annotator `%s' returned empty annotations.", annotator)
else:
logger.debug("Annotator `%s' succeeded!", annotator)
kwargs['annotations'].update(annotations or {})
# check if we have all the required annotations
if all(key in kwargs['annotations'] for key in self.required_keys):
break
else: # this else is for the for loop
# we don't want to return half of the annotations
kwargs['annotations'] = None
if self.only_required_keys:
for key in list(kwargs['annotations'].keys()):
if key not in self.required_keys:
del kwargs['annotations'][key]
return kwargs['annotations']
if self.only_required_keys:
for key in list(annotations.keys()):
if key not in self.required_keys:
del annotations[key]
all_annotations.append(annotations)
return all_annotations
......@@ -80,7 +80,7 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.FileDatabase)):
original_directory=None,
original_extension=None,
annotation_directory=None,
annotation_extension='.pos',
annotation_extension=None,
annotation_type=None,
protocol='Default',
training_depends_on_protocol=False,
......@@ -106,8 +106,8 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.FileDatabase)):
self._kwargs = {}
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension
self.annotation_type = annotation_type
self.annotation_extension = annotation_extension or ".json"
self.annotation_type = annotation_type or "json"
self.protocol = protocol
self.training_depends_on_protocol = training_depends_on_protocol
self.models_depend_on_protocol = models_depend_on_protocol
......@@ -338,7 +338,6 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.FileDatabase)):
"""
raise NotImplementedError("This function must be implemented in your derived class.")
@abc.abstractmethod
def annotations(self, file):
"""
Returns the annotations for the given File object, if available.
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
import bob.db.base
from bob.db.base.annotations import read_annotation_file
from bob.pipelines.sample import _ReprMixin
class BioFile(bob.db.base.File, _ReprMixin):
"""
A simple base class that defines basic properties of File object for the use
in verification experiments
class BioFile(bob.db.base.File):
"""A simple base class that defines basic properties of File object for the use in verification experiments
Parameters
Attributes
----------
client_id : object
The id of the client this file belongs to.
Its type depends on your implementation.
If you use an SQL database, this should be an SQL type like Integer or String.
file_id : object
see :py:class:`bob.db.base.File` constructor
client_id : str or int
The id of the client this file belongs to.
Its type depends on your implementation.
If you use an SQL database, this should be an SQL type like Integer or
String.
path : object
see :py:class:`bob.db.base.File` constructor
see :py:class:`bob.db.base.File` constructor
file_id : object
see :py:class:`bob.db.base.File` constructor
original_directory : str or None
The path to the original directory of the file
original_extension : str or None
The extension of the original files. This attribute is deprecated.
Please try to include the extension in the ``path`` attribute
annotation_directory : str or None
The path to the directory of the annotations
annotation_extension : str or None
The extension of annotation files. Default is ``.json``
annotation_type : str or None
The type of the annotation file, see
:any:`bob.db.base.annotations.read_annotation_file`. Default is
``json``.
"""
def __init__(self, client_id, path, file_id=None, **kwargs):
def __init__(
self,
client_id,
path,
file_id=None,
original_directory=None,
original_extension=None,
annotation_directory=None,
annotation_extension=None,
annotation_type=None,
**kwargs,
):
super(BioFile, self).__init__(path, file_id, **kwargs)
# just copy the information
self.client_id = client_id
"""The id of the client, to which this file belongs to."""
self.original_directory = original_directory
self.original_extension = original_extension
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension or ".json"
self.annotation_type = annotation_type or "json"
def load(self, original_directory=None, original_extension=None):
"""Loads the data at the specified location and using the given extension.
Override it if you need to load differently.
Parameters
----------
original_directory: str (optional)
The path to the root of the dataset structure.
If `None`, will try to use `self.original_directory`.
original_extension: str (optional)
The filename extension of every files in the dataset.
If `None`, will try to use `self.original_extension`.
Returns
-------
object
The loaded data (normally :py:class:`numpy.ndarray`).
"""
if original_directory is None:
original_directory = self.original_directory
if original_extension is None:
original_extension = self.original_extension
# get the path
path = self.make_path(
original_directory or "", original_extension or ""
)
return bob.io.base.load(path)
@property
def annotations(self):
path = self.make_path(self.annotation_directory or "", self.annotation_extension or "")
return read_annotation_file(path, annotation_type=self.annotation_type)
class BioFileSet(BioFile):
......@@ -41,11 +104,11 @@ class BioFileSet(BioFile):
----------
file_set_id : str or int
A unique ID that identifies the file set.
A unique ID that identifies the file set.
files : [:py:class:`bob.bio.base.database.BioFile`]
A non-empty list of BioFile objects that should be stored inside this file.
All files of that list need to have the same client ID.
A non-empty list of BioFile objects that should be stored inside this file.
All files of that list need to have the same client ID.
"""
def __init__(self, file_set_id, files, path=None, **kwargs):
......@@ -56,7 +119,9 @@ class BioFileSet(BioFile):
super(BioFileSet, self).__init__(
files[0].client_id,
"+".join(f.path for f in files) if path is None else path,
file_set_id, **kwargs)
file_set_id,
**kwargs,
)
# check that all files come from the same client
assert all(f.client_id == self.client_id for f in files)
......
......@@ -298,15 +298,16 @@ class ScoreWriter(metaclass=ABCMeta):
@abstractmethod
def write(self, sampleset, path):
pass
def post_process(self, score_paths, filename):
def _post_process(score_paths, filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
f = open(filename, "w")
for path in score_paths:
f.writelines(open(path).readlines())
with open(filename, "w") as f:
for path in score_paths:
with open(path) as f2:
f.writelines(f2.readlines())
return filename
import dask.bag
import dask
if isinstance(score_paths, dask.bag.Bag):
......
......@@ -136,7 +136,7 @@ class DatabaseConnector(Database):
return retval
def probes(self, group):
def probes(self, group="dev"):
"""Returns :py:class:`Probe`'s to score biometric references
......
......@@ -201,14 +201,14 @@ def check_valid_pipeline(vanilla_pipeline):
else:
raise ValueError(
f"VanillaBiometricsPipeline.transformer should be instance of either `sklearn.pipeline.Pipeline` or"
"sklearn.base.BaseEstimator, not {vanilla_pipeline.transformer}"
f"sklearn.base.BaseEstimator, not {vanilla_pipeline.transformer}"
)
## Checking the Biometric algorithm
if not isinstance(vanilla_pipeline.biometric_algorithm, BioAlgorithm):
raise ValueError(
f"VanillaBiometricsPipeline.biometric_algorithm should be instance of `BioAlgorithm`"
"not {vanilla_pipeline.biometric_algorithm}"
f"not {vanilla_pipeline.biometric_algorithm}"
)
return True
"""A script to help annotate databases.
"""
import logging
import json
import click
import json
import functools
from os.path import dirname, isfile, expanduser
from bob.extension.scripts.click_helper import (
verbosity_option,
ConfigCommand,
ResourceOption,
log_parameters,
)
from bob.io.base import create_directories_safe
from bob.pipelines import wrap, ToDaskBag, DelayedSample
logger = logging.getLogger(__name__)
def save_json(data, path):
"""
Saves a dictionnary ``data`` in a json file at ``path``.
"""
with open(path, "w") as f:
json.dump(data, f)
def indices(list_to_split, number_of_parallel_jobs, task_id=None):
"""This function returns the first and last index for the files for the current job ID.
If no job id is set (e.g., because a sub-job is executed locally), it simply returns all indices."""
if number_of_parallel_jobs is None or number_of_parallel_jobs == 1:
return None
# test if the 'SEG_TASK_ID' environment is set
sge_task_id = os.getenv('SGE_TASK_ID') if task_id is None else task_id
if sge_task_id is None:
# task id is not set, so this function is not called from a grid job
# hence, we process the whole list
return (0,len(list_to_split))
else:
job_id = int(sge_task_id) - 1
# compute number of files to be executed
number_of_objects_per_job = int(math.ceil(float(len(list_to_split) / float(number_of_parallel_jobs))))
start = job_id * number_of_objects_per_job
end = min((job_id + 1) * number_of_objects_per_job, len(list_to_split))
return (start, end)
def load_json(path):
"""
Returns a dictionnary from a json file at ``path``.
"""
with open(path, "r") as f:
return json.load(f)
def annotate_common_options(func):
@click.option(
......@@ -45,8 +34,8 @@ def annotate_common_options(func):
required=True,
cls=ResourceOption,
entry_point_group="bob.bio.annotator",
help="A callable that takes the database and a sample (biofile) "
"of the database and returns the annotations in a dictionary.",
help="A Transformer instance that takes a series of sample and returns "
"the modified samples with annotations as a dictionary.",
)
@click.option(
"--output-dir",
......@@ -56,18 +45,13 @@ def annotate_common_options(func):
help="The directory to save the annotations.",
)
@click.option(
"--force",
"-f",
is_flag=True,
"--dask-client",
"-l",
"dask_client",
entry_point_group="dask.client",
help="Dask client for the execution of the pipeline. If not specified, "
"uses a single threaded, local Dask Client.",
cls=ResourceOption,
help="Whether to overwrite existing annotations.",
)
@click.option(
"--array",
type=click.INT,
default=1,
cls=ResourceOption,
help="Use this option alongside gridtk to submit this script as an array job.",
)
@functools.wraps(func)
def wrapper(*args, **kwds):
......@@ -83,7 +67,6 @@ def annotate_common_options(func):
Examples:
$ bob bio annotate -vvv -d <database> -a <annotator> -o /tmp/annotations
$ jman submit --array 64 -- bob bio annotate ... --array 64
""",
)
@click.option(
......@@ -92,18 +75,21 @@ Examples:
required=True,
cls=ResourceOption,
entry_point_group="bob.bio.database",
help="""The database that you want to annotate.""",
help="Biometric Database (class that implements the methods: "
"`background_model_samples`, `references` and `probes`).",
)
@annotate_common_options
@click.option(
"--database-directories-file",
cls=ResourceOption,
default=expanduser("~/.bob_bio_databases.txt"),
help="(Deprecated) To support loading of old databases.",
"--groups",
"-g",
multiple=True,
default=["dev", "eval"],
show_default=True,
help="Biometric Database group that will be annotated.",
)
@annotate_common_options
@verbosity_option(cls=ResourceOption)
def annotate(
database, annotator, output_dir, force, array, database_directories_file, **kwargs
database, groups, annotator, output_dir, dask_client, **kwargs
):
"""Annotates a database.
......@@ -112,24 +98,66 @@ def annotate(
"""
log_parameters(logger)
# Some databases need their original_directory to be replaced
database.replace_directories(database_directories_file)
biofiles = database.objects(groups=None, protocol=database.protocol)
samples = sorted(biofiles)
def reader(biofile):
return annotator.read_original_data(
biofile, database.original_directory, database.original_extension
)
def make_path(biofile, output_dir):
return biofile.make_path(output_dir, ".json")
return annotate_generic(
samples, reader, make_path, annotator, output_dir, force, array
# Allows passing of Sample objects as parameters
annotator = wrap(["sample"], annotator, output_attribute="annotations")
# Will save the annotations in the `data` fields to a json file
annotator = wrap(
bases=["checkpoint"],
estimator=annotator,
features_dir=output_dir,
extension=".json",
save_func=save_json,
load_func=load_json,
sample_attribute="annotations",
)
# Allows reception of Dask Bags
annotator = wrap(["dask"], annotator)
# Transformer that splits the samples into several Dask Bags
to_dask_bags = ToDaskBag(npartitions=50)
logger.debug("Retrieving background model samples from database.")
background_model_samples = database.background_model_samples()
logger.debug("Retrieving references and probes samples from database.")
references_samplesets = []
probes_samplesets = []
for group in groups:
references_samplesets.extend(database.references(group=group))
probes_samplesets.extend(database.probes(group=group))
# Unravels all samples in one list (no SampleSets)
samples = background_model_samples
samples.extend([
sample
for r in references_samplesets
for sample in r.samples
])
samples.extend([
sample
for p in probes_samplesets
for sample in p.samples
])
# Sets the scheduler to local if no dask_client is specified
if dask_client is not None:
scheduler=dask_client
else:
scheduler="single-threaded"
logger.info(f"Saving annotations in {output_dir}.")
logger.info(f"Annotating {len(samples)} samples...")
dask_bags = to_dask_bags.transform(samples)
annotator.transform(dask_bags).compute(scheduler=scheduler)
if dask_client is not None:
logger.info("Shutdown workers...")
dask_client.shutdown()
logger.info("Done.")
@click.command(
entry_point_group="bob.bio.config",
......@@ -138,94 +166,102 @@ def annotate(
Examples:
$ bob bio annotate-samples -vvv config.py -a <annotator> -o /tmp/annotations
$ jman submit --array 64 -- bob bio annotate-samples ... --array 64
You have to define samples, reader, and make_path in a python file (config.py) as in
examples.
You have to define ``samples``, ``reader``, and ``make_key`` in python files
(config.py) as in examples.
""",
)
@click.option(
"--samples",
entry_point_group="bob.bio.config",
required=True,
cls=ResourceOption,
help="A list of all samples that you want to annotate. The list must be sorted or "
"deterministic in consequent calls. This is needed so that this script works "
"correctly on the grid.",
help="A list of all samples that you want to annotate. They will be passed "
"as is to the ``reader`` and ``make-key`` functions.",
)
@click.option(
"--reader",
required=True,
cls=ResourceOption,
help="A function with the signature of ``data = reader(sample)`` which takes a "
"sample and returns the loaded data. The data is given to the annotator.",
help="A function with the signature of ``data = reader(sample)`` which "
"takes a sample and returns the loaded data. The returned data is given to "
"the annotator.",
)
@click.option(
"--make-path",
"--make-key",
required=True,
cls=ResourceOption,
help="A function with the signature of ``path = make_path(sample, output_dir)`` "
"which takes a sample and output_dir and returns the unique path for that sample "
"to be saved in output_dir. The extension of the path must be '.json'.",
help="A function with the signature of ``key = make_key(sample)`` which "
"takes a sample and returns a unique str identifier for that sample that "
"will be use to save it in output_dir. ``key`` generally is the relative "
"path to a sample's file from the dataset's root directory.",
)
@annotate_common_options
@verbosity_option(cls=ResourceOption)
def annotate_samples(
samples, reader, make_path, annotator, output_dir, force, array, **kwargs
samples, reader, make_key, annotator, output_dir, dask_client, **kwargs
):
"""Annotates a list of samples.
This command is very similar to ``bob bio annotate`` except that it works without a
database interface. You only need to provide a list of **sorted** samples to be
annotated and two functions::
This command is very similar to ``bob bio annotate`` except that it works
without a database interface. You must provide a list of samples as well as
two functions:
def reader(sample):
# load data from sample here
# Loads data from a sample.
# for example:
data = bob.io.base.load(sample)
# data will be given to the annotator
return data