csv_dataset.py 26.9 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :


import os
from bob.pipelines import Sample, DelayedSample, SampleSet
7
from bob.db.base.utils import check_parameters_for_validity
8 9 10 11
import csv
import bob.io.base
import functools
from abc import ABCMeta, abstractmethod
12 13
import numpy as np
import itertools
14
import logging
15
import bob.db.base
16
from bob.extension.download import find_element_in_tarball
17
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
18

19
logger = logging.getLogger(__name__)
20

21 22 23 24

#####
# ANNOTATIONS LOADERS
####
25
class AnnotationsLoader:
26
    """
27 28
    Metadata loader that loads annotations in the Idiap format using the function
    :any:`bob.db.base.read_annotation_file`
29 30 31 32 33
    """

    def __init__(
        self,
        annotation_directory=None,
34
        annotation_extension=".json",
35
        annotation_type="json",
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    ):
        self.annotation_directory = annotation_directory
        self.annotation_extension = annotation_extension
        self.annotation_type = annotation_type

    def __call__(self, row, header=None):
        if self.annotation_directory is None:
            return None

        path = row[0]

        # since the file id is equal to the file name, we can simply use it
        annotation_file = os.path.join(
            self.annotation_directory, path + self.annotation_extension
        )

        # return the annotations as read from file
        annotation = {
            "annotations": bob.db.base.read_annotation_file(
                annotation_file, self.annotation_type
            )
        }
        return annotation


#######
# SAMPLE LOADERS
# CONVERT CSV LINES TO SAMPLES
#######


67
class CSVBaseSampleLoader(metaclass=ABCMeta):
68 69
    """    
    Base class that converts the lines of a CSV file, like the one below to
70 71 72 73
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`

    .. code-block:: text

74 75 76 77
       PATH,REFERENCE_ID
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
78 79 80 81 82 83 84 85 86 87 88 89
       ...

    .. note::
       This class should be extended

    Parameters
    ----------

        data_loader:
            A python function that can be called parameterlessly, to load the
            sample in question from whatever medium

90 91 92 93 94 95 96 97
        metadata_loader:
            AnnotationsLoader

        dataset_original_directory: str
            Path of where data is stored
        
        extension: str
            Default file extension
98 99 100

    """

101 102 103 104 105 106 107
    def __init__(
        self,
        data_loader,
        metadata_loader=None,
        dataset_original_directory="",
        extension="",
    ):
108 109
        self.data_loader = data_loader
        self.extension = extension
110
        self.dataset_original_directory = dataset_original_directory
111
        self.metadata_loader = metadata_loader
112 113 114 115 116 117 118 119 120

    @abstractmethod
    def __call__(self, filename):
        pass

    @abstractmethod
    def convert_row_to_sample(self, row, header):
        pass

121
    def convert_samples_to_samplesets(
122
        self, samples, group_by_reference_id=True, references=None
123
    ):
124 125 126 127 128 129
        if group_by_reference_id:

            # Grouping sample sets
            sample_sets = dict()
            for s in samples:
                if s.reference_id not in sample_sets:
130 131 132 133
                    sample_sets[s.reference_id] = (
                        SampleSet([s], parent=s)
                        if references is None
                        else SampleSet([s], parent=s, references=references)
134 135 136 137 138 139
                    )
                else:
                    sample_sets[s.reference_id].append(s)
            return list(sample_sets.values())

        else:
140 141 142 143 144
            return (
                [SampleSet([s], parent=s) for s in samples]
                if references is None
                else [SampleSet([s], parent=s, references=references) for s in samples]
            )
145 146


147
class CSVToSampleLoader(CSVBaseSampleLoader):
148
    """
149
    Simple mechanism that converts the lines of a CSV file to
150 151 152
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

153 154
    def check_header(self, header):
        """
155
        A header should have at least "reference_id" AND "PATH"
156 157
        """
        header = [h.lower() for h in header]
158 159 160 161
        if not "reference_id" in header:
            raise ValueError(
                "The field `reference_id` is not available in your dataset."
            )
162

163 164 165
        if not "path" in header:
            raise ValueError("The field `path` is not available in your dataset.")

166 167 168 169
    def __call__(self, f):
        f.seek(0)
        reader = csv.reader(f)
        header = next(reader)
170

171 172
        self.check_header(header)
        return [self.convert_row_to_sample(row, header) for row in reader]
173 174 175

    def convert_row_to_sample(self, row, header):
        path = row[0]
176 177
        reference_id = row[1]

178
        kwargs = dict([[str(h).lower(), r] for h, r in zip(header[2:], row[2:])])
179 180

        if self.metadata_loader is not None:
181
            metadata = self.metadata_loader(row, header=header)
182 183
            kwargs.update(metadata)

184
        return DelayedSample(
185 186 187 188
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
189
            key=path,
190
            reference_id=reference_id,
191 192 193 194
            **kwargs,
        )


195 196
class LSTToSampleLoader(CSVBaseSampleLoader):
    """
197
    Simple mechanism that converts the lines of a LST file to
198 199 200
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

201 202 203 204 205 206 207 208
    def __call__(self, f):
        f.seek(0)
        reader = csv.reader(f, delimiter=" ")
        samples = []
        for row in reader:
            if row[0][0] == "#":
                continue
            samples.append(self.convert_row_to_sample(row))
209

210
        return samples
211 212 213

    def convert_row_to_sample(self, row, header=None):

214 215 216 217 218 219 220 221 222 223 224 225
        if len(row) == 4:
            path = row[0]
            compare_reference_id = row[1]
            reference_id = str(row[3])
            kwargs = {"compare_reference_id": str(compare_reference_id)}
        else:
            path = row[0]
            reference_id = str(row[1])
            kwargs = dict()
            if len(row) == 3:
                subject = row[2]
                kwargs = {"subject": str(subject)}
226 227

        if self.metadata_loader is not None:
228
            metadata = self.metadata_loader(row, header=header)
229 230 231 232 233 234 235 236 237 238 239
            kwargs.update(metadata)

        return DelayedSample(
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
            key=path,
            reference_id=reference_id,
            **kwargs,
        )
240 241


242 243 244
#####
# DATABASE INTERFACES
#####
245

246

247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
def path_discovery(dataset_protocol_path, option1, option2):

    # If the input is a directory
    if os.path.isdir(dataset_protocol_path):
        option1 = os.path.join(dataset_protocol_path, option1)
        option2 = os.path.join(dataset_protocol_path, option2)
        if os.path.exists(option1):
            return open(option1)
        else:
            return open(option2) if os.path.exists(option2) else None

    # If it's not a directory is a tarball
    op1 = find_element_in_tarball(dataset_protocol_path, option1)
    return op1 if op1 else find_element_in_tarball(dataset_protocol_path, option2)


263
class CSVDataset(Database):
264
    """
265 266
    Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
267 268 269 270 271 272 273
    interface.

    To create a new dataset, you need to provide a directory structure similar to the one below:

    .. code-block:: text

       my_dataset/
274 275 276 277 278
       my_dataset/my_protocol/norm/train_world.csv
       my_dataset/my_protocol/dev/for_models.csv
       my_dataset/my_protocol/dev/for_probes.csv
       my_dataset/my_protocol/eval/for_models.csv
       my_dataset/my_protocol/eval/for_probes.csv
279 280 281 282 283 284 285
       ...


    In the above directory structure, inside of `my_dataset` should contain the directories with all
    evaluation protocols this dataset might have.
    Inside of the `my_protocol` directory should contain at least two csv files:

286 287
     - for_models.csv
     - for_probes.csv
288 289


290
    Those csv files should contain in each row i-) the path to raw data and ii-) the reference_id label
291 292
    for enrollment (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.references`) and
    probing (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.probes`).
293 294 295 296
    The structure of each CSV file should be as below:

    .. code-block:: text

297 298 299 300
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
301 302
       ...

303

304 305 306 307 308
    You might want to ship metadata within your Samples (e.g gender, age, annotation, ...)
    To do so is simple, just do as below:

    .. code-block:: text

309 310 311 312
       PATH,reference_id,METADATA_1,METADATA_2,METADATA_k
       path_1,reference_id_1,A,B,C
       path_2,reference_id_2,A,B,1
       path_i,reference_id_j,2,3,4
313 314 315 316 317
       ...


    The files `my_dataset/my_protocol/train.csv/eval_enroll.csv` and `my_dataset/my_protocol/train.csv/eval_probe.csv`
    are optional and it is used in case a protocol contains data for evaluation.
318

319
    Finally, the content of the file `my_dataset/my_protocol/train.csv` is used in the case a protocol
320
    contains data for training (`bob.bio.base.pipelines.vanilla_biometrics.Database.background_model_samples`)
321 322 323 324 325

    Parameters
    ----------

        dataset_path: str
326
          Absolute path or a tarball of the dataset protocol description.
327

328
        protocol_na,e: str
329 330
          The name of the protocol

331 332
        csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
            Base class that whose objective is to generate :any:`bob.pipelines.Sample`
333
            and/or :any:`bob.pipelines.SampleSet` from csv rows
334
    
335 336 337 338 339

    """

    def __init__(
        self,
340 341
        dataset_protocol_path,
        protocol_name,
342
        csv_to_sample_loader=CSVToSampleLoader(
343 344 345 346
            data_loader=bob.io.base.load,
            metadata_loader=None,
            dataset_original_directory="",
            extension="",
347
        ),
348
        is_sparse=False,
349
    ):
350
        self.dataset_protocol_path = dataset_protocol_path
351
        self.is_sparse = is_sparse
352
        self.protocol_name = protocol_name
353

354 355
        def get_paths():

356 357
            if not os.path.exists(dataset_protocol_path):
                raise ValueError(f"The path `{dataset_protocol_path}` was not found")
358

359 360
            # Here we are handling the legacy
            train_csv = path_discovery(
361
                dataset_protocol_path,
362 363
                os.path.join(protocol_name, "norm", "train_world.lst"),
                os.path.join(protocol_name, "norm", "train_world.csv"),
364 365 366
            )

            dev_enroll_csv = path_discovery(
367
                dataset_protocol_path,
368 369
                os.path.join(protocol_name, "dev", "for_models.lst"),
                os.path.join(protocol_name, "dev", "for_models.csv"),
370 371
            )

372
            legacy_probe = "for_scores.lst" if self.is_sparse else "for_probes.lst"
373
            dev_probe_csv = path_discovery(
374
                dataset_protocol_path,
375 376
                os.path.join(protocol_name, "dev", legacy_probe),
                os.path.join(protocol_name, "dev", "for_probes.csv"),
377 378 379
            )

            eval_enroll_csv = path_discovery(
380
                dataset_protocol_path,
381 382
                os.path.join(protocol_name, "eval", "for_models.lst"),
                os.path.join(protocol_name, "eval", "for_models.csv"),
383 384 385
            )

            eval_probe_csv = path_discovery(
386
                dataset_protocol_path,
387 388
                os.path.join(protocol_name, "eval", legacy_probe),
                os.path.join(protocol_name, "eval", "for_probes.csv"),
389
            )
390 391 392 393

            # The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`

            # Dev
394
            if dev_enroll_csv is None:
395 396 397 398
                raise ValueError(
                    f"The file `{dev_enroll_csv}` is required and it was not found"
                )

399
            if dev_probe_csv is None:
400 401 402
                raise ValueError(
                    f"The file `{dev_probe_csv}` is required and it was not found"
                )
403 404
            dev_enroll_csv = dev_enroll_csv
            dev_probe_csv = dev_probe_csv
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442

            return (
                train_csv,
                dev_enroll_csv,
                dev_probe_csv,
                eval_enroll_csv,
                eval_probe_csv,
            )

        (
            self.train_csv,
            self.dev_enroll_csv,
            self.dev_probe_csv,
            self.eval_enroll_csv,
            self.eval_probe_csv,
        ) = get_paths()

        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            cache["eval_enroll_csv"] = None
            cache["eval_probe_csv"] = None
            return cache

        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader

    def background_model_samples(self):
        self.cache["train"] = (
            self.csv_to_sample_loader(self.train_csv)
            if self.cache["train"] is None
            else self.cache["train"]
        )

        return self.cache["train"]

443
    def _get_samplesets(
444 445 446 447 448 449
        self,
        group="dev",
        cache_label=None,
        group_by_reference_id=False,
        fetching_probes=False,
        is_sparse=False,
450
    ):
451 452 453 454

        if self.cache[cache_label] is not None:
            return self.cache[cache_label]

455
        # Getting samples from CSV
456
        samples = self.csv_to_sample_loader(self.__getattribute__(cache_label))
457

458
        references = None
459
        if fetching_probes and is_sparse:
460

461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
            # Checking if `is_sparse` was set properly
            if len(samples) > 0 and not hasattr(samples[0], "compare_reference_id"):
                ValueError(
                    f"Attribute `compare_reference_id` not found in `{samples[0]}`."
                    "Make sure this attribute exists in your dataset if `is_sparse=True`"
                )

            sparse_samples = dict()
            for s in samples:
                if s.key in sparse_samples:
                    sparse_samples[s.key].references.append(s.compare_reference_id)
                else:
                    s.references = [s.compare_reference_id]
                    sparse_samples[s.key] = s
            samples = sparse_samples.values()
        else:
            if fetching_probes:
                references = list(
                    set([s.reference_id for s in self.references(group=group)])
                )
481 482

        sample_sets = self.csv_to_sample_loader.convert_samples_to_samplesets(
483
            samples, group_by_reference_id=group_by_reference_id, references=references,
484 485 486 487 488 489 490
        )

        self.cache[cache_label] = sample_sets

        return self.cache[cache_label]

    def references(self, group="dev"):
491 492
        cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"

493
        return self._get_samplesets(
494
            group=group, cache_label=cache_label, group_by_reference_id=True
495 496 497
        )

    def probes(self, group="dev"):
498 499
        cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"

500
        return self._get_samplesets(
501 502 503 504 505
            group=group,
            cache_label=cache_label,
            group_by_reference_id=False,
            fetching_probes=True,
            is_sparse=self.is_sparse,
506
        )
507

508 509 510 511 512 513 514
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
515 516 517 518 519 520 521
            Groups to consider ('train', 'dev', and/or 'eval'). If `None` is
            given, returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
