Commit 20bf6fe4 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

[db] Add a method to get all samples of a dataset.

Add it to the legacy database connector
parent f22059b3
......@@ -317,6 +317,23 @@ class Database(metaclass=ABCMeta):
"""
pass
@abstractmethod
def all_samples(self, groups=None):
"""Returns all the samples of the dataset
Parameters
----------
groups: list or `None`
List of groups to consider (like 'dev' or 'eval'). If `None`, will
return samples from all the groups.
Returns
-------
samples: list
List of all the samples of the dataset.
"""
pass
class ScoreWriter(metaclass=ABCMeta):
"""
......
......@@ -178,6 +178,24 @@ class DatabaseConnector(Database):
return list(probes.values())
def all_samples(self, groups=None):
"""Returns all the legacy database files in Sample format
Parameters
----------
groups: list or `None`
List of groups to consider (like 'dev' or 'eval'). If `None`, will
return samples from all the groups.
Returns
-------
samples: list
List of all the samples of a database, conforming to the pipeline
API. See, e.g., :py:func:`bob.pipelines.first`.
"""
objects = self.database.all_files(groups=groups)
return [_biofile_to_delayed_sample(k, self.database) for k in objects]
class BioAlgorithmLegacy(BioAlgorithm):
"""Biometric Algorithm that handles :py:class:`bob.bio.base.algorithm.Algorithm`
......
......@@ -119,28 +119,8 @@ def annotate(
to_dask_bags = ToDaskBag(npartitions=50)
logger.debug("Retrieving background model samples from database.")
background_model_samples = database.background_model_samples()
logger.debug("Retrieving references and probes samples from database.")
references_samplesets = []
probes_samplesets = []
for group in groups:
references_samplesets.extend(database.references(group=group))
probes_samplesets.extend(database.probes(group=group))
# Unravels all samples in one list (no SampleSets)
samples = background_model_samples
samples.extend([
sample
for r in references_samplesets
for sample in r.samples
])
samples.extend([
sample
for p in probes_samplesets
for sample in p.samples
])
logger.debug("Retrieving samples from database.")
samples = database.all_samples(groups)
# Sets the scheduler to local if no dask_client is specified
if dask_client is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment