From 4de7cc3b9f9abd342f01a9d3aa0147878e1a3c0c Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Wed, 7 Oct 2020 15:22:11 +0200 Subject: [PATCH] Implemented CrossValidation Filelist dataset --- bob/bio/base/database/__init__.py | 2 +- bob/bio/base/database/csv_dataset.py | 168 +++++++- .../data/atnt/cross_validation/metadata.csv | 401 ++++++++++++++++++ bob/bio/base/test/test_filelist.py | 69 ++- 4 files changed, 611 insertions(+), 29 deletions(-) create mode 100644 bob/bio/base/test/data/atnt/cross_validation/metadata.csv diff --git a/bob/bio/base/database/__init__.py b/bob/bio/base/database/__init__.py index 3d728e2b..1ff6325e 100644 --- a/bob/bio/base/database/__init__.py +++ b/bob/bio/base/database/__init__.py @@ -1,4 +1,4 @@ -from .csv_dataset import CSVDatasetDevEval, CSVToSampleLoader +from .csv_dataset import CSVDatasetDevEval, CSVToSampleLoader, CSVDatasetCrossValidation from .file import BioFile from .file import BioFileSet from .database import BioDatabase diff --git a/bob/bio/base/database/csv_dataset.py b/bob/bio/base/database/csv_dataset.py index 8336a4cf..bafb21a7 100644 --- a/bob/bio/base/database/csv_dataset.py +++ b/bob/bio/base/database/csv_dataset.py @@ -8,6 +8,8 @@ import csv import bob.io.base import functools from abc import ABCMeta, abstractmethod +import numpy as np +import itertools class CSVBaseSampleLoader(metaclass=ABCMeta): @@ -91,7 +93,10 @@ class CSVToSampleLoader(CSVBaseSampleLoader): subject = row[1] kwargs = dict([[h, r] for h, r in zip(header[2:], row[2:])]) return DelayedSample( - functools.partial(self.data_loader, os.path.join(self.dataset_original_directory, path+self.extension)), + functools.partial( + self.data_loader, + os.path.join(self.dataset_original_directory, path + self.extension), + ), key=path, subject=subject, **kwargs, @@ -118,11 +123,15 @@ class CSVToSampleLoader(CSVBaseSampleLoader): sample_sets[s.subject] = SampleSet( [s], **get_attribute_from_sample(s) ) - sample_sets[s.subject].append(s) + else: + sample_sets[s.subject].append(s) return list(sample_sets.values()) else: - return [SampleSet([s], **get_attribute_from_sample(s), references=references) for s in samples] + return [ + SampleSet([s], **get_attribute_from_sample(s), references=references) + for s in samples + ] class CSVDatasetDevEval: @@ -194,8 +203,9 @@ class CSVDatasetDevEval: protocol_na,e: str The name of the protocol - csv_to_sample_loader: - + csv_to_sample_loader: :any:`CSVBaseSampleLoader` + Base class that whose objective is to generate :any:`bob.pipelines.Samples` + and/or :any:`bob.pipelines.SampleSet` from csv rows """ @@ -281,9 +291,6 @@ class CSVDatasetDevEval: return self.cache["train"] - def _get_subjects_from_samplesets(self, sample_sets): - return list(set([s.subject for s in sample_sets])) - def _get_samplesets(self, group="dev", purpose="enroll", group_by_subject=False): fetching_probes = False @@ -298,9 +305,7 @@ class CSVDatasetDevEval: references = None if fetching_probes: - references = self._get_subjects_from_samplesets( - self.references(group=group) - ) + references = list(set([s.subject for s in self.references(group=group)])) samples = self.csv_to_sample_loader(self.__dict__[cache_label]) @@ -321,3 +326,144 @@ class CSVDatasetDevEval: return self._get_samplesets( group=group, purpose="probe", group_by_subject=False ) + + +class CSVDatasetCrossValidation: + """ + Generic filelist dataset for :any:`bob.bio.base.pipelines.VanillaBiometrics` pipeline that + handles **CROSS VALIDATION**. + + Check :ref:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset + interface. + + + This interface will take one `csv_file` as input and split into i-) data for training and + ii-) data for testing. + The data for testing will be further split in data for enrollment and data for probing. + The input CSV file should be casted in the following format: + + .. code-block:: text + + PATH,SUBJECT + path_1,subject_1 + path_2,subject_2 + path_i,subject_j + ... + + Parameters + ---------- + + csv_file_name: str + CSV file containing all the samples from your database + + random_state: int + Pseudo-random number generator seed + + test_size: float + Percentage of the subjects used for testing + + samples_for_enrollment: float + Number of samples used for enrollment + + csv_to_sample_loader: :any:`CSVBaseSampleLoader` + Base class that whose objective is to generate :any:`bob.pipelines.Samples` + and/or :any:`bob.pipelines.SampleSet` from csv rows + + """ + + def __init__( + self, + csv_file_name="metadata.csv", + random_state=0, + test_size=0.8, + samples_for_enrollment=1, + csv_to_sample_loader=CSVToSampleLoader( + data_loader=bob.io.base.load, dataset_original_directory="", extension="" + ), + ): + def get_dict_cache(): + cache = dict() + cache["train"] = None + cache["dev_enroll_csv"] = None + cache["dev_probe_csv"] = None + return cache + + self.random_state = random_state + self.cache = get_dict_cache() + self.csv_to_sample_loader = csv_to_sample_loader + self.csv_file_name = csv_file_name + self.samples_for_enrollment = samples_for_enrollment + self.test_size = test_size + + if self.test_size < 0 and self.test_size > 1: + raise ValueError( + f"`test_size` should be between 0 and 1. {test_size} is provided" + ) + + def _do_cross_validation(self): + + # Shuffling samples by subject + samples_by_subject = group_samples_by_subject( + self.csv_to_sample_loader(self.csv_file_name) + ) + subjects = list(samples_by_subject.keys()) + np.random.seed(self.random_state) + np.random.shuffle(subjects) + + # Getting the training data + n_samples_for_training = len(subjects) - int(self.test_size * len(subjects)) + self.cache["train"] = list( + itertools.chain( + *[samples_by_subject[s] for s in subjects[0:n_samples_for_training]] + ) + ) + + # Splitting enroll and probe + self.cache["dev_enroll_csv"] = [] + self.cache["dev_probe_csv"] = [] + for s in subjects[n_samples_for_training:]: + samples = samples_by_subject[s] + if len(samples) < self.samples_for_enrollment: + raise ValueError( + f"Not enough samples ({len(samples)}) for enrollment for the subject {s}" + ) + + # Enrollment samples + self.cache["dev_enroll_csv"].append( + self.csv_to_sample_loader.convert_samples_to_samplesets( + samples[0 : self.samples_for_enrollment] + )[0] + ) + + self.cache[ + "dev_probe_csv" + ] += self.csv_to_sample_loader.convert_samples_to_samplesets( + samples[self.samples_for_enrollment :], + group_by_subject=False, + references=subjects[n_samples_for_training:], + ) + + def _load_from_cache(self, cache_key): + if self.cache[cache_key] is None: + self._do_cross_validation() + return self.cache[cache_key] + + def background_model_samples(self): + return self._load_from_cache("train") + + def references(self, group="dev"): + return self._load_from_cache("dev_enroll_csv") + + def probes(self, group="dev"): + return self._load_from_cache("dev_probe_csv") + + +def group_samples_by_subject(samples): + + # Grouping sample sets + samples_by_subject = dict() + for s in samples: + if s.subject not in samples_by_subject: + samples_by_subject[s.subject] = [] + samples_by_subject[s.subject].append(s) + return samples_by_subject diff --git a/bob/bio/base/test/data/atnt/cross_validation/metadata.csv b/bob/bio/base/test/data/atnt/cross_validation/metadata.csv new file mode 100644 index 00000000..21bf0ae0 --- /dev/null +++ b/bob/bio/base/test/data/atnt/cross_validation/metadata.csv @@ -0,0 +1,401 @@ +PATH,SUBJECT +s1/9,1 +s1/2,1 +s1/4,1 +s1/5,1 +s1/7,1 +s1/8,1 +s1/1,1 +s1/10,1 +s1/3,1 +s1/6,1 +s2/9,2 +s2/2,2 +s2/4,2 +s2/5,2 +s2/7,2 +s2/8,2 +s2/1,2 +s2/10,2 +s2/3,2 +s2/6,2 +s5/9,5 +s5/2,5 +s5/4,5 +s5/5,5 +s5/7,5 +s5/8,5 +s5/1,5 +s5/10,5 +s5/3,5 +s5/6,5 +s6/9,6 +s6/2,6 +s6/4,6 +s6/5,6 +s6/7,6 +s6/8,6 +s6/1,6 +s6/10,6 +s6/3,6 +s6/6,6 +s10/9,10 +s10/2,10 +s10/4,10 +s10/5,10 +s10/7,10 +s10/8,10 +s10/1,10 +s10/10,10 +s10/3,10 +s10/6,10 +s11/9,11 +s11/2,11 +s11/4,11 +s11/5,11 +s11/7,11 +s11/8,11 +s11/1,11 +s11/10,11 +s11/3,11 +s11/6,11 +s12/9,12 +s12/2,12 +s12/4,12 +s12/5,12 +s12/7,12 +s12/8,12 +s12/1,12 +s12/10,12 +s12/3,12 +s12/6,12 +s14/9,14 +s14/2,14 +s14/4,14 +s14/5,14 +s14/7,14 +s14/8,14 +s14/1,14 +s14/10,14 +s14/3,14 +s14/6,14 +s16/9,16 +s16/2,16 +s16/4,16 +s16/5,16 +s16/7,16 +s16/8,16 +s16/1,16 +s16/10,16 +s16/3,16 +s16/6,16 +s17/9,17 +s17/2,17 +s17/4,17 +s17/5,17 +s17/7,17 +s17/8,17 +s17/1,17 +s17/10,17 +s17/3,17 +s17/6,17 +s20/9,20 +s20/2,20 +s20/4,20 +s20/5,20 +s20/7,20 +s20/8,20 +s20/1,20 +s20/10,20 +s20/3,20 +s20/6,20 +s21/9,21 +s21/2,21 +s21/4,21 +s21/5,21 +s21/7,21 +s21/8,21 +s21/1,21 +s21/10,21 +s21/3,21 +s21/6,21 +s24/9,24 +s24/2,24 +s24/4,24 +s24/5,24 +s24/7,24 +s24/8,24 +s24/1,24 +s24/10,24 +s24/3,24 +s24/6,24 +s26/9,26 +s26/2,26 +s26/4,26 +s26/5,26 +s26/7,26 +s26/8,26 +s26/1,26 +s26/10,26 +s26/3,26 +s26/6,26 +s27/9,27 +s27/2,27 +s27/4,27 +s27/5,27 +s27/7,27 +s27/8,27 +s27/1,27 +s27/10,27 +s27/3,27 +s27/6,27 +s29/9,29 +s29/2,29 +s29/4,29 +s29/5,29 +s29/7,29 +s29/8,29 +s29/1,29 +s29/10,29 +s29/3,29 +s29/6,29 +s33/9,33 +s33/2,33 +s33/4,33 +s33/5,33 +s33/7,33 +s33/8,33 +s33/1,33 +s33/10,33 +s33/3,33 +s33/6,33 +s34/9,34 +s34/2,34 +s34/4,34 +s34/5,34 +s34/7,34 +s34/8,34 +s34/1,34 +s34/10,34 +s34/3,34 +s34/6,34 +s36/9,36 +s36/2,36 +s36/4,36 +s36/5,36 +s36/7,36 +s36/8,36 +s36/1,36 +s36/10,36 +s36/3,36 +s36/6,36 +s39/9,39 +s39/2,39 +s39/4,39 +s39/5,39 +s39/7,39 +s39/8,39 +s39/1,39 +s39/10,39 +s39/3,39 +s39/6,39 +s3/9,3 +s3/2,3 +s3/4,3 +s3/5,3 +s3/7,3 +s4/9,4 +s4/2,4 +s4/4,4 +s4/5,4 +s4/7,4 +s7/9,7 +s7/2,7 +s7/4,7 +s7/5,7 +s7/7,7 +s8/9,8 +s8/2,8 +s8/4,8 +s8/5,8 +s8/7,8 +s9/9,9 +s9/2,9 +s9/4,9 +s9/5,9 +s9/7,9 +s13/9,13 +s13/2,13 +s13/4,13 +s13/5,13 +s13/7,13 +s15/9,15 +s15/2,15 +s15/4,15 +s15/5,15 +s15/7,15 +s18/9,18 +s18/2,18 +s18/4,18 +s18/5,18 +s18/7,18 +s19/9,19 +s19/2,19 +s19/4,19 +s19/5,19 +s19/7,19 +s22/9,22 +s22/2,22 +s22/4,22 +s22/5,22 +s22/7,22 +s23/9,23 +s23/2,23 +s23/4,23 +s23/5,23 +s23/7,23 +s25/9,25 +s25/2,25 +s25/4,25 +s25/5,25 +s25/7,25 +s28/9,28 +s28/2,28 +s28/4,28 +s28/5,28 +s28/7,28 +s30/9,30 +s30/2,30 +s30/4,30 +s30/5,30 +s30/7,30 +s31/9,31 +s31/2,31 +s31/4,31 +s31/5,31 +s31/7,31 +s32/9,32 +s32/2,32 +s32/4,32 +s32/5,32 +s32/7,32 +s35/9,35 +s35/2,35 +s35/4,35 +s35/5,35 +s35/7,35 +s37/9,37 +s37/2,37 +s37/4,37 +s37/5,37 +s37/7,37 +s38/9,38 +s38/2,38 +s38/4,38 +s38/5,38 +s38/7,38 +s40/9,40 +s40/2,40 +s40/4,40 +s40/5,40 +s40/7,40 +s3/8,3 +s3/1,3 +s3/10,3 +s3/3,3 +s3/6,3 +s4/8,4 +s4/1,4 +s4/10,4 +s4/3,4 +s4/6,4 +s7/8,7 +s7/1,7 +s7/10,7 +s7/3,7 +s7/6,7 +s8/8,8 +s8/1,8 +s8/10,8 +s8/3,8 +s8/6,8 +s9/8,9 +s9/1,9 +s9/10,9 +s9/3,9 +s9/6,9 +s13/8,13 +s13/1,13 +s13/10,13 +s13/3,13 +s13/6,13 +s15/8,15 +s15/1,15 +s15/10,15 +s15/3,15 +s15/6,15 +s18/8,18 +s18/1,18 +s18/10,18 +s18/3,18 +s18/6,18 +s19/8,19 +s19/1,19 +s19/10,19 +s19/3,19 +s19/6,19 +s22/8,22 +s22/1,22 +s22/10,22 +s22/3,22 +s22/6,22 +s23/8,23 +s23/1,23 +s23/10,23 +s23/3,23 +s23/6,23 +s25/8,25 +s25/1,25 +s25/10,25 +s25/3,25 +s25/6,25 +s28/8,28 +s28/1,28 +s28/10,28 +s28/3,28 +s28/6,28 +s30/8,30 +s30/1,30 +s30/10,30 +s30/3,30 +s30/6,30 +s31/8,31 +s31/1,31 +s31/10,31 +s31/3,31 +s31/6,31 +s32/8,32 +s32/1,32 +s32/10,32 +s32/3,32 +s32/6,32 +s35/8,35 +s35/1,35 +s35/10,35 +s35/3,35 +s35/6,35 +s37/8,37 +s37/1,37 +s37/10,37 +s37/3,37 +s37/6,37 +s38/8,38 +s38/1,38 +s38/10,38 +s38/3,38 +s38/6,38 +s40/8,40 +s40/1,40 +s40/10,40 +s40/3,40 +s40/6,40 diff --git a/bob/bio/base/test/test_filelist.py b/bob/bio/base/test/test_filelist.py index 9e154c48..939dded8 100644 --- a/bob/bio/base/test/test_filelist.py +++ b/bob/bio/base/test/test_filelist.py @@ -7,7 +7,7 @@ import os import bob.io.base import bob.io.base.test_utils -from bob.bio.base.database import CSVDatasetDevEval, CSVToSampleLoader +from bob.bio.base.database import CSVDatasetDevEval, CSVToSampleLoader, CSVDatasetCrossValidation import nose.tools from bob.pipelines import DelayedSample, SampleSet import numpy as np @@ -28,6 +28,10 @@ atnt_protocol_path = os.path.realpath( bob.io.base.test_utils.datafile(".", __name__, "data/atnt") ) +atnt_protocol_path_cross_validation = os.path.join(os.path.realpath( + bob.io.base.test_utils.datafile(".", __name__, "data/atnt/cross_validation/") +),"metadata.csv") + def check_all_true(list_of_something, something): """ @@ -100,36 +104,67 @@ def test_csv_file_list_atnt(): assert len(dataset.probes()) == 100 -def test_atnt_experiment(): - def load(path): - import bob.io.image - return bob.io.base.load(path) +def run_experiment(dataset): def linearize(X): X = np.asarray(X) return np.reshape(X, (X.shape[0], -1)) - dataset = CSVDatasetDevEval( - dataset_protocol_path=atnt_protocol_path, - protocol_name="idiap_protocol", - csv_to_sample_loader=CSVToSampleLoader( - data_loader=load, - dataset_original_directory=atnt_database_directory(), - extension=".pgm", - ), - ) - #### Testing it in a real recognition systems transformer = wrap(["sample"], make_pipeline(FunctionTransformer(linearize))) vanilla_biometrics_pipeline = VanillaBiometricsPipeline(transformer, Distance()) - scores = vanilla_biometrics_pipeline( + return vanilla_biometrics_pipeline( dataset.background_model_samples(), dataset.references(), dataset.probes(), ) + +def data_loader(path): + import bob.io.image + return bob.io.base.load(path) + +def test_atnt_experiment(): + + dataset = CSVDatasetDevEval( + dataset_protocol_path=atnt_protocol_path, + protocol_name="idiap_protocol", + csv_to_sample_loader=CSVToSampleLoader( + data_loader=data_loader, + dataset_original_directory=atnt_database_directory(), + extension=".pgm", + ), + ) + + scores = run_experiment(dataset) assert len(scores)==100 - assert np.alltrue([len(s)==20] for s in scores) \ No newline at end of file + assert np.alltrue([len(s)==20] for s in scores) + + +def test_atnt_experiment_cross_validation(): + + samples_per_identity = 10 + total_identities = 40 + samples_for_enrollment = 1 + + def run_cross_validataion_experiment(test_size = 0.9): + dataset = CSVDatasetCrossValidation( + csv_file_name=atnt_protocol_path_cross_validation, + random_state=0, + test_size=test_size, + csv_to_sample_loader=CSVToSampleLoader( + data_loader=data_loader, + dataset_original_directory=atnt_database_directory(), + extension=".pgm", + ), + ) + + scores = run_experiment(dataset) + assert len(scores)==int(total_identities*test_size*(samples_per_identity-samples_for_enrollment)) + + run_cross_validataion_experiment(test_size = 0.9) + run_cross_validataion_experiment(test_size = 0.8) + run_cross_validataion_experiment(test_size = 0.5) -- GitLab