522
        """
523 524 525 526 527 528 529 530 531 532 533 534 535 536
        valid_groups = ["train"]
        if self.dev_enroll_csv and self.dev_probe_csv:
            valid_groups.append("dev")
        if self.eval_enroll_csv and self.eval_probe_csv:
            valid_groups.append("eval")
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

537
        # Get train samples (background_model_samples returns a list of samples)
538 539 540
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
541 542 543 544 545

        # Get enroll and probe samples
        for group in groups:
            for purpose in ("enroll", "probe"):
                label = f"{group}_{purpose}_csv"
546 547 548
                samples = samples + self.csv_to_sample_loader(
                    self.__getattribute__(label)
                )
549 550
        return samples

551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
    def groups(self):
        """This function returns the list of groups for this database.

        Returns
        -------

        [str]
          A list of groups
        """

        # We always have dev-set
        groups = ["dev"]

        if self.train_csv is not None:
            groups.append("train")

        if self.eval_enroll_csv is not None:
            groups.append("eval")

        return groups

572

573
class CSVDatasetZTNorm(Database):
574 575 576 577 578
    """
    Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.ZTNormPipeline` pipelines.
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
    interface. 

579 580
    This dataset interface takes as in put a :any:`CSVDataset` as input and have two extra methods:
    :any:`CSVDatasetZTNorm.zprobes` and :any:`CSVDatasetZTNorm.treferences`.
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597

    To create a new dataset, you need to provide a directory structure similar to the one below:

    .. code-block:: text

       my_dataset/
       my_dataset/my_protocol/norm/train_world.csv
       my_dataset/my_protocol/norm/for_znorm.csv
       my_dataset/my_protocol/norm/for_tnorm.csv
       my_dataset/my_protocol/dev/for_models.csv
       my_dataset/my_protocol/dev/for_probes.csv
       my_dataset/my_protocol/eval/for_models.csv
       my_dataset/my_protocol/eval/for_probes.csv

    Parameters
    ----------
    
598 599
      database: :any:`CSVDataset`
         :any:`CSVDataset` to be aggregated
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691

    """

    def __init__(self, database):
        self.database = database
        self.cache = self.database.cache
        self.csv_to_sample_loader = self.database.csv_to_sample_loader
        self.protocol_name = self.database.protocol_name
        self.dataset_protocol_path = self.database.dataset_protocol_path
        self._get_samplesets = self.database._get_samplesets

        ## create_cache
        self.cache["znorm_csv"] = None
        self.cache["tnorm_csv"] = None

        znorm_csv = path_discovery(
            self.dataset_protocol_path,
            os.path.join(self.protocol_name, "norm", "for_znorm.lst"),
            os.path.join(self.protocol_name, "norm", "for_znorm.csv"),
        )

        tnorm_csv = path_discovery(
            self.dataset_protocol_path,
            os.path.join(self.protocol_name, "norm", "for_tnorm.lst"),
            os.path.join(self.protocol_name, "norm", "for_tnorm.csv"),
        )

        if znorm_csv is None:
            raise ValueError(
                f"The file `for_znorm.lst` is required and it was not found in `{self.protocol_name}/norm` "
            )

        if tnorm_csv is None:
            raise ValueError(
                f"The file `for_tnorm.csv` is required and it was not found `{self.protocol_name}/norm`"
            )

        self.database.znorm_csv = znorm_csv
        self.database.tnorm_csv = tnorm_csv

    def background_model_samples(self):
        return self.database.background_model_samples()

    def references(self, group="dev"):
        return self.database.references(group=group)

    def probes(self, group="dev"):
        return self.database.probes(group=group)

    def all_samples(self, groups=None):
        return self.database.all_samples(groups=groups)

    def groups(self):
        return self.database.groups()

    def zprobes(self, group="dev", proportion=1.0):

        if proportion <= 0 or proportion > 1:
            raise ValueError(
                f"Invalid proportion value ({proportion}). Values allowed from [0-1]"
            )

        cache_label = "znorm_csv"
        samplesets = self._get_samplesets(
            group=group,
            cache_label=cache_label,
            group_by_reference_id=False,
            fetching_probes=True,
            is_sparse=False,
        )

        zprobes = samplesets[: int(len(samplesets) * proportion)]

        return zprobes

    def treferences(self, covariate="sex", proportion=1.0):

        if proportion <= 0 or proportion > 1:
            raise ValueError(
                f"Invalid proportion value ({proportion}). Values allowed from [0-1]"
            )

        cache_label = "tnorm_csv"
        samplesets = self._get_samplesets(
            group="dev", cache_label=cache_label, group_by_reference_id=True,
        )

        treferences = samplesets[: int(len(samplesets) * proportion)]

        return treferences


692 693
class CSVDatasetCrossValidation:
    """
694
    Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline that
695 696
    handles **CROSS VALIDATION**.

697
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
698 699 700 701 702 703 704 705 706 707
    interface.


    This interface will take one `csv_file` as input and split into i-) data for training and
    ii-) data for testing.
    The data for testing will be further split in data for enrollment and data for probing.
    The input CSV file should be casted in the following format:

    .. code-block:: text

708 709 710 711
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
712 713 714 715 716 717 718 719 720 721 722 723
       ...

    Parameters
    ----------

    csv_file_name: str
      CSV file containing all the samples from your database

    random_state: int
      Pseudo-random number generator seed

    test_size: float
724
      Percentage of the reference_ids used for testing
725 726 727 728

    samples_for_enrollment: float
      Number of samples used for enrollment

729 730
    csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
        Base class that whose objective is to generate :any:`bob.pipelines.Sample`
731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754
        and/or :any:`bob.pipelines.SampleSet` from csv rows

    """

    def __init__(
        self,
        csv_file_name="metadata.csv",
        random_state=0,
        test_size=0.8,
        samples_for_enrollment=1,
        csv_to_sample_loader=CSVToSampleLoader(
            data_loader=bob.io.base.load, dataset_original_directory="", extension=""
        ),
    ):
        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            return cache

        self.random_state = random_state
        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader
755
        self.csv_file_name = open(csv_file_name)
756 757 758 759 760 761 762 763 764 765
        self.samples_for_enrollment = samples_for_enrollment
        self.test_size = test_size

        if self.test_size < 0 and self.test_size > 1:
            raise ValueError(
                f"`test_size` should be between 0 and 1. {test_size} is provided"
            )

    def _do_cross_validation(self):

766 767
        # Shuffling samples by reference_id
        samples_by_reference_id = group_samples_by_reference_id(
768 769
            self.csv_to_sample_loader(self.csv_file_name)
        )
770
        reference_ids = list(samples_by_reference_id.keys())
771
        np.random.seed(self.random_state)
772
        np.random.shuffle(reference_ids)
773 774

        # Getting the training data
775 776 777
        n_samples_for_training = len(reference_ids) - int(
            self.test_size * len(reference_ids)
        )
778 779
        self.cache["train"] = list(
            itertools.chain(
780 781 782 783
                *[
                    samples_by_reference_id[s]
                    for s in reference_ids[0:n_samples_for_training]
                ]
784 785 786 787 788 789
            )
        )

        # Splitting enroll and probe
        self.cache["dev_enroll_csv"] = []
        self.cache["dev_probe_csv"] = []
790 791
        for s in reference_ids[n_samples_for_training:]:
            samples = samples_by_reference_id[s]
792 793
            if len(samples) < self.samples_for_enrollment:
                raise ValueError(
794
                    f"Not enough samples ({len(samples)}) for enrollment for the reference_id {s}"
795 796 797 798 799 800 801 802 803 804 805 806 807
                )

            # Enrollment samples
            self.cache["dev_enroll_csv"].append(
                self.csv_to_sample_loader.convert_samples_to_samplesets(
                    samples[0 : self.samples_for_enrollment]
                )[0]
            )

            self.cache[
                "dev_probe_csv"
            ] += self.csv_to_sample_loader.convert_samples_to_samplesets(
                samples[self.samples_for_enrollment :],
808 809
                group_by_reference_id=False,
                references=reference_ids[n_samples_for_training:],
810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825
            )

    def _load_from_cache(self, cache_key):
        if self.cache[cache_key] is None:
            self._do_cross_validation()
        return self.cache[cache_key]

    def background_model_samples(self):
        return self._load_from_cache("train")

    def references(self, group="dev"):
        return self._load_from_cache("dev_enroll_csv")

    def probes(self, group="dev"):
        return self._load_from_cache("dev_probe_csv")

826 827 828 829 830 831 832
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
833 834 835 836 837 838 839
            Groups to consider ('train' and/or 'dev'). If `None` is given,
            returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
840
        """
841 842 843 844 845 846 847 848 849 850
        valid_groups = ["train", "dev"]
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

851
        # Get train samples (background_model_samples returns a list of samples)
852 853 854
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
855 856 857

        # Get enroll and probe samples
        for group in groups:
858 859
            samples = samples + [s for s_set in self.references(group) for s in s_set]
            samples = samples + [s for s_set in self.probes(group) for s in s_set]
860 861
        return samples

862

863
def group_samples_by_reference_id(samples):
864 865

    # Grouping sample sets
866
    samples_by_reference_id = dict()
867
    for s in samples:
868 869 870 871
        if s.reference_id not in samples_by_reference_id:
            samples_by_reference_id[s.reference_id] = []
        samples_by_reference_id[s.reference_id].append(s)
    return samples_by_reference_id