csv_dataset.py 27 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :


import os
from bob.pipelines import Sample, DelayedSample, SampleSet
7
from bob.db.base.utils import check_parameters_for_validity
8
9
10
11
import csv
import bob.io.base
import functools
from abc import ABCMeta, abstractmethod
12
13
import numpy as np
import itertools
14
import logging
15
import bob.db.base
16
from bob.extension.download import find_element_in_tarball
17
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
18

19
logger = logging.getLogger(__name__)
20

21
22
23
24

#####
# ANNOTATIONS LOADERS
####
25
class AnnotationsLoader:
26
    """
27
28
    Metadata loader that loads annotations in the Idiap format using the function
    :any:`bob.db.base.read_annotation_file`
29
30
31
32
33
    """

    def __init__(
        self,
        annotation_directory=None,
34
        annotation_extension=".json",
35
        annotation_type="json",
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    ):
        self.annotation_directory = annotation_directory
        self.annotation_extension = annotation_extension
        self.annotation_type = annotation_type

    def __call__(self, row, header=None):
        if self.annotation_directory is None:
            return None

        path = row[0]

        # since the file id is equal to the file name, we can simply use it
        annotation_file = os.path.join(
            self.annotation_directory, path + self.annotation_extension
        )

        # return the annotations as read from file
        annotation = {
            "annotations": bob.db.base.read_annotation_file(
                annotation_file, self.annotation_type
            )
        }
        return annotation


#######
# SAMPLE LOADERS
# CONVERT CSV LINES TO SAMPLES
#######


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
67
class CSVBaseSampleLoader(metaclass=ABCMeta):
68
69
    """    
    Base class that converts the lines of a CSV file, like the one below to
70
71
72
73
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`

    .. code-block:: text

74
75
76
77
       PATH,REFERENCE_ID
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
78
79
80
81
82
83
84
85
86
87
88
89
       ...

    .. note::
       This class should be extended

    Parameters
    ----------

        data_loader:
            A python function that can be called parameterlessly, to load the
            sample in question from whatever medium

90
91
92
93
94
95
96
97
        metadata_loader:
            AnnotationsLoader

        dataset_original_directory: str
            Path of where data is stored
        
        extension: str
            Default file extension
98
99
100

    """

101
102
103
104
105
106
107
    def __init__(
        self,
        data_loader,
        metadata_loader=None,
        dataset_original_directory="",
        extension="",
    ):
108
109
        self.data_loader = data_loader
        self.extension = extension
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
110
        self.dataset_original_directory = dataset_original_directory
111
        self.metadata_loader = metadata_loader
112
113
114
115
116
117
118
119
120

    @abstractmethod
    def __call__(self, filename):
        pass

    @abstractmethod
    def convert_row_to_sample(self, row, header):
        pass

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121
    def convert_samples_to_samplesets(
122
        self, samples, group_by_reference_id=True, references=None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123
    ):
124
125
126
127
128
129
        if group_by_reference_id:

            # Grouping sample sets
            sample_sets = dict()
            for s in samples:
                if s.reference_id not in sample_sets:
130
131
132
133
                    sample_sets[s.reference_id] = (
                        SampleSet([s], parent=s)
                        if references is None
                        else SampleSet([s], parent=s, references=references)
134
135
136
137
138
139
                    )
                else:
                    sample_sets[s.reference_id].append(s)
            return list(sample_sets.values())

        else:
140
141
142
143
144
            return (
                [SampleSet([s], parent=s) for s in samples]
                if references is None
                else [SampleSet([s], parent=s, references=references) for s in samples]
            )
145
146


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
147
class CSVToSampleLoader(CSVBaseSampleLoader):
148
    """
149
    Simple mechanism that converts the lines of a CSV file to
150
151
152
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

153
154
    def check_header(self, header):
        """
155
        A header should have at least "reference_id" AND "PATH"
156
157
        """
        header = [h.lower() for h in header]
158
159
160
161
        if not "reference_id" in header:
            raise ValueError(
                "The field `reference_id` is not available in your dataset."
            )
162

163
164
165
        if not "path" in header:
            raise ValueError("The field `path` is not available in your dataset.")

166
167
168
169
    def __call__(self, f):
        f.seek(0)
        reader = csv.reader(f)
        header = next(reader)
170

171
172
        self.check_header(header)
        return [self.convert_row_to_sample(row, header) for row in reader]
173
174
175

    def convert_row_to_sample(self, row, header):
        path = row[0]
176
177
        reference_id = row[1]

178
        kwargs = dict([[str(h).lower(), r] for h, r in zip(header[2:], row[2:])])
179
180

        if self.metadata_loader is not None:
181
            metadata = self.metadata_loader(row, header=header)
182
183
            kwargs.update(metadata)

184
        return DelayedSample(
185
186
187
188
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
189
            key=path,
190
            reference_id=reference_id,
191
192
193
194
            **kwargs,
        )


195
196
class LSTToSampleLoader(CSVBaseSampleLoader):
    """
197
    Simple mechanism that converts the lines of a LST file to
198
199
200
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

201
202
203
204
205
206
207
208
    def __call__(self, f):
        f.seek(0)
        reader = csv.reader(f, delimiter=" ")
        samples = []
        for row in reader:
            if row[0][0] == "#":
                continue
            samples.append(self.convert_row_to_sample(row))
209

210
        return samples
211
212
213

    def convert_row_to_sample(self, row, header=None):

214
215
216
217
218
219
220
221
222
223
224
225
        if len(row) == 4:
            path = row[0]
            compare_reference_id = row[1]
            reference_id = str(row[3])
            kwargs = {"compare_reference_id": str(compare_reference_id)}
        else:
            path = row[0]
            reference_id = str(row[1])
            kwargs = dict()
            if len(row) == 3:
                subject = row[2]
                kwargs = {"subject": str(subject)}
226
227

        if self.metadata_loader is not None:
228
            metadata = self.metadata_loader(row, header=header)
229
230
231
232
233
234
235
236
237
238
239
            kwargs.update(metadata)

        return DelayedSample(
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
            key=path,
            reference_id=reference_id,
            **kwargs,
        )
240
241


242
243
244
#####
# DATABASE INTERFACES
#####
245

246

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def path_discovery(dataset_protocol_path, option1, option2):

    # If the input is a directory
    if os.path.isdir(dataset_protocol_path):
        option1 = os.path.join(dataset_protocol_path, option1)
        option2 = os.path.join(dataset_protocol_path, option2)
        if os.path.exists(option1):
            return open(option1)
        else:
            return open(option2) if os.path.exists(option2) else None

    # If it's not a directory is a tarball
    op1 = find_element_in_tarball(dataset_protocol_path, option1)
    return op1 if op1 else find_element_in_tarball(dataset_protocol_path, option2)


263
class CSVDatasetDevEval(Database):
264
    """
265
266
    Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
267
268
269
270
271
272
273
    interface.

    To create a new dataset, you need to provide a directory structure similar to the one below:

    .. code-block:: text

       my_dataset/
274
275
276
277
278
       my_dataset/my_protocol/norm/train_world.csv
       my_dataset/my_protocol/dev/for_models.csv
       my_dataset/my_protocol/dev/for_probes.csv
       my_dataset/my_protocol/eval/for_models.csv
       my_dataset/my_protocol/eval/for_probes.csv
279
280
281
282
283
284
285
       ...


    In the above directory structure, inside of `my_dataset` should contain the directories with all
    evaluation protocols this dataset might have.
    Inside of the `my_protocol` directory should contain at least two csv files:

286
287
     - for_models.csv
     - for_probes.csv
288
289


290
    Those csv files should contain in each row i-) the path to raw data and ii-) the reference_id label
291
292
    for enrollment (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.references`) and
    probing (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.probes`).
293
294
295
296
    The structure of each CSV file should be as below:

    .. code-block:: text

297
298
299
300
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
301
302
       ...

303

304
305
306
307
308
    You might want to ship metadata within your Samples (e.g gender, age, annotation, ...)
    To do so is simple, just do as below:

    .. code-block:: text

309
310
311
312
       PATH,reference_id,METADATA_1,METADATA_2,METADATA_k
       path_1,reference_id_1,A,B,C
       path_2,reference_id_2,A,B,1
       path_i,reference_id_j,2,3,4
313
314
315
316
317
       ...


    The files `my_dataset/my_protocol/train.csv/eval_enroll.csv` and `my_dataset/my_protocol/train.csv/eval_probe.csv`
    are optional and it is used in case a protocol contains data for evaluation.
318

319
    Finally, the content of the file `my_dataset/my_protocol/train.csv` is used in the case a protocol
320
    contains data for training (`bob.bio.base.pipelines.vanilla_biometrics.Database.background_model_samples`)
321
322
323
324
325

    Parameters
    ----------

        dataset_path: str
326
          Absolute path or a tarball of the dataset protocol description.
327

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
328
        protocol_na,e: str
329
330
          The name of the protocol

331
332
        csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
            Base class that whose objective is to generate :any:`bob.pipelines.Sample`
333
            and/or :any:`bob.pipelines.SampleSet` from csv rows
334
    
335
336
337
338
339

    """

    def __init__(
        self,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
340
341
        dataset_protocol_path,
        protocol_name,
342
        csv_to_sample_loader=CSVToSampleLoader(
343
344
345
346
            data_loader=bob.io.base.load,
            metadata_loader=None,
            dataset_original_directory="",
            extension="",
347
        ),
348
        is_sparse=False,
349
    ):
350
        self.dataset_protocol_path = dataset_protocol_path
351
        self.is_sparse = is_sparse
352
        self.protocol_name = protocol_name
353

354
355
        def get_paths():

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
356
357
            if not os.path.exists(dataset_protocol_path):
                raise ValueError(f"The path `{dataset_protocol_path}` was not found")
358

359
360
            # Here we are handling the legacy
            train_csv = path_discovery(
361
                dataset_protocol_path,
362
363
                os.path.join(protocol_name, "norm", "train_world.lst"),
                os.path.join(protocol_name, "norm", "train_world.csv"),
364
365
366
            )

            dev_enroll_csv = path_discovery(
367
                dataset_protocol_path,
368
369
                os.path.join(protocol_name, "dev", "for_models.lst"),
                os.path.join(protocol_name, "dev", "for_models.csv"),
370
371
            )

372
            legacy_probe = "for_scores.lst" if self.is_sparse else "for_probes.lst"
373
            dev_probe_csv = path_discovery(
374
                dataset_protocol_path,
375
376
                os.path.join(protocol_name, "dev", legacy_probe),
                os.path.join(protocol_name, "dev", "for_probes.csv"),
377
378
379
            )

            eval_enroll_csv = path_discovery(
380
                dataset_protocol_path,
381
382
                os.path.join(protocol_name, "eval", "for_models.lst"),
                os.path.join(protocol_name, "eval", "for_models.csv"),
383
384
385
            )

            eval_probe_csv = path_discovery(
386
                dataset_protocol_path,
387
388
                os.path.join(protocol_name, "eval", legacy_probe),
                os.path.join(protocol_name, "eval", "for_probes.csv"),
389
            )
390
391
392
393

            # The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`

            # Dev
394
            if dev_enroll_csv is None:
395
396
397
398
                raise ValueError(
                    f"The file `{dev_enroll_csv}` is required and it was not found"
                )

399
            if dev_probe_csv is None:
400
401
402
                raise ValueError(
                    f"The file `{dev_probe_csv}` is required and it was not found"
                )
403
404
            dev_enroll_csv = dev_enroll_csv
            dev_probe_csv = dev_probe_csv
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

            return (
                train_csv,
                dev_enroll_csv,
                dev_probe_csv,
                eval_enroll_csv,
                eval_probe_csv,
            )

        (
            self.train_csv,
            self.dev_enroll_csv,
            self.dev_probe_csv,
            self.eval_enroll_csv,
            self.eval_probe_csv,
        ) = get_paths()

        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            cache["eval_enroll_csv"] = None
            cache["eval_probe_csv"] = None
            return cache

        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader

    def background_model_samples(self):
        self.cache["train"] = (
            self.csv_to_sample_loader(self.train_csv)
            if self.cache["train"] is None
            else self.cache["train"]
        )

        return self.cache["train"]

443
    def _get_samplesets(
444
445
446
447
448
449
        self,
        group="dev",
        cache_label=None,
        group_by_reference_id=False,
        fetching_probes=False,
        is_sparse=False,
450
    ):
451
452
453
454

        if self.cache[cache_label] is not None:
            return self.cache[cache_label]

455
        # Getting samples from CSV
456
        samples = self.csv_to_sample_loader(self.__getattribute__(cache_label))
457

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
458
        references = None
459
        if fetching_probes and is_sparse:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
460

461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
            # Checking if `is_sparse` was set properly
            if len(samples) > 0 and not hasattr(samples[0], "compare_reference_id"):
                ValueError(
                    f"Attribute `compare_reference_id` not found in `{samples[0]}`."
                    "Make sure this attribute exists in your dataset if `is_sparse=True`"
                )

            sparse_samples = dict()
            for s in samples:
                if s.key in sparse_samples:
                    sparse_samples[s.key].references.append(s.compare_reference_id)
                else:
                    s.references = [s.compare_reference_id]
                    sparse_samples[s.key] = s
            samples = sparse_samples.values()
        else:
            if fetching_probes:
                references = list(
                    set([s.reference_id for s in self.references(group=group)])
                )
481
482

        sample_sets = self.csv_to_sample_loader.convert_samples_to_samplesets(
483
            samples, group_by_reference_id=group_by_reference_id, references=references,
484
485
486
487
488
489
490
        )

        self.cache[cache_label] = sample_sets

        return self.cache[cache_label]

    def references(self, group="dev"):
491
492
        cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"

493
        return self._get_samplesets(
494
            group=group, cache_label=cache_label, group_by_reference_id=True
495
496
497
        )

    def probes(self, group="dev"):
498
499
        cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"

500
        return self._get_samplesets(
501
502
503
504
505
            group=group,
            cache_label=cache_label,
            group_by_reference_id=False,
            fetching_probes=True,
            is_sparse=self.is_sparse,
506
        )
507

508
509
510
511
512
513
514
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
515
516
517
518
519
520
521
            Groups to consider ('train', 'dev', and/or 'eval'). If `None` is
            given, returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
522
        """
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        valid_groups = ["train"]
        if self.dev_enroll_csv and self.dev_probe_csv:
            valid_groups.append("dev")
        if self.eval_enroll_csv and self.eval_probe_csv:
            valid_groups.append("eval")
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

537
        # Get train samples (background_model_samples returns a list of samples)
538
539
540
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
541
542
543
544
545

        # Get enroll and probe samples
        for group in groups:
            for purpose in ("enroll", "probe"):
                label = f"{group}_{purpose}_csv"
546
547
548
                samples = samples + self.csv_to_sample_loader(
                    self.__getattribute__(label)
                )
549
550
        return samples

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
    def groups(self):
        """This function returns the list of groups for this database.

        Returns
        -------

        [str]
          A list of groups
        """

        # We always have dev-set
        groups = ["dev"]

        if self.train_csv is not None:
            groups.append("train")

        if self.eval_enroll_csv is not None:
            groups.append("eval")

        return groups

572

573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
class CSVDatasetDevEvalZTNorm(Database):
    """
    Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.ZTNormPipeline` pipelines.
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
    interface. 

    This dataset interface takes as in put a :any:`CSVDatasetDevEval` as input and have two extra methods:
    :any:`CSVDatasetDevEvalZTNorm.zprobes` and :any:`CSVDatasetDevEvalZTNorm.treferences`.

    To create a new dataset, you need to provide a directory structure similar to the one below:

    .. code-block:: text

       my_dataset/
       my_dataset/my_protocol/norm/train_world.csv
       my_dataset/my_protocol/norm/for_znorm.csv
       my_dataset/my_protocol/norm/for_tnorm.csv
       my_dataset/my_protocol/dev/for_models.csv
       my_dataset/my_protocol/dev/for_probes.csv
       my_dataset/my_protocol/eval/for_models.csv
       my_dataset/my_protocol/eval/for_probes.csv

    Parameters
    ----------
    
      database: :any:`CSVDatasetDevEval`
         :any:`CSVDatasetDevEval` to be aggregated

    """

    def __init__(self, database):
        self.database = database
        self.cache = self.database.cache
        self.csv_to_sample_loader = self.database.csv_to_sample_loader
        self.protocol_name = self.database.protocol_name
        self.dataset_protocol_path = self.database.dataset_protocol_path
        self._get_samplesets = self.database._get_samplesets

        ## create_cache
        self.cache["znorm_csv"] = None
        self.cache["tnorm_csv"] = None

        znorm_csv = path_discovery(
            self.dataset_protocol_path,
            os.path.join(self.protocol_name, "norm", "for_znorm.lst"),
            os.path.join(self.protocol_name, "norm", "for_znorm.csv"),
        )

        tnorm_csv = path_discovery(
            self.dataset_protocol_path,
            os.path.join(self.protocol_name, "norm", "for_tnorm.lst"),
            os.path.join(self.protocol_name, "norm", "for_tnorm.csv"),
        )

        if znorm_csv is None:
            raise ValueError(
                f"The file `for_znorm.lst` is required and it was not found in `{self.protocol_name}/norm` "
            )

        if tnorm_csv is None:
            raise ValueError(
                f"The file `for_tnorm.csv` is required and it was not found `{self.protocol_name}/norm`"
            )

        self.database.znorm_csv = znorm_csv
        self.database.tnorm_csv = tnorm_csv

    def background_model_samples(self):
        return self.database.background_model_samples()

    def references(self, group="dev"):
        return self.database.references(group=group)

    def probes(self, group="dev"):
        return self.database.probes(group=group)

    def all_samples(self, groups=None):
        return self.database.all_samples(groups=groups)

    def groups(self):
        return self.database.groups()

    def zprobes(self, group="dev", proportion=1.0):

        if proportion <= 0 or proportion > 1:
            raise ValueError(
                f"Invalid proportion value ({proportion}). Values allowed from [0-1]"
            )

        cache_label = "znorm_csv"
        samplesets = self._get_samplesets(
            group=group,
            cache_label=cache_label,
            group_by_reference_id=False,
            fetching_probes=True,
            is_sparse=False,
        )

        zprobes = samplesets[: int(len(samplesets) * proportion)]

        return zprobes

    def treferences(self, covariate="sex", proportion=1.0):

        if proportion <= 0 or proportion > 1:
            raise ValueError(
                f"Invalid proportion value ({proportion}). Values allowed from [0-1]"
            )

        cache_label = "tnorm_csv"
        samplesets = self._get_samplesets(
            group="dev", cache_label=cache_label, group_by_reference_id=True,
        )

        treferences = samplesets[: int(len(samplesets) * proportion)]

        return treferences


692
693
class CSVDatasetCrossValidation:
    """
694
    Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline that
695
696
    handles **CROSS VALIDATION**.

697
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
698
699
700
701
702
703
704
705
706
707
    interface.


    This interface will take one `csv_file` as input and split into i-) data for training and
    ii-) data for testing.
    The data for testing will be further split in data for enrollment and data for probing.
    The input CSV file should be casted in the following format:

    .. code-block:: text

708
709
710
711
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
712
713
714
715
716
717
718
719
720
721
722
723
       ...

    Parameters
    ----------

    csv_file_name: str
      CSV file containing all the samples from your database

    random_state: int
      Pseudo-random number generator seed

    test_size: float
724
      Percentage of the reference_ids used for testing
725
726
727
728

    samples_for_enrollment: float
      Number of samples used for enrollment

729
730
    csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
        Base class that whose objective is to generate :any:`bob.pipelines.Sample`
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
        and/or :any:`bob.pipelines.SampleSet` from csv rows

    """

    def __init__(
        self,
        csv_file_name="metadata.csv",
        random_state=0,
        test_size=0.8,
        samples_for_enrollment=1,
        csv_to_sample_loader=CSVToSampleLoader(
            data_loader=bob.io.base.load, dataset_original_directory="", extension=""
        ),
    ):
        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            return cache

        self.random_state = random_state
        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader
755
        self.csv_file_name = open(csv_file_name)
756
757
758
759
760
761
762
763
764
765
        self.samples_for_enrollment = samples_for_enrollment
        self.test_size = test_size

        if self.test_size < 0 and self.test_size > 1:
            raise ValueError(
                f"`test_size` should be between 0 and 1. {test_size} is provided"
            )

    def _do_cross_validation(self):

766
767
        # Shuffling samples by reference_id
        samples_by_reference_id = group_samples_by_reference_id(
768
769
            self.csv_to_sample_loader(self.csv_file_name)
        )
770
        reference_ids = list(samples_by_reference_id.keys())
771
        np.random.seed(self.random_state)
772
        np.random.shuffle(reference_ids)
773
774

        # Getting the training data
775
776
777
        n_samples_for_training = len(reference_ids) - int(
            self.test_size * len(reference_ids)
        )
778
779
        self.cache["train"] = list(
            itertools.chain(
780
781
782
783
                *[
                    samples_by_reference_id[s]
                    for s in reference_ids[0:n_samples_for_training]
                ]
784
785
786
787
788
789
            )
        )

        # Splitting enroll and probe
        self.cache["dev_enroll_csv"] = []
        self.cache["dev_probe_csv"] = []
790
791
        for s in reference_ids[n_samples_for_training:]:
            samples = samples_by_reference_id[s]
792
793
            if len(samples) < self.samples_for_enrollment:
                raise ValueError(
794
                    f"Not enough samples ({len(samples)}) for enrollment for the reference_id {s}"
795
796
797
798
799
800
801
802
803
804
805
806
807
                )

            # Enrollment samples
            self.cache["dev_enroll_csv"].append(
                self.csv_to_sample_loader.convert_samples_to_samplesets(
                    samples[0 : self.samples_for_enrollment]
                )[0]
            )

            self.cache[
                "dev_probe_csv"
            ] += self.csv_to_sample_loader.convert_samples_to_samplesets(
                samples[self.samples_for_enrollment :],
808
809
                group_by_reference_id=False,
                references=reference_ids[n_samples_for_training:],
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
            )

    def _load_from_cache(self, cache_key):
        if self.cache[cache_key] is None:
            self._do_cross_validation()
        return self.cache[cache_key]

    def background_model_samples(self):
        return self._load_from_cache("train")

    def references(self, group="dev"):
        return self._load_from_cache("dev_enroll_csv")

    def probes(self, group="dev"):
        return self._load_from_cache("dev_probe_csv")

826
827
828
829
830
831
832
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
833
834
835
836
837
838
839
            Groups to consider ('train' and/or 'dev'). If `None` is given,
            returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
840
        """
841
842
843
844
845
846
847
848
849
850
        valid_groups = ["train", "dev"]
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

851
        # Get train samples (background_model_samples returns a list of samples)
852
853
854
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
855
856
857

        # Get enroll and probe samples
        for group in groups:
858
859
            samples = samples + [s for s_set in self.references(group) for s in s_set]
            samples = samples + [s for s_set in self.probes(group) for s in s_set]
860
861
        return samples

862

863
def group_samples_by_reference_id(samples):
864
865

    # Grouping sample sets
866
    samples_by_reference_id = dict()
867
    for s in samples:
868
869
870
871
        if s.reference_id not in samples_by_reference_id:
            samples_by_reference_id[s.reference_id] = []
        samples_by_reference_id[s.reference_id].append(s)
    return samples_by_reference_id