Commit b1e1beb5 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'db-all-samples' into 'master'

Add a method to retrieve all the samples of a dataset

Closes #146

See merge request !217
parents 97cc2b19 2e2a7d97
Pipeline #46051 passed with stages
in 8 minutes and 55 seconds
......@@ -10,7 +10,9 @@ import functools
from abc import ABCMeta, abstractmethod
import numpy as np
import itertools
import logging
logger = logging.getLogger(__name__)
class CSVBaseSampleLoader(metaclass=ABCMeta):
"""
......@@ -316,6 +318,29 @@ class CSVDatasetDevEval:
group=group, purpose="probe", group_by_subject=False
)
def all_samples(self, groups=None):
"""
Reads and returns all the samples in `groups`.
Parameters
----------
groups: list or None
Groups to consider, or all groups if `None` is given.
"""
# Get train samples (background_model_samples returns a list of samples)
samples = self.background_model_samples()
# Get enroll and probe samples
groups = ["dev", "eval"] if not groups else groups
if "eval" in groups and (not self.eval_enroll_csv or not self.eval_probe_csv):
logger.warning("'eval' requested, but dataset has no 'eval' group.")
groups.remove("eval")
for group in groups:
for purpose in ("enroll", "probe"):
label = f"{group}_{purpose}_csv"
samples = samples + self.csv_to_sample_loader(self.__dict__[label])
return samples
class CSVDatasetCrossValidation:
"""
......@@ -446,6 +471,28 @@ class CSVDatasetCrossValidation:
def probes(self, group="dev"):
return self._load_from_cache("dev_probe_csv")
def all_samples(self, groups=None):
"""
Reads and returns all the samples in `groups`.
Parameters
----------
groups: list or None
Groups to consider, or all groups if `None` is given.
"""
# Get train samples (background_model_samples returns a list of samples)
samples = self.background_model_samples()
# Get enroll and probe samples
groups = ["dev"] if not groups else groups
if "eval" in groups:
logger.info("'eval' requested but there is no 'eval' group defined.")
groups.remove("eval")
for group in groups:
samples = samples+ [s for s_set in self.references(group) for s in s_set]
samples = samples+ [s for s_set in self.probes(group) for s in s_set]
return samples
def group_samples_by_subject(samples):
......
......@@ -671,12 +671,15 @@ class ZTBioDatabase(BioDatabase):
files = self.objects(protocol=self.protocol, groups=groups, **self.all_files_options)
# add all files that belong to the ZT-norm
if add_zt_files:
if add_zt_files and groups:
for group in groups:
if group == 'world':
continue
files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None)
files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options)
elif add_zt_files:
files += self.tobjects(protocol=self.protocol, groups=groups, model_ids=None)
files += self.zobjects(protocol=self.protocol, groups=groups, **self.z_probe_options)
return self.sort(files)
@abc.abstractmethod
......
......@@ -311,6 +311,23 @@ class Database(metaclass=ABCMeta):
"""
pass
@abstractmethod
def all_samples(self, groups=None):
"""Returns all the samples of the dataset
Parameters
----------
groups: list or `None`
List of groups to consider (like 'dev' or 'eval'). If `None`, will
return samples from all the groups.
Returns
-------
samples: list
List of all the samples of the dataset.
"""
pass
class ScoreWriter(metaclass=ABCMeta):
"""
......
......@@ -178,6 +178,24 @@ class DatabaseConnector(Database):
return list(probes.values())
def all_samples(self, groups=None):
"""Returns all the legacy database files in Sample format
Parameters
----------
groups: list or `None`
List of groups to consider (like 'dev' or 'eval'). If `None`, will
return samples from all the groups.
Returns
-------
samples: list
List of all the samples of a database, conforming to the pipeline
API. See, e.g., :py:func:`bob.pipelines.first`.
"""
objects = self.database.all_files(groups=groups)
return [_biofile_to_delayed_sample(k, self.database) for k in objects]
class BioAlgorithmLegacy(BioAlgorithm):
"""Biometric Algorithm that handles :py:class:`bob.bio.base.algorithm.Algorithm`
......
......@@ -120,20 +120,9 @@ def annotate(database, groups, annotator, output_dir, dask_client, **kwargs):
# Transformer that splits the samples into several Dask Bags
to_dask_bags = ToDaskBag(npartitions=50)
logger.debug("Retrieving background model samples from database.")
background_model_samples = database.background_model_samples()
logger.debug("Retrieving references and probes samples from database.")
references_samplesets = []
probes_samplesets = []
for group in groups:
references_samplesets.extend(database.references(group=group))
probes_samplesets.extend(database.probes(group=group))
# Unravels all samples in one list (no SampleSets)
samples = background_model_samples
samples.extend([sample for r in references_samplesets for sample in r.samples])
samples.extend([sample for p in probes_samplesets for sample in p.samples])
logger.debug("Retrieving samples from database.")
samples = database.all_samples(groups)
# Sets the scheduler to local if no dask_client is specified
if dask_client is not None:
......
......@@ -10,7 +10,8 @@ Very simple tests for Implementations
import os
from bob.bio.base.database import BioDatabase, ZTBioDatabase
from bob.bio.base.test.dummy.database import database as dummy_database
from bob.pipelines import DelayedSample
def check_database(database, groups=('dev',), protocol=None, training_depends=False, models_depend=False, skip_train=False, check_zt=False):
database_legacy = database.database
......@@ -51,3 +52,8 @@ def check_database_zt(database, groups=('dev', 'eval'), protocol=None, training_
assert database_legacy.client_id_from_model_id(t_model_ids[0], group) is not None
assert len(database_legacy.t_enroll_files(t_model_ids[0], group)) > 0
assert len(database_legacy.z_probe_files(group)) > 0
def test_all_samples():
all_samples = dummy_database.all_samples(groups=None)
assert len(all_samples) == 400
assert all([isinstance(s, DelayedSample) for s in all_samples])
......@@ -107,6 +107,9 @@ def test_csv_file_list_dev_eval():
assert len(dataset.probes(group="eval")) == 13
assert check_all_true(dataset.probes(group="eval"), SampleSet)
assert len(dataset.all_samples(groups=None)) == 49
assert check_all_true(dataset.all_samples(groups=None), DelayedSample)
def test_csv_file_list_atnt():
......@@ -114,6 +117,32 @@ def test_csv_file_list_atnt():
assert len(dataset.background_model_samples()) == 200
assert len(dataset.references()) == 20
assert len(dataset.probes()) == 100
assert len(dataset.all_samples(groups=["dev"])) == 400
assert len(dataset.all_samples(groups=None)) == 400
def data_loader(path):
import bob.io.image
return bob.io.base.load(path)
def test_csv_cross_validation_atnt():
dataset = CSVDatasetCrossValidation(
csv_file_name=atnt_protocol_path_cross_validation,
random_state=0,
test_size=0.8,
csv_to_sample_loader=CSVToSampleLoader(
data_loader=data_loader,
dataset_original_directory=atnt_database_directory(),
extension=".pgm",
),
)
assert len(dataset.background_model_samples()) == 80
assert len(dataset.references("dev")) == 32
assert len(dataset.probes("dev")) == 288
assert len(dataset.all_samples(groups=None)) == 400
def run_experiment(dataset):
......@@ -131,12 +160,6 @@ def run_experiment(dataset):
)
def data_loader(path):
import bob.io.image
return bob.io.base.load(path)
def test_atnt_experiment():
dataset = CSVDatasetDevEval(
......@@ -160,7 +183,7 @@ def test_atnt_experiment_cross_validation():
total_identities = 40
samples_for_enrollment = 1
def run_cross_validataion_experiment(test_size=0.9):
def run_cross_validation_experiment(test_size=0.9):
dataset = CSVDatasetCrossValidation(
csv_file_name=atnt_protocol_path_cross_validation,
random_state=0,
......@@ -179,9 +202,9 @@ def test_atnt_experiment_cross_validation():
* (samples_per_identity - samples_for_enrollment)
)
run_cross_validataion_experiment(test_size=0.9)
run_cross_validataion_experiment(test_size=0.8)
run_cross_validataion_experiment(test_size=0.5)
run_cross_validation_experiment(test_size=0.9)
run_cross_validation_experiment(test_size=0.8)
run_cross_validation_experiment(test_size=0.5)
####
......
......@@ -167,8 +167,9 @@ This will create a database interface with:
- The elements in ``train.csv`` returned by :py:meth:`~bob.db.base.Database.background_model_samples`,
- The elements in ``*_enroll.csv`` returned by :py:meth:`~bob.db.base.Database.references`,
- The elements in ``*_probe.csv`` returned by :py:meth:`~bob.db.base.Database.probes`,
- The elements in ``*_probe.csv`` returned by :py:meth:`~bob.db.base.Database.probes`.
An aggregation of all of the above is available with the :py:meth:`~bob.db.base.Database.all_samples` method, which returns all the samples of the protocol.
.. _bob.bio.base.database.csv_cross_validation:
......@@ -226,6 +227,7 @@ When a vanilla-biometrics pipeline requests data from that class, it will call t
The group parameter (*dev* or *eval*) can be given to specify from which set of individuals the data comes.
Each :py:class:`~bob.pipelines.SampleSet` must contain a :py:attr:`~bob.pipelines.SampleSet.subject`, a :py:attr:`~bob.pipelines.SampleSet.references` list, and a list of :py:attr:`~bob.pipelines.Sample` containing at least the :py:attr:`~bob.pipelines.Sample.key` attribute as well as the :py:attr:`~bob.pipelines.Sample.data` of the sample.
Furthermore, the :py:meth:`~bob.db.base.Database.all_samples` method must return a list of all the existing samples in the dataset. This functionality is used for annotating a whole dataset.
Here is a code snippet of a simple database interface:
......@@ -235,10 +237,10 @@ Here is a code snippet of a simple database interface:
class CustomDatabase:
def background_model_samples(self):
world_samples = []
for a_sample in dataset_world_subjects:
world_samples.append( Sample(data=a_sample.data, key=a_sample.sample_id) )
return world_samples
train_samples = []
for a_sample in dataset_train_subjects:
train_samples.append( Sample(data=a_sample.data, key=a_sample.sample_id) )
return train_samples
def references(self, group="dev"):
all_references = []
......@@ -258,6 +260,13 @@ Here is a code snippet of a simple database interface:
all_probes.append(current_sampleset)
return all_probes
def all_samples(self, group=None):
all_subjects = dataset_train_subjects + dataset_dev_subjects
all_samples = []
for a_sample in all_subjects:
all_samples.append( Sample(data=a_sample.data, key=a_sample.sample_id) )
return all_samples
allow_scoring_with_all_biometric_references = True
database = CustomDatabase()
......@@ -303,7 +312,7 @@ When doing so, the output of each :py:class:`Transformer` of the pipeline will b
.. WARNING::
You have to be careful when using checkpoints: If you modify an early step of an experiment, the created checkpoints are not valid anymore, but the system has no way of knowing that.
**You** have to take care of removing invalid checkpoints files.
When changing the pipeline or the dataset of an experiment, you should change the output folder (``-o``) accordingly. Otherwise, the system could try to load a checkpoint of an older experiment, or samples from another dataset.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment