Skip to content
Snippets Groups Projects
Commit 193cf6fc authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Handle 'train' group in database.all_samples()

Check parameters with existing function
Add tests for train group
parent a2499e03
No related branches found
No related tags found
1 merge request!222Follow-up to "Add a method to retrieve all the samples of a dataset"
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import os import os
from bob.pipelines import Sample, DelayedSample, SampleSet from bob.pipelines import Sample, DelayedSample, SampleSet
from bob.db.base.utils import check_parameters_for_validity
import csv import csv
import bob.io.base import bob.io.base
import functools import functools
...@@ -325,16 +326,34 @@ class CSVDatasetDevEval: ...@@ -325,16 +326,34 @@ class CSVDatasetDevEval:
Parameters Parameters
---------- ----------
groups: list or None groups: list or None
Groups to consider, or all groups if `None` is given. Groups to consider ('train', 'dev', and/or 'eval'). If `None` is
given, returns the samples from all groups.
Returns
-------
samples: list
List of :class:`bob.pipelines.Sample` objects.
""" """
valid_groups = ["train"]
if self.dev_enroll_csv and self.dev_probe_csv:
valid_groups.append("dev")
if self.eval_enroll_csv and self.eval_probe_csv:
valid_groups.append("eval")
groups = check_parameters_for_validity(
parameters=groups,
parameter_description="groups",
valid_parameters=valid_groups,
default_parameters=valid_groups,
)
samples = []
# Get train samples (background_model_samples returns a list of samples) # Get train samples (background_model_samples returns a list of samples)
samples = self.background_model_samples() if "train" in groups:
samples = samples + self.background_model_samples()
groups.remove("train")
# Get enroll and probe samples # Get enroll and probe samples
groups = ["dev", "eval"] if not groups else groups
if "eval" in groups and (not self.eval_enroll_csv or not self.eval_probe_csv):
logger.warning("'eval' requested, but dataset has no 'eval' group.")
groups.remove("eval")
for group in groups: for group in groups:
for purpose in ("enroll", "probe"): for purpose in ("enroll", "probe"):
label = f"{group}_{purpose}_csv" label = f"{group}_{purpose}_csv"
...@@ -478,16 +497,30 @@ class CSVDatasetCrossValidation: ...@@ -478,16 +497,30 @@ class CSVDatasetCrossValidation:
Parameters Parameters
---------- ----------
groups: list or None groups: list or None
Groups to consider, or all groups if `None` is given. Groups to consider ('train' and/or 'dev'). If `None` is given,
returns the samples from all groups.
Returns
-------
samples: list
List of :class:`bob.pipelines.Sample` objects.
""" """
valid_groups = ["train", "dev"]
groups = check_parameters_for_validity(
parameters=groups,
parameter_description="groups",
valid_parameters=valid_groups,
default_parameters=valid_groups,
)
samples = []
# Get train samples (background_model_samples returns a list of samples) # Get train samples (background_model_samples returns a list of samples)
samples = self.background_model_samples() if "train" in groups:
samples = samples + self.background_model_samples()
groups.remove("train")
# Get enroll and probe samples # Get enroll and probe samples
groups = ["dev"] if not groups else groups
if "eval" in groups:
logger.info("'eval' requested but there is no 'eval' group defined.")
groups.remove("eval")
for group in groups: for group in groups:
samples = samples + [s for s_set in self.references(group) for s in s_set] samples = samples + [s for s_set in self.references(group) for s in s_set]
samples = samples + [s for s_set in self.probes(group) for s in s_set] samples = samples + [s for s_set in self.probes(group) for s in s_set]
......
...@@ -12,6 +12,7 @@ from bob.bio.base.algorithm import Algorithm ...@@ -12,6 +12,7 @@ from bob.bio.base.algorithm import Algorithm
from bob.pipelines import DelayedSample from bob.pipelines import DelayedSample
from bob.pipelines import DelayedSampleSet from bob.pipelines import DelayedSampleSet
from bob.pipelines import SampleSet from bob.pipelines import SampleSet
from bob.db.base.utils import check_parameters_for_validity
from .abstract_classes import BioAlgorithm from .abstract_classes import BioAlgorithm
from .abstract_classes import Database from .abstract_classes import Database
...@@ -188,15 +189,23 @@ class DatabaseConnector(Database): ...@@ -188,15 +189,23 @@ class DatabaseConnector(Database):
Parameters Parameters
---------- ----------
groups: list or `None` groups: list or `None`
List of groups to consider (like 'dev' or 'eval'). If `None`, will List of groups to consider ('world'/'train', 'dev', and/or 'eval').
return samples from all the groups. If `None` is given, returns samples from all the groups.
Returns Returns
------- -------
samples: list samples: list
List of all the samples of a database, conforming to the pipeline List of all the samples of a database in :class:`bob.pipelines.Sample`
API. See, e.g., :py:func:`bob.pipelines.first`. objects.
""" """
valid_groups = self.database.groups()
groups = check_parameters_for_validity(
parameters=groups,
parameter_description="groups",
valid_parameters=valid_groups,
default_parameters=valid_groups,
)
logger.debug(f"Fetching all samples of groups '{groups}'.")
objects = self.database.all_files(groups=groups) objects = self.database.all_files(groups=groups)
return [_biofile_to_delayed_sample(k, self.database) for k in objects] return [_biofile_to_delayed_sample(k, self.database) for k in objects]
......
...@@ -57,3 +57,6 @@ def test_all_samples(): ...@@ -57,3 +57,6 @@ def test_all_samples():
all_samples = dummy_database.all_samples(groups=None) all_samples = dummy_database.all_samples(groups=None)
assert len(all_samples) == 400 assert len(all_samples) == 400
assert all([isinstance(s, DelayedSample) for s in all_samples]) assert all([isinstance(s, DelayedSample) for s in all_samples])
assert len(dummy_database.all_samples(groups=["world"])) == 200
assert len(dummy_database.all_samples(groups=["dev"])) == 200
assert len(dummy_database.all_samples(groups=[])) == 400
...@@ -117,7 +117,8 @@ def test_csv_file_list_atnt(): ...@@ -117,7 +117,8 @@ def test_csv_file_list_atnt():
assert len(dataset.background_model_samples()) == 200 assert len(dataset.background_model_samples()) == 200
assert len(dataset.references()) == 20 assert len(dataset.references()) == 20
assert len(dataset.probes()) == 100 assert len(dataset.probes()) == 100
assert len(dataset.all_samples(groups=["dev"])) == 400 assert len(dataset.all_samples(groups=["train"])) == 200
assert len(dataset.all_samples(groups=["dev"])) == 200
assert len(dataset.all_samples(groups=None)) == 400 assert len(dataset.all_samples(groups=None)) == 400
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment