Commit 4de7cc3b authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented CrossValidation Filelist dataset

parent 34bd50fb
from .csv_dataset import CSVDatasetDevEval, CSVToSampleLoader
from .csv_dataset import CSVDatasetDevEval, CSVToSampleLoader, CSVDatasetCrossValidation
from .file import BioFile
from .file import BioFileSet
from .database import BioDatabase
......
......@@ -8,6 +8,8 @@ import csv
import bob.io.base
import functools
from abc import ABCMeta, abstractmethod
import numpy as np
import itertools
class CSVBaseSampleLoader(metaclass=ABCMeta):
......@@ -91,7 +93,10 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
subject = row[1]
kwargs = dict([[h, r] for h, r in zip(header[2:], row[2:])])
return DelayedSample(
functools.partial(self.data_loader, os.path.join(self.dataset_original_directory, path+self.extension)),
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, path + self.extension),
),
key=path,
subject=subject,
**kwargs,
......@@ -118,11 +123,15 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
sample_sets[s.subject] = SampleSet(
[s], **get_attribute_from_sample(s)
)
sample_sets[s.subject].append(s)
else:
sample_sets[s.subject].append(s)
return list(sample_sets.values())
else:
return [SampleSet([s], **get_attribute_from_sample(s), references=references) for s in samples]
return [
SampleSet([s], **get_attribute_from_sample(s), references=references)
for s in samples
]
class CSVDatasetDevEval:
......@@ -194,8 +203,9 @@ class CSVDatasetDevEval:
protocol_na,e: str
The name of the protocol
csv_to_sample_loader:
csv_to_sample_loader: :any:`CSVBaseSampleLoader`
Base class that whose objective is to generate :any:`bob.pipelines.Samples`
and/or :any:`bob.pipelines.SampleSet` from csv rows
"""
......@@ -281,9 +291,6 @@ class CSVDatasetDevEval:
return self.cache["train"]
def _get_subjects_from_samplesets(self, sample_sets):
return list(set([s.subject for s in sample_sets]))
def _get_samplesets(self, group="dev", purpose="enroll", group_by_subject=False):
fetching_probes = False
......@@ -298,9 +305,7 @@ class CSVDatasetDevEval:
references = None
if fetching_probes:
references = self._get_subjects_from_samplesets(
self.references(group=group)
)
references = list(set([s.subject for s in self.references(group=group)]))
samples = self.csv_to_sample_loader(self.__dict__[cache_label])
......@@ -321,3 +326,144 @@ class CSVDatasetDevEval:
return self._get_samplesets(
group=group, purpose="probe", group_by_subject=False
)
class CSVDatasetCrossValidation:
"""
Generic filelist dataset for :any:`bob.bio.base.pipelines.VanillaBiometrics` pipeline that
handles **CROSS VALIDATION**.
Check :ref:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
interface.
This interface will take one `csv_file` as input and split into i-) data for training and
ii-) data for testing.
The data for testing will be further split in data for enrollment and data for probing.
The input CSV file should be casted in the following format:
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
path_2,subject_2
path_i,subject_j
...
Parameters
----------
csv_file_name: str
CSV file containing all the samples from your database
random_state: int
Pseudo-random number generator seed
test_size: float
Percentage of the subjects used for testing
samples_for_enrollment: float
Number of samples used for enrollment
csv_to_sample_loader: :any:`CSVBaseSampleLoader`
Base class that whose objective is to generate :any:`bob.pipelines.Samples`
and/or :any:`bob.pipelines.SampleSet` from csv rows
"""
def __init__(
self,
csv_file_name="metadata.csv",
random_state=0,
test_size=0.8,
samples_for_enrollment=1,
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load, dataset_original_directory="", extension=""
),
):
def get_dict_cache():
cache = dict()
cache["train"] = None
cache["dev_enroll_csv"] = None
cache["dev_probe_csv"] = None
return cache
self.random_state = random_state
self.cache = get_dict_cache()
self.csv_to_sample_loader = csv_to_sample_loader
self.csv_file_name = csv_file_name
self.samples_for_enrollment = samples_for_enrollment
self.test_size = test_size
if self.test_size < 0 and self.test_size > 1:
raise ValueError(
f"`test_size` should be between 0 and 1. {test_size} is provided"
)
def _do_cross_validation(self):
# Shuffling samples by subject
samples_by_subject = group_samples_by_subject(
self.csv_to_sample_loader(self.csv_file_name)
)
subjects = list(samples_by_subject.keys())
np.random.seed(self.random_state)
np.random.shuffle(subjects)
# Getting the training data
n_samples_for_training = len(subjects) - int(self.test_size * len(subjects))
self.cache["train"] = list(
itertools.chain(
*[samples_by_subject[s] for s in subjects[0:n_samples_for_training]]
)
)
# Splitting enroll and probe
self.cache["dev_enroll_csv"] = []
self.cache["dev_probe_csv"] = []
for s in subjects[n_samples_for_training:]:
samples = samples_by_subject[s]
if len(samples) < self.samples_for_enrollment:
raise ValueError(
f"Not enough samples ({len(samples)}) for enrollment for the subject {s}"
)
# Enrollment samples
self.cache["dev_enroll_csv"].append(
self.csv_to_sample_loader.convert_samples_to_samplesets(
samples[0 : self.samples_for_enrollment]
)[0]
)
self.cache[
"dev_probe_csv"
] += self.csv_to_sample_loader.convert_samples_to_samplesets(
samples[self.samples_for_enrollment :],
group_by_subject=False,
references=subjects[n_samples_for_training:],
)
def _load_from_cache(self, cache_key):
if self.cache[cache_key] is None:
self._do_cross_validation()
return self.cache[cache_key]
def background_model_samples(self):
return self._load_from_cache("train")
def references(self, group="dev"):
return self._load_from_cache("dev_enroll_csv")
def probes(self, group="dev"):
return self._load_from_cache("dev_probe_csv")
def group_samples_by_subject(samples):
# Grouping sample sets
samples_by_subject = dict()
for s in samples:
if s.subject not in samples_by_subject:
samples_by_subject[s.subject] = []
samples_by_subject[s.subject].append(s)
return samples_by_subject
PATH,SUBJECT
s1/9,1
s1/2,1
s1/4,1
s1/5,1
s1/7,1
s1/8,1
s1/1,1
s1/10,1
s1/3,1
s1/6,1
s2/9,2
s2/2,2
s2/4,2
s2/5,2
s2/7,2
s2/8,2
s2/1,2
s2/10,2
s2/3,2
s2/6,2
s5/9,5
s5/2,5
s5/4,5
s5/5,5
s5/7,5
s5/8,5
s5/1,5
s5/10,5
s5/3,5
s5/6,5
s6/9,6
s6/2,6
s6/4,6
s6/5,6
s6/7,6
s6/8,6
s6/1,6
s6/10,6
s6/3,6
s6/6,6
s10/9,10
s10/2,10
s10/4,10
s10/5,10
s10/7,10
s10/8,10
s10/1,10
s10/10,10
s10/3,10
s10/6,10
s11/9,11
s11/2,11
s11/4,11
s11/5,11
s11/7,11
s11/8,11
s11/1,11
s11/10,11
s11/3,11
s11/6,11
s12/9,12
s12/2,12
s12/4,12
s12/5,12
s12/7,12
s12/8,12
s12/1,12
s12/10,12
s12/3,12
s12/6,12
s14/9,14
s14/2,14
s14/4,14
s14/5,14
s14/7,14
s14/8,14
s14/1,14
s14/10,14
s14/3,14
s14/6,14
s16/9,16
s16/2,16
s16/4,16
s16/5,16
s16/7,16
s16/8,16
s16/1,16
s16/10,16
s16/3,16
s16/6,16
s17/9,17
s17/2,17
s17/4,17
s17/5,17
s17/7,17
s17/8,17
s17/1,17
s17/10,17
s17/3,17
s17/6,17
s20/9,20
s20/2,20
s20/4,20
s20/5,20
s20/7,20
s20/8,20
s20/1,20
s20/10,20
s20/3,20
s20/6,20
s21/9,21
s21/2,21
s21/4,21
s21/5,21
s21/7,21
s21/8,21
s21/1,21
s21/10,21
s21/3,21
s21/6,21
s24/9,24
s24/2,24
s24/4,24
s24/5,24
s24/7,24
s24/8,24
s24/1,24
s24/10,24
s24/3,24
s24/6,24
s26/9,26
s26/2,26
s26/4,26
s26/5,26
s26/7,26
s26/8,26
s26/1,26
s26/10,26
s26/3,26
s26/6,26
s27/9,27
s27/2,27
s27/4,27
s27/5,27
s27/7,27
s27/8,27
s27/1,27
s27/10,27
s27/3,27
s27/6,27
s29/9,29
s29/2,29
s29/4,29
s29/5,29
s29/7,29
s29/8,29
s29/1,29
s29/10,29
s29/3,29
s29/6,29
s33/9,33
s33/2,33
s33/4,33
s33/5,33
s33/7,33
s33/8,33
s33/1,33
s33/10,33
s33/3,33
s33/6,33
s34/9,34
s34/2,34
s34/4,34
s34/5,34
s34/7,34
s34/8,34
s34/1,34
s34/10,34
s34/3,34
s34/6,34
s36/9,36
s36/2,36
s36/4,36
s36/5,36
s36/7,36
s36/8,36
s36/1,36
s36/10,36
s36/3,36
s36/6,36
s39/9,39
s39/2,39
s39/4,39
s39/5,39
s39/7,39
s39/8,39
s39/1,39
s39/10,39
s39/3,39
s39/6,39
s3/9,3
s3/2,3
s3/4,3
s3/5,3
s3/7,3
s4/9,4
s4/2,4
s4/4,4
s4/5,4
s4/7,4
s7/9,7
s7/2,7
s7/4,7
s7/5,7
s7/7,7
s8/9,8
s8/2,8
s8/4,8
s8/5,8
s8/7,8
s9/9,9
s9/2,9
s9/4,9
s9/5,9
s9/7,9
s13/9,13
s13/2,13
s13/4,13
s13/5,13
s13/7,13
s15/9,15
s15/2,15
s15/4,15
s15/5,15
s15/7,15
s18/9,18
s18/2,18
s18/4,18
s18/5,18
s18/7,18
s19/9,19
s19/2,19
s19/4,19
s19/5,19
s19/7,19
s22/9,22
s22/2,22
s22/4,22
s22/5,22
s22/7,22
s23/9,23
s23/2,23
s23/4,23
s23/5,23
s23/7,23
s25/9,25
s25/2,25
s25/4,25
s25/5,25
s25/7,25
s28/9,28
s28/2,28
s28/4,28
s28/5,28
s28/7,28
s30/9,30
s30/2,30
s30/4,30
s30/5,30
s30/7,30
s31/9,31
s31/2,31
s31/4,31
s31/5,31
s31/7,31
s32/9,32
s32/2,32
s32/4,32
s32/5,32
s32/7,32
s35/9,35
s35/2,35
s35/4,35
s35/5,35
s35/7,35
s37/9,37
s37/2,37
s37/4,37
s37/5,37
s37/7,37
s38/9,38
s38/2,38
s38/4,38
s38/5,38
s38/7,38
s40/9,40
s40/2,40
s40/4,40
s40/5,40
s40/7,40
s3/8,3
s3/1,3
s3/10,3
s3/3,3
s3/6,3
s4/8,4
s4/1,4
s4/10,4
s4/3,4
s4/6,4
s7/8,7
s7/1,7
s7/10,7
s7/3,7
s7/6,7
s8/8,8
s8/1,8
s8/10,8
s8/3,8
s8/6,8
s9/8,9
s9/1,9
s9/10,9
s9/3,9
s9/6,9
s13/8,13
s13/1,13
s13/10,13
s13/3,13
s13/6,13
s15/8,15
s15/1,15
s15/10,15
s15/3,15
s15/6,15
s18/8,18
s18/1,18
s18/10,18
s18/3,18
s18/6,18
s19/8,19
s19/1,19
s19/10,19
s19/3,19
s19/6,19
s22/8,22
s22/1,22
s22/10,22
s22/3,22
s22/6,22
s23/8,23
s23/1,23
s23/10,23
s23/3,23
s23/6,23
s25/8,25
s25/1,25
s25/10,25
s25/3,25
s25/6,25
s28/8,28
s28/1,28
s28/10,28
s28/3,28
s28/6,28
s30/8,30
s30/1,30
s30/10,30
s30/3,30
s30/6,30
s31/8,31
s31/1,31
s31/10,31
s31/3,31
s31/6,31
s32/8,32
s32/1,32
s32/10,32
s32/3,32
s32/6,32
s35/8,35
s35/1,35
s35/10,35
s35/3,35
s35/6,35
s37/8,37
s37/1,37
s37/10,37
s37/3,37
s37/6,37
s38/8,38
s38/1,38
s38/10,38
s38/3,38
s38/6,38
s40/8,40
s40/1,40
s40/10,40
s40/3,40
s40/6,40
......@@ -7,7 +7,7 @@
import os
import bob.io.base
import bob.io.base.test_utils
from bob.bio.base.database import CSVDatasetDevEval, CSVToSampleLoader
from bob.bio.base.database import CSVDatasetDevEval, CSVToSampleLoader, CSVDatasetCrossValidation