Skip to content
Snippets Groups Projects
Commit 570ea2a9 authored by Amir Mohammadi's avatar Amir Mohammadi
Browse files

makes the annotations method implementation mandatory

parent edae49e1
No related branches found
No related tags found
1 merge request!56makes the annotations method implementation mandatory
Pipeline #
......@@ -151,6 +151,55 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
return "%s(%s)" % (str(self.__class__), params)
def replace_directories(self, replacements=None):
"""This helper function replaces the ``original_directory`` and the ``annotation_directory`` of the database with the directories read from the given replacement file.
This function is provided for convenience, so that the database configuration files do not need to be modified.
Instead, this function uses the given dictionary of replacements to change the original directory and the original extension (if given).
The given ``replacements`` can be of type ``dict``, including all replacements, or a file name (as a ``str``), in which case the file is read.
The structure of the file should be:
.. code-block:: text
# Comments starting with # and empty lines are ignored
[YOUR_..._DATA_DIRECTORY] = /path/to/your/data
[YOUR_..._ANNOTATION_DIRECTORY] = /path/to/your/annotations
If no annotation files are available (e.g. when they are stored inside the ``database``), the annotation directory can be left out.
**Parameters:**
replacements : dict or str
A dictionary with replacements, or a name of a file to read the dictionary from.
If the file name does not exist, no directories are replaced.
"""
if replacements is None:
return
if isinstance(replacements, str):
if not os.path.exists(replacements):
return
# Open the database replacement file and reads its content
with open(replacements) as f:
replacements = {}
for line in f:
if line.strip() and not line.startswith("#"):
splits = line.split("=")
assert len(splits) == 2
replacements[splits[0].strip()] = splits[1].strip()
assert isinstance(replacements, dict)
if self.original_directory in replacements:
self.original_directory = replacements[self.original_directory]
try:
if self.annotation_directory in replacements:
self.annotation_directory = replacements[self.annotation_directory]
except AttributeError:
pass
###########################################################################
# Helper functions that you might want to use in derived classes
###########################################################################
......@@ -188,31 +237,6 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
files_by_clients.append(client_files[client])
return files_by_clients
def annotations(self, file):
"""
Returns the annotations for the given File object, if available.
It uses `bob.db.base.read_annotation_file` to load the annotations.
**Parameters:**
file : :py:class:`bob.bio.base.database.BioFile`
The file for which annotations should be returned.
**Returns:**
annots : dict or None
The annotations for the file, if available.
"""
if self.annotation_directory:
try:
from bob.db.base.annotations import read_annotation_file
annotation_path = os.path.join(self.annotation_directory, file.path + self.annotation_extension)
return read_annotation_file(annotation_path, self.annotation_type)
except ImportError as e:
raise NotImplementedError(str(e) + " Annotations are not read." % e)
return None
def file_names(self, files, directory, extension):
"""file_names(files, directory, extension) -> paths
......@@ -266,23 +290,6 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
"""
raise NotImplementedError("Please implement this function in derived classes")
def model_ids(self, groups='dev'):
"""model_ids(group = 'dev') -> ids
Returns a list of model ids for the given group, respecting the current protocol.
**Parameters:**
group : one of ``('dev', 'eval')``
The group to get the model ids for.
**Returns:**
ids : [int] or [str]
The list of (unique) model ids for models of the given group.
"""
return sorted(self.model_ids_with_protocol(groups=groups, protocol=self.protocol))
@abc.abstractmethod
def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs):
"""This function returns lists of File objects, which fulfill the given restrictions.
......@@ -311,10 +318,45 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
"""
raise NotImplementedError("This function must be implemented in your derived class.")
@abc.abstractmethod
def annotations(self, file):
"""
Returns the annotations for the given File object, if available.
It uses `bob.db.base.read_annotation_file` to load the annotations.
**Parameters:**
file : :py:class:`bob.bio.base.database.BioFile`
The file for which annotations should be returned.
**Returns:**
annots : dict or None
The annotations for the file, if available.
"""
raise NotImplementedError("This function must be implemented in your derived class.")
#################################################################
######### Methods to provide common functionality ###############
#################################################################
def model_ids(self, groups='dev'):
"""model_ids(group = 'dev') -> ids
Returns a list of model ids for the given group, respecting the current protocol.
**Parameters:**
group : one of ``('dev', 'eval')``
The group to get the model ids for.
**Returns:**
ids : [int] or [str]
The list of (unique) model ids for models of the given group.
"""
return sorted(self.model_ids_with_protocol(groups=groups, protocol=self.protocol))
def all_files(self, groups=None):
"""all_files(groups=None) -> files
......
......@@ -16,16 +16,16 @@ class DummyDatabase(ZTBioDatabase):
models_depend_on_protocol=False
)
import bob.db.atnt
self.__db = bob.db.atnt.Database()
self._db = bob.db.atnt.Database()
def _make_bio(self, files):
return [BioFile(client_id=f.client_id, path=f.path, file_id=f.id) for f in files]
def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs):
return self.__db.model_ids(groups, protocol)
return self._db.model_ids(groups, protocol)
def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs):
return self._make_bio(self.__db.objects(model_ids, groups, purposes, protocol, **kwargs))
return self._make_bio(self._db.objects(model_ids, groups, purposes, protocol, **kwargs))
def tobjects(self, groups=None, protocol=None, model_ids=None, **kwargs):
return []
......@@ -34,7 +34,7 @@ class DummyDatabase(ZTBioDatabase):
return []
def tmodel_ids_with_protocol(self, protocol=None, groups=None, **kwargs):
return self.__db.model_ids(groups)
return self._db.model_ids(groups)
def t_enroll_files(self, t_model_id, group='dev'):
return self.enroll_files(t_model_id, group)
......@@ -42,4 +42,8 @@ class DummyDatabase(ZTBioDatabase):
def z_probe_files(self, group='dev'):
return self.probe_files(None, group)
def annotations(self, file):
return None
database = DummyDatabase()
......@@ -15,7 +15,7 @@ class DummyDatabase(ZTBioDatabase):
models_depend_on_protocol=False
)
import bob.db.atnt
self.__db = bob.db.atnt.Database()
self._db = bob.db.atnt.Database()
def uses_probe_file_sets(self):
return True
......@@ -31,10 +31,10 @@ class DummyDatabase(ZTBioDatabase):
return file_sets
def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs):
return self.__db.model_ids(groups, protocol)
return self._db.model_ids(groups, protocol)
def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs):
return self._make_bio(self.__db.objects(model_ids, groups, purposes, protocol, **kwargs))
return self._make_bio(self._db.objects(model_ids, groups, purposes, protocol, **kwargs))
def tobjects(self, groups=None, protocol=None, model_ids=None, **kwargs):
return []
......@@ -43,7 +43,7 @@ class DummyDatabase(ZTBioDatabase):
return []
def tmodel_ids_with_protocol(self, protocol=None, groups=None, **kwargs):
return self.__db.model_ids(groups)
return self._db.model_ids(groups)
def t_enroll_files(self, t_model_id, group='dev'):
return self.enroll_files(t_model_id, group)
......@@ -54,4 +54,8 @@ class DummyDatabase(ZTBioDatabase):
def z_probe_file_sets(self, group='dev'):
return self.probe_file_sets(None, group)
def annotations(self, file):
return None
database = DummyDatabase()
from bob.bio.base.preprocessor import Preprocessor
class DummyPreprocessor (Preprocessor):
def __init__(self, return_none=False, **kwargs):
Preprocessor.__init__(self)
......@@ -11,4 +12,5 @@ class DummyPreprocessor (Preprocessor):
return None
return data
preprocessor = DummyPreprocessor()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment