Skip to content
Snippets Groups Projects
Commit caf1f710 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Added all_samples to csv dataset classes

parent 20bf6fe4
No related branches found
No related tags found
1 merge request!217Add a method to retrieve all the samples of a dataset
......@@ -326,6 +326,26 @@ class CSVDatasetDevEval:
group=group, purpose="probe", group_by_subject=False
)
def all_samples(self, groups=None):
"""
Reads and returns all the samples in `groups`.
Parameters
----------
groups: list or None
Groups to consider, or all groups if `None` is given.
"""
# Get train samples (background_model_samples returns a list of samples)
samples = self.background_model_samples()
# Get enroll and probe samples
groups = ["dev", "eval"] if not groups else groups
for group in groups:
for purpose in ("enroll", "probe"):
label = f"{group}_{purpose}_csv"
samples.append(self.csv_to_sample_loader(self.__dict__[label]))
return samples
class CSVDatasetCrossValidation:
"""
......@@ -456,6 +476,26 @@ class CSVDatasetCrossValidation:
def probes(self, group="dev"):
return self._load_from_cache("dev_probe_csv")
def all_samples(self, groups=None):
"""
Reads and returns all the samples in `groups`.
Parameters
----------
groups: list or None
Groups to consider, or all groups if `None` is given.
"""
# Get train samples (background_model_samples returns a list of samples)
samples = self.background_model_samples()
# Get enroll and probe samples
groups = ["dev", "eval"] if not groups else groups
for group in groups:
for purpose in ("enroll", "probe"):
label = f"{group}_{purpose}_csv"
samples.append(self.csv_to_sample_loader(self.__dict__[label]))
return samples
def group_samples_by_subject(samples):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment