Commit 7c54afe1 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Revert passing Sample objects to Transformers

parent 0b4640c7
Pipeline #44788 passed with stage
in 81 minutes and 47 seconds
......@@ -18,5 +18,5 @@ class Callable(Annotator):
super(Callable, self).__init__(**kwargs)
self.callable = callable
def annotate(self, sample, **kwargs):
def transform(self, sample, **kwargs):
return self.callable(sample, **kwargs)
......@@ -34,33 +34,35 @@ class FailSafe(Annotator):
self.required_keys = list(required_keys)
self.only_required_keys = only_required_keys
def transform(self, samples, **kwargs):
for sample in samples:
if 'annotations' not in kwargs or kwargs['annotations'] is None:
kwargs['annotations'] = {}
def transform(self, sample_batch, **kwargs):
if 'annotations' not in kwargs or kwargs['annotations'] is None:
kwargs['annotations'] = {}
all_annotations = []
for sample in sample_batch:
annotations = kwargs['annotations'].copy()
for annotator in self.annotators:
try:
sample = annotator([sample], **kwargs)[0]
annot = annotator([sample], **kwargs)[0]
except Exception:
logger.debug(
"The annotator `%s' failed to annotate!", annotator,
exc_info=True)
sample.annotations = None
if not sample.annotations:
annot = None
if not annot:
logger.debug(
"Annotator `%s' returned empty annotations.", annotator)
else:
logger.debug("Annotator `%s' succeeded!", annotator)
kwargs['annotations'].update(sample.annotations or {})
annotations.update(annot or {})
# check if we have all the required annotations
if all(key in kwargs['annotations'] for key in self.required_keys):
if all(key in annotations for key in self.required_keys):
break
else: # this else is for the for loop
# we don't want to return half of the annotations
kwargs['annotations'] = None
annotations = None
if self.only_required_keys:
for key in list(kwargs['annotations'].keys()):
for key in list(annotations.keys()):
if key not in self.required_keys:
del kwargs['annotations'][key]
sample.annotations = kwargs['annotations']
return samples
del annotations[key]
all_annotations.append(annotations)
return all_annotations
from bob.pipelines import CheckpointWrapper, SampleSet
from bob.pipelines.wrappers import _frmt
from os.path import dirname, isfile, expanduser, join
from os import makedirs
import logging
import json
logger = logging.getLogger(__name__)
class SaveAnnotationsWrapper(CheckpointWrapper):
"""
A specialization of bob.pipelines.CheckpointWrapper that saves annotations.
Saves :py:attr:`~bob.pipelines.Sample.annotations` to the disk instead of
:py:attr:`~bob.pipelines.Sample.data` (default in
:py:class:`~bob.pipelines.CheckpointWrapper`).
The annotations of each sample will be "dumped" with json in a file
corresponding to the one in the original dataset (following the same path
structure, ie. using the :py:attr:`~bob.pipelines.Sample.key` attribute of
each sample).
Parameters
----------
estimator: Annotator Transformer
Transformer that places samples annotations in
:py:attr:`~bob.pipelines.Sample.annotations`.
annotations_dir: str
The root path where the annotations will be saved.
extension: str
The extension of the annotations files [default: ``.json``].
save_func: function
The function used to save each sample [default: :py:func:`json.dump`].
overwrite: bool
when ``True``, will overwrite any existing files. Otherwise, will skip
samples when an annotation file with the same ``key`` exists.
"""
def __init__(
self,
estimator,
annotations_dir,
extension=".json",
save_func=None,
overwrite=False,
**kwargs,
):
save_func = save_func or self._save_json
super(SaveAnnotationsWrapper, self).__init__(
estimator=estimator,
features_dir=annotations_dir,
extension=extension,
save_func=save_func,
**kwargs,
)
self.overwrite = overwrite
def save(self, sample):
"""
Saves one sample's annotations to a file on disk.
Overrides :py:meth:`bob.pipelines.CheckpointWrapper.save`
Parameters
----------
sample: :py:class:`~bob.pipelines.Sample`
One sample containing an :py:attr:`~bob.pipelinessSample.annotations`
attribute.
"""
path = self.make_path(sample)
makedirs(dirname(path), exist_ok=True)
try:
self.save_func(sample.annotations, path)
except Exception as e:
raise RuntimeError(
f"Could not save annotations of {sample}\n"
f"(annotations are: {sample.annotations})\n"
f"during {self}.save"
) from e
def _checkpoint_transform(self, samples, method_name):
"""
Checks if a transform needs to be saved to the disk.
Overrides :py:meth:`bob.pipelines.CheckpointWrapper._checkpoint_transform`
"""
# Transform either samples or samplesets
method = getattr(self.estimator, method_name)
logger.debug(f"{_frmt(self)}.{method_name}")
# if features_dir is None, just transform all samples at once
if self.features_dir is None:
return method(samples)
def _transform_samples(samples):
paths = [self.make_path(s) for s in samples]
should_compute_list = [
p is None or not isfile(p) or self.overwrite
for p in paths
]
skipped_count = len([s for s in should_compute_list if s==False])
if skipped_count != 0:
logger.info(f"Skipping {skipped_count} already existing files.")
# call method on non-checkpointed samples
non_existing_samples = [
s
for s, should_compute in zip(samples, should_compute_list)
if should_compute
]
# non_existing_samples could be empty
computed_features = []
if non_existing_samples:
computed_features = method(non_existing_samples)
# return computed features and checkpointed features
features, com_feat_index = [], 0
for s, p, should_compute in zip(samples, paths, should_compute_list):
if should_compute:
feat = computed_features[com_feat_index]
com_feat_index += 1
# save the computed feature
if p is not None:
self.save(feat)
feat = self.load(s, p)
features.append(feat)
else:
features.append(self.load(s, p))
return features
if isinstance(samples[0], SampleSet):
return [SampleSet(_transform_samples(s.samples), parent=s) for s in samples]
else:
return _transform_samples(samples)
def _save_json(self, annot, path):
"""
Saves the annotations in json format in the file ``path``.
This is the default ``save_func`` if it is not passed as parameters of
:py:class:`~bob.bio.base.annotator.SaveAnnotationsWrapper`.
Parameters
----------
annot: dict
Any dictionary (containing annotations for example).
path: str
A filename pointing in an existing directory.
"""
logger.debug(f"Writing annotations '{annot}' to file '{path}'.")
with open(path, "w") as f:
json.dump(annot, f, indent=1, allow_nan=False)
\ No newline at end of file
from .Annotator import Annotator
from .FailSafe import FailSafe
from .Callable import Callable
from .SaveAnnotationsWrapper import SaveAnnotationsWrapper
# gets sphinx autodoc done right - don't remove it
......@@ -27,7 +26,6 @@ __appropriate__(
Annotator,
FailSafe,
Callable,
SaveAnnotationsWrapper,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -2,6 +2,7 @@
"""
import logging
import click
import json
import functools
from bob.extension.scripts.click_helper import (
verbosity_option,
......@@ -10,9 +11,11 @@ from bob.extension.scripts.click_helper import (
log_parameters,
)
from bob.pipelines import wrap, ToDaskBag
from bob.bio.base.annotator import SaveAnnotationsWrapper
logger = logging.getLogger(__name__)
def save_annotations_to_json(data, path):
with open(path, "w") as f:
json.dump(data, f)
def annotate_common_options(func):
@click.option(
......@@ -92,11 +95,16 @@ def annotate(
"""
log_parameters(logger)
# Wrapping that will save each sample at {output_dir}/{sample.key}.json
annotator = SaveAnnotationsWrapper(
annotator,
annotations_dir=output_dir,
overwrite=force,
# Allows passing of Sample objects as parameters
annotator = wrap(["sample"], annotator)
# Will save the annotations in the `data` fields to a json file
annotator = wrap(
bases=["checkpoint"],
estimator=annotator,
features_dir=output_dir,
save_func=save_annotations_to_json,
extension=".json",
)
# Allows reception of Dask Bags
......
from random import random
from bob.bio.base.annotator import FailSafe, Annotator
from bob.bio.base.annotator import FailSafe, Callable
class SimpleAnnotator(Annotator):
def transform(self, samples, **kwargs):
for sample in samples:
sample.annotations = {
def simple_annotator(image_batch, **kwargs):
all_annotations = []
for image in image_batch:
all_annotations.append(
{
'topleft': (0, 0),
'bottomright': sample.data.shape,
'bottomright': image.shape,
}
return samples
)
return all_annotations
class MoodyAnnotator(Annotator):
def transform(self, samples, **kwargs):
for sample in samples:
sample.annotations = {'topleft': (0,0)}
if random() > 0.5:
sample.annotations['bottomright'] = sample.data.shape
return samples
def moody_annotator(image_batch, **kwargs):
all_annotations = simple_annotator(image_batch, **kwargs)
for annot in all_annotations:
if random() < 0.5:
del annot['bottomright']
return all_annotations
class FailAnnotator(Annotator):
def transform(self, samples, **kwargs):
return {}
def fail_annotator(image_batch, **kwargs):
all_annotations = []
for image in image_batch:
all_annotations.append({})
return all_annotations
annotator = FailSafe(
[FailAnnotator(),
SimpleAnnotator()],
[Callable(fail_annotator),
Callable(simple_annotator)],
required_keys=['topleft', 'bottomright'],
)
......@@ -3,9 +3,8 @@ import os
import shutil
from click.testing import CliRunner
from bob.bio.base.script.annotate import annotate
from bob.bio.base.annotator import Annotator, FailSafe
from bob.bio.base.annotator import Callable, FailSafe
from bob.db.base import read_annotation_file
from bob.pipelines import Sample
def test_annotate():
......@@ -14,12 +13,12 @@ def test_annotate():
tmp_dir = tempfile.mkdtemp(prefix="bobtest_")
runner = CliRunner()
result = runner.invoke(annotate, args=(
'-d', 'dummy', '-a', 'dummy', '-g', 'dev', '-o', tmp_dir))
'-d', 'dummy', '-g', 'dev', '-a', 'dummy', '-o', tmp_dir))
assertion_error_message = (
'Command exited with this output: `{}\' \n'
'If the output is empty, you can run this script locally to see '
'what is wrong:\n'
'bin/bob bio annotate -vvv --force -d dummy -a dummy -o /tmp/temp_annotations'
'bin/bob bio annotate -vvv --force -d dummy -g dev -a dummy -o /tmp/temp_annotations'
''.format(result.output))
assert result.exit_code == 0, assertion_error_message
......@@ -35,23 +34,18 @@ def test_annotate():
shutil.rmtree(tmp_dir)
class dummy_extra_key_annotator(Annotator):
def transform(self, samples, **kwargs):
for s in samples:
s.annotations = {'leye': 0, 'reye': 0, 'topleft': 0}
return samples
def dummy_extra_key_annotator(data_batch, **kwargs):
return [{'leye': 0, 'reye': 0, 'topleft': 0}]
def test_failsafe():
annotator = FailSafe([dummy_extra_key_annotator()], ['leye', 'reye'])
samples = [Sample(data=1, key="dummy_sample")]
annotated = annotator(samples)[0]
assert all(key in annotated.annotations for key in ['leye', 'reye', 'topleft'])
annotator = FailSafe([dummy_extra_key_annotator()], ['leye', 'reye'], True)
samples = [Sample(data=1, key="dummy_sample")]
annotated = annotator(samples)[0]
assert all(key in annotated.annotations for key in ['leye', 'reye'])
samples = [Sample(data=1, key="dummy_sample")]
annotated = annotator(samples)[0]
assert all(key not in annotated.annotations for key in ['topleft'])
annotator = FailSafe([Callable(dummy_extra_key_annotator)],
['leye', 'reye'])
annotations = annotator([1])
assert all(key in annotations[0] for key in ['leye', 'reye', 'topleft'])
annotator = FailSafe([Callable(dummy_extra_key_annotator)],
['leye', 'reye'], True)
annotations = annotator([1])
assert all(key in annotations[0] for key in ['leye', 'reye'])
assert all(key not in annotations[0] for key in ['topleft'])
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment