diff --git a/bob/bio/base/database/database.py b/bob/bio/base/database/database.py index 4fe518172266864987031a87d092c895b2593fa8..7620ef507ed1008c7224af6f04074883106d933f 100644 --- a/bob/bio/base/database/database.py +++ b/bob/bio/base/database/database.py @@ -151,6 +151,55 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)): return "%s(%s)" % (str(self.__class__), params) + def replace_directories(self, replacements=None): + """This helper function replaces the ``original_directory`` and the ``annotation_directory`` of the database with the directories read from the given replacement file. + + This function is provided for convenience, so that the database configuration files do not need to be modified. + Instead, this function uses the given dictionary of replacements to change the original directory and the original extension (if given). + + The given ``replacements`` can be of type ``dict``, including all replacements, or a file name (as a ``str``), in which case the file is read. + The structure of the file should be: + + .. code-block:: text + + # Comments starting with # and empty lines are ignored + + [YOUR_..._DATA_DIRECTORY] = /path/to/your/data + [YOUR_..._ANNOTATION_DIRECTORY] = /path/to/your/annotations + + If no annotation files are available (e.g. when they are stored inside the ``database``), the annotation directory can be left out. + + **Parameters:** + + replacements : dict or str + A dictionary with replacements, or a name of a file to read the dictionary from. + If the file name does not exist, no directories are replaced. + """ + if replacements is None: + return + if isinstance(replacements, str): + if not os.path.exists(replacements): + return + # Open the database replacement file and reads its content + with open(replacements) as f: + replacements = {} + for line in f: + if line.strip() and not line.startswith("#"): + splits = line.split("=") + assert len(splits) == 2 + replacements[splits[0].strip()] = splits[1].strip() + + assert isinstance(replacements, dict) + + if self.original_directory in replacements: + self.original_directory = replacements[self.original_directory] + + try: + if self.annotation_directory in replacements: + self.annotation_directory = replacements[self.annotation_directory] + except AttributeError: + pass + ########################################################################### # Helper functions that you might want to use in derived classes ########################################################################### @@ -188,31 +237,6 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)): files_by_clients.append(client_files[client]) return files_by_clients - def annotations(self, file): - """ - Returns the annotations for the given File object, if available. - It uses `bob.db.base.read_annotation_file` to load the annotations. - - **Parameters:** - - file : :py:class:`bob.bio.base.database.BioFile` - The file for which annotations should be returned. - - **Returns:** - - annots : dict or None - The annotations for the file, if available. - """ - if self.annotation_directory: - try: - from bob.db.base.annotations import read_annotation_file - annotation_path = os.path.join(self.annotation_directory, file.path + self.annotation_extension) - return read_annotation_file(annotation_path, self.annotation_type) - except ImportError as e: - raise NotImplementedError(str(e) + " Annotations are not read." % e) - - return None - def file_names(self, files, directory, extension): """file_names(files, directory, extension) -> paths @@ -266,23 +290,6 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)): """ raise NotImplementedError("Please implement this function in derived classes") - def model_ids(self, groups='dev'): - """model_ids(group = 'dev') -> ids - - Returns a list of model ids for the given group, respecting the current protocol. - - **Parameters:** - - group : one of ``('dev', 'eval')`` - The group to get the model ids for. - - **Returns:** - - ids : [int] or [str] - The list of (unique) model ids for models of the given group. - """ - return sorted(self.model_ids_with_protocol(groups=groups, protocol=self.protocol)) - @abc.abstractmethod def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs): """This function returns lists of File objects, which fulfill the given restrictions. @@ -311,10 +318,45 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)): """ raise NotImplementedError("This function must be implemented in your derived class.") + @abc.abstractmethod + def annotations(self, file): + """ + Returns the annotations for the given File object, if available. + It uses `bob.db.base.read_annotation_file` to load the annotations. + + **Parameters:** + + file : :py:class:`bob.bio.base.database.BioFile` + The file for which annotations should be returned. + + **Returns:** + + annots : dict or None + The annotations for the file, if available. + """ + raise NotImplementedError("This function must be implemented in your derived class.") + ################################################################# ######### Methods to provide common functionality ############### ################################################################# + def model_ids(self, groups='dev'): + """model_ids(group = 'dev') -> ids + + Returns a list of model ids for the given group, respecting the current protocol. + + **Parameters:** + + group : one of ``('dev', 'eval')`` + The group to get the model ids for. + + **Returns:** + + ids : [int] or [str] + The list of (unique) model ids for models of the given group. + """ + return sorted(self.model_ids_with_protocol(groups=groups, protocol=self.protocol)) + def all_files(self, groups=None): """all_files(groups=None) -> files diff --git a/bob/bio/base/test/dummy/database.py b/bob/bio/base/test/dummy/database.py index e8cfac92044055f84a343cae0a226b92a9fb94f8..9c84e23bdbb79f0caf9bd26d15cd8060419fb665 100644 --- a/bob/bio/base/test/dummy/database.py +++ b/bob/bio/base/test/dummy/database.py @@ -16,16 +16,16 @@ class DummyDatabase(ZTBioDatabase): models_depend_on_protocol=False ) import bob.db.atnt - self.__db = bob.db.atnt.Database() + self._db = bob.db.atnt.Database() def _make_bio(self, files): return [BioFile(client_id=f.client_id, path=f.path, file_id=f.id) for f in files] def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs): - return self.__db.model_ids(groups, protocol) + return self._db.model_ids(groups, protocol) def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs): - return self._make_bio(self.__db.objects(model_ids, groups, purposes, protocol, **kwargs)) + return self._make_bio(self._db.objects(model_ids, groups, purposes, protocol, **kwargs)) def tobjects(self, groups=None, protocol=None, model_ids=None, **kwargs): return [] @@ -34,7 +34,7 @@ class DummyDatabase(ZTBioDatabase): return [] def tmodel_ids_with_protocol(self, protocol=None, groups=None, **kwargs): - return self.__db.model_ids(groups) + return self._db.model_ids(groups) def t_enroll_files(self, t_model_id, group='dev'): return self.enroll_files(t_model_id, group) @@ -42,4 +42,8 @@ class DummyDatabase(ZTBioDatabase): def z_probe_files(self, group='dev'): return self.probe_files(None, group) + def annotations(self, file): + return None + + database = DummyDatabase() diff --git a/bob/bio/base/test/dummy/fileset.py b/bob/bio/base/test/dummy/fileset.py index 4f2c6461b09b5bd15b20975394848cc4918fa288..0d2a7faf489c369e627fd82921e7a8689c0e2639 100644 --- a/bob/bio/base/test/dummy/fileset.py +++ b/bob/bio/base/test/dummy/fileset.py @@ -15,7 +15,7 @@ class DummyDatabase(ZTBioDatabase): models_depend_on_protocol=False ) import bob.db.atnt - self.__db = bob.db.atnt.Database() + self._db = bob.db.atnt.Database() def uses_probe_file_sets(self): return True @@ -31,10 +31,10 @@ class DummyDatabase(ZTBioDatabase): return file_sets def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs): - return self.__db.model_ids(groups, protocol) + return self._db.model_ids(groups, protocol) def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs): - return self._make_bio(self.__db.objects(model_ids, groups, purposes, protocol, **kwargs)) + return self._make_bio(self._db.objects(model_ids, groups, purposes, protocol, **kwargs)) def tobjects(self, groups=None, protocol=None, model_ids=None, **kwargs): return [] @@ -43,7 +43,7 @@ class DummyDatabase(ZTBioDatabase): return [] def tmodel_ids_with_protocol(self, protocol=None, groups=None, **kwargs): - return self.__db.model_ids(groups) + return self._db.model_ids(groups) def t_enroll_files(self, t_model_id, group='dev'): return self.enroll_files(t_model_id, group) @@ -54,4 +54,8 @@ class DummyDatabase(ZTBioDatabase): def z_probe_file_sets(self, group='dev'): return self.probe_file_sets(None, group) + def annotations(self, file): + return None + + database = DummyDatabase() diff --git a/bob/bio/base/test/dummy/preprocessor.py b/bob/bio/base/test/dummy/preprocessor.py index fd32e3ce9776efa53e1b638070688b866094b376..7ccfcdcb7bd60750203bcf883b4d38ed1b973cb4 100644 --- a/bob/bio/base/test/dummy/preprocessor.py +++ b/bob/bio/base/test/dummy/preprocessor.py @@ -1,5 +1,6 @@ from bob.bio.base.preprocessor import Preprocessor + class DummyPreprocessor (Preprocessor): def __init__(self, return_none=False, **kwargs): Preprocessor.__init__(self) @@ -11,4 +12,5 @@ class DummyPreprocessor (Preprocessor): return None return data + preprocessor = DummyPreprocessor()