Commit 193cf6fc authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Handle 'train' group in database.all_samples()

Check parameters with existing function
Add tests for train group
parent a2499e03
......@@ -4,6 +4,7 @@
import os
from bob.pipelines import Sample, DelayedSample, SampleSet
from bob.db.base.utils import check_parameters_for_validity
import csv
import bob.io.base
import functools
......@@ -325,16 +326,34 @@ class CSVDatasetDevEval:
Parameters
----------
groups: list or None
Groups to consider, or all groups if `None` is given.
Groups to consider ('train', 'dev', and/or 'eval'). If `None` is
given, returns the samples from all groups.
Returns
-------
samples: list
List of :class:`bob.pipelines.Sample` objects.
"""
valid_groups = ["train"]
if self.dev_enroll_csv and self.dev_probe_csv:
valid_groups.append("dev")
if self.eval_enroll_csv and self.eval_probe_csv:
valid_groups.append("eval")
groups = check_parameters_for_validity(
parameters=groups,
parameter_description="groups",
valid_parameters=valid_groups,
default_parameters=valid_groups,
)
samples = []
# Get train samples (background_model_samples returns a list of samples)
samples = self.background_model_samples()
if "train" in groups:
samples = samples + self.background_model_samples()
groups.remove("train")
# Get enroll and probe samples
groups = ["dev", "eval"] if not groups else groups
if "eval" in groups and (not self.eval_enroll_csv or not self.eval_probe_csv):
logger.warning("'eval' requested, but dataset has no 'eval' group.")
groups.remove("eval")
for group in groups:
for purpose in ("enroll", "probe"):
label = f"{group}_{purpose}_csv"
......@@ -478,19 +497,33 @@ class CSVDatasetCrossValidation:
Parameters
----------
groups: list or None
Groups to consider, or all groups if `None` is given.
Groups to consider ('train' and/or 'dev'). If `None` is given,
returns the samples from all groups.
Returns
-------
samples: list
List of :class:`bob.pipelines.Sample` objects.
"""
valid_groups = ["train", "dev"]
groups = check_parameters_for_validity(
parameters=groups,
parameter_description="groups",
valid_parameters=valid_groups,
default_parameters=valid_groups,
)
samples = []
# Get train samples (background_model_samples returns a list of samples)
samples = self.background_model_samples()
if "train" in groups:
samples = samples + self.background_model_samples()
groups.remove("train")
# Get enroll and probe samples
groups = ["dev"] if not groups else groups
if "eval" in groups:
logger.info("'eval' requested but there is no 'eval' group defined.")
groups.remove("eval")
for group in groups:
samples = samples+ [s for s_set in self.references(group) for s in s_set]
samples = samples+ [s for s_set in self.probes(group) for s in s_set]
samples = samples + [s for s_set in self.references(group) for s in s_set]
samples = samples + [s for s_set in self.probes(group) for s in s_set]
return samples
......
......@@ -12,6 +12,7 @@ from bob.bio.base.algorithm import Algorithm
from bob.pipelines import DelayedSample
from bob.pipelines import DelayedSampleSet
from bob.pipelines import SampleSet
from bob.db.base.utils import check_parameters_for_validity
from .abstract_classes import BioAlgorithm
from .abstract_classes import Database
......@@ -188,15 +189,23 @@ class DatabaseConnector(Database):
Parameters
----------
groups: list or `None`
List of groups to consider (like 'dev' or 'eval'). If `None`, will
return samples from all the groups.
List of groups to consider ('world'/'train', 'dev', and/or 'eval').
If `None` is given, returns samples from all the groups.
Returns
-------
samples: list
List of all the samples of a database, conforming to the pipeline
API. See, e.g., :py:func:`bob.pipelines.first`.
List of all the samples of a database in :class:`bob.pipelines.Sample`
objects.
"""
valid_groups = self.database.groups()
groups = check_parameters_for_validity(
parameters=groups,
parameter_description="groups",
valid_parameters=valid_groups,
default_parameters=valid_groups,
)
logger.debug(f"Fetching all samples of groups '{groups}'.")
objects = self.database.all_files(groups=groups)
return [_biofile_to_delayed_sample(k, self.database) for k in objects]
......
......@@ -57,3 +57,6 @@ def test_all_samples():
all_samples = dummy_database.all_samples(groups=None)
assert len(all_samples) == 400
assert all([isinstance(s, DelayedSample) for s in all_samples])
assert len(dummy_database.all_samples(groups=["world"])) == 200
assert len(dummy_database.all_samples(groups=["dev"])) == 200
assert len(dummy_database.all_samples(groups=[])) == 400
......@@ -117,7 +117,8 @@ def test_csv_file_list_atnt():
assert len(dataset.background_model_samples()) == 200
assert len(dataset.references()) == 20
assert len(dataset.probes()) == 100
assert len(dataset.all_samples(groups=["dev"])) == 400
assert len(dataset.all_samples(groups=["train"])) == 200
assert len(dataset.all_samples(groups=["dev"])) == 200
assert len(dataset.all_samples(groups=None)) == 400
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment