From 4de7cc3b9f9abd342f01a9d3aa0147878e1a3c0c Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Wed, 7 Oct 2020 15:22:11 +0200
Subject: [PATCH] Implemented CrossValidation Filelist dataset

---
 bob/bio/base/database/__init__.py             |   2 +-
 bob/bio/base/database/csv_dataset.py          | 168 +++++++-
 .../data/atnt/cross_validation/metadata.csv   | 401 ++++++++++++++++++
 bob/bio/base/test/test_filelist.py            |  69 ++-
 4 files changed, 611 insertions(+), 29 deletions(-)
 create mode 100644 bob/bio/base/test/data/atnt/cross_validation/metadata.csv

diff --git a/bob/bio/base/database/__init__.py b/bob/bio/base/database/__init__.py
index 3d728e2b..1ff6325e 100644
--- a/bob/bio/base/database/__init__.py
+++ b/bob/bio/base/database/__init__.py
@@ -1,4 +1,4 @@
-from .csv_dataset import CSVDatasetDevEval, CSVToSampleLoader
+from .csv_dataset import CSVDatasetDevEval, CSVToSampleLoader, CSVDatasetCrossValidation
 from .file import BioFile
 from .file import BioFileSet
 from .database import BioDatabase
diff --git a/bob/bio/base/database/csv_dataset.py b/bob/bio/base/database/csv_dataset.py
index 8336a4cf..bafb21a7 100644
--- a/bob/bio/base/database/csv_dataset.py
+++ b/bob/bio/base/database/csv_dataset.py
@@ -8,6 +8,8 @@ import csv
 import bob.io.base
 import functools
 from abc import ABCMeta, abstractmethod
+import numpy as np
+import itertools
 
 
 class CSVBaseSampleLoader(metaclass=ABCMeta):
@@ -91,7 +93,10 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
         subject = row[1]
         kwargs = dict([[h, r] for h, r in zip(header[2:], row[2:])])
         return DelayedSample(
-            functools.partial(self.data_loader, os.path.join(self.dataset_original_directory, path+self.extension)),
+            functools.partial(
+                self.data_loader,
+                os.path.join(self.dataset_original_directory, path + self.extension),
+            ),
             key=path,
             subject=subject,
             **kwargs,
@@ -118,11 +123,15 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
                     sample_sets[s.subject] = SampleSet(
                         [s], **get_attribute_from_sample(s)
                     )
-                sample_sets[s.subject].append(s)
+                else:
+                    sample_sets[s.subject].append(s)
             return list(sample_sets.values())
 
         else:
-            return [SampleSet([s], **get_attribute_from_sample(s), references=references) for s in samples]
+            return [
+                SampleSet([s], **get_attribute_from_sample(s), references=references)
+                for s in samples
+            ]
 
 
 class CSVDatasetDevEval:
@@ -194,8 +203,9 @@ class CSVDatasetDevEval:
         protocol_na,e: str
           The name of the protocol
 
-        csv_to_sample_loader:
-
+        csv_to_sample_loader: :any:`CSVBaseSampleLoader`
+            Base class that whose objective is to generate :any:`bob.pipelines.Samples`
+            and/or :any:`bob.pipelines.SampleSet` from csv rows
 
     """
 
@@ -281,9 +291,6 @@ class CSVDatasetDevEval:
 
         return self.cache["train"]
 
-    def _get_subjects_from_samplesets(self, sample_sets):
-        return list(set([s.subject for s in sample_sets]))
-
     def _get_samplesets(self, group="dev", purpose="enroll", group_by_subject=False):
 
         fetching_probes = False
@@ -298,9 +305,7 @@ class CSVDatasetDevEval:
 
         references = None
         if fetching_probes:
-            references = self._get_subjects_from_samplesets(
-                self.references(group=group)
-            )
+            references = list(set([s.subject for s in self.references(group=group)]))
 
         samples = self.csv_to_sample_loader(self.__dict__[cache_label])
 
@@ -321,3 +326,144 @@ class CSVDatasetDevEval:
         return self._get_samplesets(
             group=group, purpose="probe", group_by_subject=False
         )
+
+
+class CSVDatasetCrossValidation:
+    """
+    Generic filelist dataset for :any:`bob.bio.base.pipelines.VanillaBiometrics` pipeline that 
+    handles **CROSS VALIDATION**.
+
+    Check :ref:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
+    interface.
+
+
+    This interface will take one `csv_file` as input and split into i-) data for training and
+    ii-) data for testing.
+    The data for testing will be further split in data for enrollment and data for probing.
+    The input CSV file should be casted in the following format:
+
+    .. code-block:: text
+
+       PATH,SUBJECT
+       path_1,subject_1
+       path_2,subject_2
+       path_i,subject_j
+       ...
+
+    Parameters
+    ----------
+
+    csv_file_name: str
+      CSV file containing all the samples from your database
+
+    random_state: int
+      Pseudo-random number generator seed
+
+    test_size: float
+      Percentage of the subjects used for testing
+
+    samples_for_enrollment: float
+      Number of samples used for enrollment
+
+    csv_to_sample_loader: :any:`CSVBaseSampleLoader`
+        Base class that whose objective is to generate :any:`bob.pipelines.Samples`
+        and/or :any:`bob.pipelines.SampleSet` from csv rows
+
+    """
+
+    def __init__(
+        self,
+        csv_file_name="metadata.csv",
+        random_state=0,
+        test_size=0.8,
+        samples_for_enrollment=1,
+        csv_to_sample_loader=CSVToSampleLoader(
+            data_loader=bob.io.base.load, dataset_original_directory="", extension=""
+        ),
+    ):
+        def get_dict_cache():
+            cache = dict()
+            cache["train"] = None
+            cache["dev_enroll_csv"] = None
+            cache["dev_probe_csv"] = None
+            return cache
+
+        self.random_state = random_state
+        self.cache = get_dict_cache()
+        self.csv_to_sample_loader = csv_to_sample_loader
+        self.csv_file_name = csv_file_name
+        self.samples_for_enrollment = samples_for_enrollment
+        self.test_size = test_size
+
+        if self.test_size < 0 and self.test_size > 1:
+            raise ValueError(
+                f"`test_size` should be between 0 and 1. {test_size} is provided"
+            )
+
+    def _do_cross_validation(self):
+
+        # Shuffling samples by subject
+        samples_by_subject = group_samples_by_subject(
+            self.csv_to_sample_loader(self.csv_file_name)
+        )
+        subjects = list(samples_by_subject.keys())
+        np.random.seed(self.random_state)
+        np.random.shuffle(subjects)
+
+        # Getting the training data
+        n_samples_for_training = len(subjects) - int(self.test_size * len(subjects))
+        self.cache["train"] = list(
+            itertools.chain(
+                *[samples_by_subject[s] for s in subjects[0:n_samples_for_training]]
+            )
+        )
+
+        # Splitting enroll and probe
+        self.cache["dev_enroll_csv"] = []
+        self.cache["dev_probe_csv"] = []
+        for s in subjects[n_samples_for_training:]:
+            samples = samples_by_subject[s]
+            if len(samples) < self.samples_for_enrollment:
+                raise ValueError(
+                    f"Not enough samples ({len(samples)}) for enrollment for the subject {s}"
+                )
+
+            # Enrollment samples
+            self.cache["dev_enroll_csv"].append(
+                self.csv_to_sample_loader.convert_samples_to_samplesets(
+                    samples[0 : self.samples_for_enrollment]
+                )[0]
+            )
+
+            self.cache[
+                "dev_probe_csv"
+            ] += self.csv_to_sample_loader.convert_samples_to_samplesets(
+                samples[self.samples_for_enrollment :],
+                group_by_subject=False,
+                references=subjects[n_samples_for_training:],
+            )
+
+    def _load_from_cache(self, cache_key):
+        if self.cache[cache_key] is None:
+            self._do_cross_validation()
+        return self.cache[cache_key]
+
+    def background_model_samples(self):
+        return self._load_from_cache("train")
+
+    def references(self, group="dev"):
+        return self._load_from_cache("dev_enroll_csv")
+
+    def probes(self, group="dev"):
+        return self._load_from_cache("dev_probe_csv")
+
+
+def group_samples_by_subject(samples):
+
+    # Grouping sample sets
+    samples_by_subject = dict()
+    for s in samples:
+        if s.subject not in samples_by_subject:
+            samples_by_subject[s.subject] = []
+        samples_by_subject[s.subject].append(s)
+    return samples_by_subject
diff --git a/bob/bio/base/test/data/atnt/cross_validation/metadata.csv b/bob/bio/base/test/data/atnt/cross_validation/metadata.csv
new file mode 100644
index 00000000..21bf0ae0
--- /dev/null
+++ b/bob/bio/base/test/data/atnt/cross_validation/metadata.csv
@@ -0,0 +1,401 @@
+PATH,SUBJECT
+s1/9,1
+s1/2,1
+s1/4,1
+s1/5,1
+s1/7,1
+s1/8,1
+s1/1,1
+s1/10,1
+s1/3,1
+s1/6,1
+s2/9,2
+s2/2,2
+s2/4,2
+s2/5,2
+s2/7,2
+s2/8,2
+s2/1,2
+s2/10,2
+s2/3,2
+s2/6,2
+s5/9,5
+s5/2,5
+s5/4,5
+s5/5,5
+s5/7,5
+s5/8,5
+s5/1,5
+s5/10,5
+s5/3,5
+s5/6,5
+s6/9,6
+s6/2,6
+s6/4,6
+s6/5,6
+s6/7,6
+s6/8,6
+s6/1,6
+s6/10,6
+s6/3,6
+s6/6,6
+s10/9,10
+s10/2,10
+s10/4,10
+s10/5,10
+s10/7,10
+s10/8,10
+s10/1,10
+s10/10,10
+s10/3,10
+s10/6,10
+s11/9,11
+s11/2,11
+s11/4,11
+s11/5,11
+s11/7,11
+s11/8,11
+s11/1,11
+s11/10,11
+s11/3,11
+s11/6,11
+s12/9,12
+s12/2,12
+s12/4,12
+s12/5,12
+s12/7,12
+s12/8,12
+s12/1,12
+s12/10,12
+s12/3,12
+s12/6,12
+s14/9,14
+s14/2,14
+s14/4,14
+s14/5,14
+s14/7,14
+s14/8,14
+s14/1,14
+s14/10,14
+s14/3,14
+s14/6,14
+s16/9,16
+s16/2,16
+s16/4,16
+s16/5,16
+s16/7,16
+s16/8,16
+s16/1,16
+s16/10,16
+s16/3,16
+s16/6,16
+s17/9,17
+s17/2,17
+s17/4,17
+s17/5,17
+s17/7,17
+s17/8,17
+s17/1,17
+s17/10,17
+s17/3,17
+s17/6,17
+s20/9,20
+s20/2,20
+s20/4,20
+s20/5,20
+s20/7,20
+s20/8,20
+s20/1,20
+s20/10,20
+s20/3,20
+s20/6,20
+s21/9,21
+s21/2,21
+s21/4,21
+s21/5,21
+s21/7,21
+s21/8,21
+s21/1,21
+s21/10,21
+s21/3,21
+s21/6,21
+s24/9,24
+s24/2,24
+s24/4,24
+s24/5,24
+s24/7,24
+s24/8,24
+s24/1,24
+s24/10,24
+s24/3,24
+s24/6,24
+s26/9,26
+s26/2,26
+s26/4,26
+s26/5,26
+s26/7,26
+s26/8,26
+s26/1,26
+s26/10,26
+s26/3,26
+s26/6,26
+s27/9,27
+s27/2,27
+s27/4,27
+s27/5,27
+s27/7,27
+s27/8,27
+s27/1,27
+s27/10,27
+s27/3,27
+s27/6,27
+s29/9,29
+s29/2,29
+s29/4,29
+s29/5,29
+s29/7,29
+s29/8,29
+s29/1,29
+s29/10,29
+s29/3,29
+s29/6,29
+s33/9,33
+s33/2,33
+s33/4,33
+s33/5,33
+s33/7,33
+s33/8,33
+s33/1,33
+s33/10,33
+s33/3,33
+s33/6,33
+s34/9,34
+s34/2,34
+s34/4,34
+s34/5,34
+s34/7,34
+s34/8,34
+s34/1,34
+s34/10,34
+s34/3,34
+s34/6,34
+s36/9,36
+s36/2,36
+s36/4,36
+s36/5,36
+s36/7,36
+s36/8,36
+s36/1,36
+s36/10,36
+s36/3,36
+s36/6,36
+s39/9,39
+s39/2,39
+s39/4,39
+s39/5,39
+s39/7,39
+s39/8,39
+s39/1,39
+s39/10,39
+s39/3,39
+s39/6,39
+s3/9,3
+s3/2,3
+s3/4,3
+s3/5,3
+s3/7,3
+s4/9,4
+s4/2,4
+s4/4,4
+s4/5,4
+s4/7,4
+s7/9,7
+s7/2,7
+s7/4,7
+s7/5,7
+s7/7,7
+s8/9,8
+s8/2,8
+s8/4,8
+s8/5,8
+s8/7,8
+s9/9,9
+s9/2,9
+s9/4,9
+s9/5,9
+s9/7,9
+s13/9,13
+s13/2,13
+s13/4,13
+s13/5,13
+s13/7,13
+s15/9,15
+s15/2,15
+s15/4,15
+s15/5,15
+s15/7,15
+s18/9,18
+s18/2,18
+s18/4,18
+s18/5,18
+s18/7,18
+s19/9,19
+s19/2,19
+s19/4,19
+s19/5,19
+s19/7,19
+s22/9,22
+s22/2,22
+s22/4,22
+s22/5,22
+s22/7,22
+s23/9,23
+s23/2,23
+s23/4,23
+s23/5,23
+s23/7,23
+s25/9,25
+s25/2,25
+s25/4,25
+s25/5,25
+s25/7,25
+s28/9,28
+s28/2,28
+s28/4,28
+s28/5,28
+s28/7,28
+s30/9,30
+s30/2,30
+s30/4,30
+s30/5,30
+s30/7,30
+s31/9,31
+s31/2,31
+s31/4,31
+s31/5,31
+s31/7,31
+s32/9,32
+s32/2,32
+s32/4,32
+s32/5,32
+s32/7,32
+s35/9,35
+s35/2,35
+s35/4,35
+s35/5,35
+s35/7,35
+s37/9,37
+s37/2,37
+s37/4,37
+s37/5,37
+s37/7,37
+s38/9,38
+s38/2,38
+s38/4,38
+s38/5,38
+s38/7,38
+s40/9,40
+s40/2,40
+s40/4,40
+s40/5,40
+s40/7,40
+s3/8,3
+s3/1,3
+s3/10,3
+s3/3,3
+s3/6,3
+s4/8,4
+s4/1,4
+s4/10,4
+s4/3,4
+s4/6,4
+s7/8,7
+s7/1,7
+s7/10,7
+s7/3,7
+s7/6,7
+s8/8,8
+s8/1,8
+s8/10,8
+s8/3,8
+s8/6,8
+s9/8,9
+s9/1,9
+s9/10,9
+s9/3,9
+s9/6,9
+s13/8,13
+s13/1,13
+s13/10,13
+s13/3,13
+s13/6,13
+s15/8,15
+s15/1,15
+s15/10,15
+s15/3,15
+s15/6,15
+s18/8,18
+s18/1,18
+s18/10,18
+s18/3,18
+s18/6,18
+s19/8,19
+s19/1,19
+s19/10,19
+s19/3,19
+s19/6,19
+s22/8,22
+s22/1,22
+s22/10,22
+s22/3,22
+s22/6,22
+s23/8,23
+s23/1,23
+s23/10,23
+s23/3,23
+s23/6,23
+s25/8,25
+s25/1,25
+s25/10,25
+s25/3,25
+s25/6,25
+s28/8,28
+s28/1,28
+s28/10,28
+s28/3,28
+s28/6,28
+s30/8,30
+s30/1,30
+s30/10,30
+s30/3,30
+s30/6,30
+s31/8,31
+s31/1,31
+s31/10,31
+s31/3,31
+s31/6,31
+s32/8,32
+s32/1,32
+s32/10,32
+s32/3,32
+s32/6,32
+s35/8,35
+s35/1,35
+s35/10,35
+s35/3,35
+s35/6,35
+s37/8,37
+s37/1,37
+s37/10,37
+s37/3,37
+s37/6,37
+s38/8,38
+s38/1,38
+s38/10,38
+s38/3,38
+s38/6,38
+s40/8,40
+s40/1,40
+s40/10,40
+s40/3,40
+s40/6,40
diff --git a/bob/bio/base/test/test_filelist.py b/bob/bio/base/test/test_filelist.py
index 9e154c48..939dded8 100644
--- a/bob/bio/base/test/test_filelist.py
+++ b/bob/bio/base/test/test_filelist.py
@@ -7,7 +7,7 @@
 import os
 import bob.io.base
 import bob.io.base.test_utils
-from bob.bio.base.database import CSVDatasetDevEval, CSVToSampleLoader
+from bob.bio.base.database import CSVDatasetDevEval, CSVToSampleLoader, CSVDatasetCrossValidation
 import nose.tools
 from bob.pipelines import DelayedSample, SampleSet
 import numpy as np
@@ -28,6 +28,10 @@ atnt_protocol_path = os.path.realpath(
     bob.io.base.test_utils.datafile(".", __name__, "data/atnt")
 )
 
+atnt_protocol_path_cross_validation = os.path.join(os.path.realpath(
+    bob.io.base.test_utils.datafile(".", __name__, "data/atnt/cross_validation/")
+),"metadata.csv")
+
 
 def check_all_true(list_of_something, something):
     """
@@ -100,36 +104,67 @@ def test_csv_file_list_atnt():
     assert len(dataset.probes()) == 100
 
 
-def test_atnt_experiment():
-    def load(path):
-        import bob.io.image
 
-        return bob.io.base.load(path)
+def run_experiment(dataset):
 
     def linearize(X):
         X = np.asarray(X)
         return np.reshape(X, (X.shape[0], -1))
 
-    dataset = CSVDatasetDevEval(
-        dataset_protocol_path=atnt_protocol_path,
-        protocol_name="idiap_protocol",
-        csv_to_sample_loader=CSVToSampleLoader(
-            data_loader=load,
-            dataset_original_directory=atnt_database_directory(),
-            extension=".pgm",
-        ),
-    )
-
     #### Testing it in a real recognition systems
     transformer = wrap(["sample"], make_pipeline(FunctionTransformer(linearize)))
 
     vanilla_biometrics_pipeline = VanillaBiometricsPipeline(transformer, Distance())
 
-    scores = vanilla_biometrics_pipeline(
+    return vanilla_biometrics_pipeline(
         dataset.background_model_samples(),
         dataset.references(),
         dataset.probes(),
     )
 
+
+def data_loader(path):
+    import bob.io.image
+    return bob.io.base.load(path)
+
+def test_atnt_experiment():
+
+    dataset = CSVDatasetDevEval(
+        dataset_protocol_path=atnt_protocol_path,
+        protocol_name="idiap_protocol",
+        csv_to_sample_loader=CSVToSampleLoader(
+            data_loader=data_loader,
+            dataset_original_directory=atnt_database_directory(),
+            extension=".pgm",
+        ),
+    )
+
+    scores = run_experiment(dataset)
     assert len(scores)==100
-    assert np.alltrue([len(s)==20] for s in scores)
\ No newline at end of file
+    assert np.alltrue([len(s)==20] for s in scores)
+
+
+def test_atnt_experiment_cross_validation():
+
+    samples_per_identity = 10
+    total_identities = 40
+    samples_for_enrollment = 1
+    
+    def run_cross_validataion_experiment(test_size = 0.9):
+        dataset = CSVDatasetCrossValidation(
+            csv_file_name=atnt_protocol_path_cross_validation,
+            random_state=0,
+            test_size=test_size,
+            csv_to_sample_loader=CSVToSampleLoader(
+                data_loader=data_loader,
+                dataset_original_directory=atnt_database_directory(),
+                extension=".pgm",
+            ),
+        )
+
+        scores = run_experiment(dataset)
+        assert len(scores)==int(total_identities*test_size*(samples_per_identity-samples_for_enrollment))
+
+    run_cross_validataion_experiment(test_size = 0.9)
+    run_cross_validataion_experiment(test_size = 0.8)
+    run_cross_validataion_experiment(test_size = 0.5)
-- 
GitLab