csv_dataset.py 22.7 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :


import os
from bob.pipelines import Sample, DelayedSample, SampleSet
7
from bob.db.base.utils import check_parameters_for_validity
8
9
10
11
import csv
import bob.io.base
import functools
from abc import ABCMeta, abstractmethod
12
13
import numpy as np
import itertools
14
import logging
15
import bob.db.base
16
from bob.extension.download import find_element_in_tarball
17
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
18

19
logger = logging.getLogger(__name__)
20

21
22
23
24

#####
# ANNOTATIONS LOADERS
####
25
class AnnotationsLoader:
26
27
28
29
30
31
32
    """
    Load annotations in the Idiap format
    """

    def __init__(
        self,
        annotation_directory=None,
33
        annotation_extension=".json",
34
        annotation_type="json",
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    ):
        self.annotation_directory = annotation_directory
        self.annotation_extension = annotation_extension
        self.annotation_type = annotation_type

    def __call__(self, row, header=None):
        if self.annotation_directory is None:
            return None

        path = row[0]

        # since the file id is equal to the file name, we can simply use it
        annotation_file = os.path.join(
            self.annotation_directory, path + self.annotation_extension
        )

        # return the annotations as read from file
        annotation = {
            "annotations": bob.db.base.read_annotation_file(
                annotation_file, self.annotation_type
            )
        }
        return annotation


#######
# SAMPLE LOADERS
# CONVERT CSV LINES TO SAMPLES
#######


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
66
class CSVBaseSampleLoader(metaclass=ABCMeta):
67
68
69
70
71
72
    """
    Convert CSV files in the format below to either a list of
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`

    .. code-block:: text

73
74
75
76
       PATH,REFERENCE_ID
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
       ...

    .. note::
       This class should be extended

    Parameters
    ----------

        data_loader:
            A python function that can be called parameterlessly, to load the
            sample in question from whatever medium

        extension:
            The file extension

    """

94
95
96
97
98
99
100
    def __init__(
        self,
        data_loader,
        metadata_loader=None,
        dataset_original_directory="",
        extension="",
    ):
101
102
        self.data_loader = data_loader
        self.extension = extension
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103
        self.dataset_original_directory = dataset_original_directory
104
        self.metadata_loader = metadata_loader
105
106
107
108
109
110
111
112
113

    @abstractmethod
    def __call__(self, filename):
        pass

    @abstractmethod
    def convert_row_to_sample(self, row, header):
        pass

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
114
    def convert_samples_to_samplesets(
115
        self, samples, group_by_reference_id=True, references=None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
116
    ):
117
118
119
120
121
122
        if group_by_reference_id:

            # Grouping sample sets
            sample_sets = dict()
            for s in samples:
                if s.reference_id not in sample_sets:
123
124
125
126
                    sample_sets[s.reference_id] = (
                        SampleSet([s], parent=s)
                        if references is None
                        else SampleSet([s], parent=s, references=references)
127
128
129
130
131
132
                    )
                else:
                    sample_sets[s.reference_id].append(s)
            return list(sample_sets.values())

        else:
133
134
135
136
137
            return (
                [SampleSet([s], parent=s) for s in samples]
                if references is None
                else [SampleSet([s], parent=s, references=references) for s in samples]
            )
138
139


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
140
class CSVToSampleLoader(CSVBaseSampleLoader):
141
142
143
144
145
    """
    Simple mechanism to convert CSV files in the format below to either a list of
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

146
147
    def check_header(self, header):
        """
148
        A header should have at least "reference_id" AND "PATH"
149
150
        """
        header = [h.lower() for h in header]
151
152
153
154
        if not "reference_id" in header:
            raise ValueError(
                "The field `reference_id` is not available in your dataset."
            )
155

156
157
158
        if not "path" in header:
            raise ValueError("The field `path` is not available in your dataset.")

159
160
161
162
    def __call__(self, f):
        f.seek(0)
        reader = csv.reader(f)
        header = next(reader)
163

164
165
        self.check_header(header)
        return [self.convert_row_to_sample(row, header) for row in reader]
166
167
168

    def convert_row_to_sample(self, row, header):
        path = row[0]
169
170
        reference_id = row[1]

171
        kwargs = dict([[str(h).lower(), r] for h, r in zip(header[2:], row[2:])])
172
173

        if self.metadata_loader is not None:
174
            metadata = self.metadata_loader(row, header=header)
175
176
            kwargs.update(metadata)

177
        return DelayedSample(
178
179
180
181
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
182
            key=path,
183
            reference_id=reference_id,
184
185
186
187
            **kwargs,
        )


188
189
190
191
192
193
class LSTToSampleLoader(CSVBaseSampleLoader):
    """
    Simple mechanism to convert LST files in the format below to either a list of
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

194
195
196
197
198
199
200
201
    def __call__(self, f):
        f.seek(0)
        reader = csv.reader(f, delimiter=" ")
        samples = []
        for row in reader:
            if row[0][0] == "#":
                continue
            samples.append(self.convert_row_to_sample(row))
202

203
        return samples
204
205
206

    def convert_row_to_sample(self, row, header=None):

207
208
209
210
211
212
213
214
215
216
217
218
        if len(row) == 4:
            path = row[0]
            compare_reference_id = row[1]
            reference_id = str(row[3])
            kwargs = {"compare_reference_id": str(compare_reference_id)}
        else:
            path = row[0]
            reference_id = str(row[1])
            kwargs = dict()
            if len(row) == 3:
                subject = row[2]
                kwargs = {"subject": str(subject)}
219
220

        if self.metadata_loader is not None:
221
            metadata = self.metadata_loader(row, header=header)
222
223
224
225
226
227
228
229
230
231
232
            kwargs.update(metadata)

        return DelayedSample(
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
            key=path,
            reference_id=reference_id,
            **kwargs,
        )
233
234


235
236
237
#####
# DATABASE INTERFACES
#####
238

239
240

class CSVDatasetDevEval(Database):
241
    """
242
243
    Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
244
245
246
247
248
249
250
    interface.

    To create a new dataset, you need to provide a directory structure similar to the one below:

    .. code-block:: text

       my_dataset/
251
252
253
254
255
       my_dataset/my_protocol/norm/train_world.csv
       my_dataset/my_protocol/dev/for_models.csv
       my_dataset/my_protocol/dev/for_probes.csv
       my_dataset/my_protocol/eval/for_models.csv
       my_dataset/my_protocol/eval/for_probes.csv
256
257
258
259
260
261
262
       ...


    In the above directory structure, inside of `my_dataset` should contain the directories with all
    evaluation protocols this dataset might have.
    Inside of the `my_protocol` directory should contain at least two csv files:

263
264
     - for_models.csv
     - for_probes.csv
265
266


267
    Those csv files should contain in each row i-) the path to raw data and ii-) the reference_id label
268
269
    for enrollment (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.references`) and
    probing (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.probes`).
270
271
272
273
    The structure of each CSV file should be as below:

    .. code-block:: text

274
275
276
277
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
278
279
       ...

280

281
282
283
284
285
    You might want to ship metadata within your Samples (e.g gender, age, annotation, ...)
    To do so is simple, just do as below:

    .. code-block:: text

286
287
288
289
       PATH,reference_id,METADATA_1,METADATA_2,METADATA_k
       path_1,reference_id_1,A,B,C
       path_2,reference_id_2,A,B,1
       path_i,reference_id_j,2,3,4
290
291
292
293
294
       ...


    The files `my_dataset/my_protocol/train.csv/eval_enroll.csv` and `my_dataset/my_protocol/train.csv/eval_probe.csv`
    are optional and it is used in case a protocol contains data for evaluation.
295

296
    Finally, the content of the file `my_dataset/my_protocol/train.csv` is used in the case a protocol
297
    contains data for training (`bob.bio.base.pipelines.vanilla_biometrics.Database.background_model_samples`)
298
299
300
301
302
303
304

    Parameters
    ----------

        dataset_path: str
          Absolute path of the dataset protocol description

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
305
        protocol_na,e: str
306
307
          The name of the protocol

308
309
        csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
            Base class that whose objective is to generate :any:`bob.pipelines.Sample`
310
            and/or :any:`bob.pipelines.SampleSet` from csv rows
311
312
313
314
315

    """

    def __init__(
        self,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
316
317
        dataset_protocol_path,
        protocol_name,
318
        csv_to_sample_loader=CSVToSampleLoader(
319
320
321
322
            data_loader=bob.io.base.load,
            metadata_loader=None,
            dataset_original_directory="",
            extension="",
323
        ),
324
        is_sparse=False,
325
    ):
326
        self.dataset_protocol_path = dataset_protocol_path
327
        self.is_sparse = is_sparse
328

329
330
        def get_paths():

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
331
332
            if not os.path.exists(dataset_protocol_path):
                raise ValueError(f"The path `{dataset_protocol_path}` was not found")
333

334
            def path_discovery(option1, option2):
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

                # If the input is a directory
                if os.path.isdir(dataset_protocol_path):
                    option1 = os.path.join(dataset_protocol_path, option1)
                    option2 = os.path.join(dataset_protocol_path, option2)
                    if os.path.exists(option1):
                        return open(option1)
                    else:
                        return open(option2) if os.path.exists(option2) else None

                # If it's not a directory is a tarball
                op1 = find_element_in_tarball(dataset_protocol_path, option1)
                return (
                    op1
                    if op1
                    else find_element_in_tarball(dataset_protocol_path, option2)
                )
352
353
354

            # Here we are handling the legacy
            train_csv = path_discovery(
355
356
                os.path.join(protocol_name, "norm", "train_world.lst"),
                os.path.join(protocol_name, "norm", "train_world.csv"),
357
358
359
            )

            dev_enroll_csv = path_discovery(
360
361
                os.path.join(protocol_name, "dev", "for_models.lst"),
                os.path.join(protocol_name, "dev", "for_models.csv"),
362
363
            )

364
            legacy_probe = "for_scores.lst" if self.is_sparse else "for_probes.lst"
365
            dev_probe_csv = path_discovery(
366
367
                os.path.join(protocol_name, "dev", legacy_probe),
                os.path.join(protocol_name, "dev", "for_probes.csv"),
368
369
370
            )

            eval_enroll_csv = path_discovery(
371
372
                os.path.join(protocol_name, "eval", "for_models.lst"),
                os.path.join(protocol_name, "eval", "for_models.csv"),
373
374
375
            )

            eval_probe_csv = path_discovery(
376
377
                os.path.join(protocol_name, "eval", legacy_probe),
                os.path.join(protocol_name, "eval", "for_probes.csv"),
378
            )
379
380
381
382

            # The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`

            # Dev
383
            if dev_enroll_csv is None:
384
385
386
387
                raise ValueError(
                    f"The file `{dev_enroll_csv}` is required and it was not found"
                )

388
            if dev_probe_csv is None:
389
390
391
                raise ValueError(
                    f"The file `{dev_probe_csv}` is required and it was not found"
                )
392
393
            dev_enroll_csv = dev_enroll_csv
            dev_probe_csv = dev_probe_csv
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431

            return (
                train_csv,
                dev_enroll_csv,
                dev_probe_csv,
                eval_enroll_csv,
                eval_probe_csv,
            )

        (
            self.train_csv,
            self.dev_enroll_csv,
            self.dev_probe_csv,
            self.eval_enroll_csv,
            self.eval_probe_csv,
        ) = get_paths()

        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            cache["eval_enroll_csv"] = None
            cache["eval_probe_csv"] = None
            return cache

        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader

    def background_model_samples(self):
        self.cache["train"] = (
            self.csv_to_sample_loader(self.train_csv)
            if self.cache["train"] is None
            else self.cache["train"]
        )

        return self.cache["train"]

432
433
434
    def _get_samplesets(
        self, group="dev", purpose="enroll", group_by_reference_id=False
    ):
435

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
436
        fetching_probes = False
437
438
439
        if purpose == "enroll":
            cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
440
            fetching_probes = True
441
442
443
444
445
            cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"

        if self.cache[cache_label] is not None:
            return self.cache[cache_label]

446
447
448
        # Getting samples from CSV
        samples = self.csv_to_sample_loader(self.__dict__[cache_label])

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
449
        references = None
450
        if fetching_probes and self.is_sparse:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
451

452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
            # Checking if `is_sparse` was set properly
            if len(samples) > 0 and not hasattr(samples[0], "compare_reference_id"):
                ValueError(
                    f"Attribute `compare_reference_id` not found in `{samples[0]}`."
                    "Make sure this attribute exists in your dataset if `is_sparse=True`"
                )

            sparse_samples = dict()
            for s in samples:
                if s.key in sparse_samples:
                    sparse_samples[s.key].references.append(s.compare_reference_id)
                else:
                    s.references = [s.compare_reference_id]
                    sparse_samples[s.key] = s
            samples = sparse_samples.values()
        else:
            if fetching_probes:
                references = list(
                    set([s.reference_id for s in self.references(group=group)])
                )
472
473

        sample_sets = self.csv_to_sample_loader.convert_samples_to_samplesets(
474
            samples, group_by_reference_id=group_by_reference_id, references=references,
475
476
477
478
479
480
481
482
        )

        self.cache[cache_label] = sample_sets

        return self.cache[cache_label]

    def references(self, group="dev"):
        return self._get_samplesets(
483
            group=group, purpose="enroll", group_by_reference_id=True
484
485
486
487
        )

    def probes(self, group="dev"):
        return self._get_samplesets(
488
            group=group, purpose="probe", group_by_reference_id=False
489
        )
490

491
492
493
494
495
496
497
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
498
499
500
501
502
503
504
            Groups to consider ('train', 'dev', and/or 'eval'). If `None` is
            given, returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
505
        """
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        valid_groups = ["train"]
        if self.dev_enroll_csv and self.dev_probe_csv:
            valid_groups.append("dev")
        if self.eval_enroll_csv and self.eval_probe_csv:
            valid_groups.append("eval")
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

520
        # Get train samples (background_model_samples returns a list of samples)
521
522
523
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
524
525
526
527
528

        # Get enroll and probe samples
        for group in groups:
            for purpose in ("enroll", "probe"):
                label = f"{group}_{purpose}_csv"
529
                samples = samples + self.csv_to_sample_loader(self.__dict__[label])
530
531
        return samples

532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
    def groups(self):
        """This function returns the list of groups for this database.

        Returns
        -------

        [str]
          A list of groups
        """

        # We always have dev-set
        groups = ["dev"]

        if self.train_csv is not None:
            groups.append("train")

        if self.eval_enroll_csv is not None:
            groups.append("eval")

        return groups

553
554
555

class CSVDatasetCrossValidation:
    """
556
    Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline that
557
558
    handles **CROSS VALIDATION**.

559
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
560
561
562
563
564
565
566
567
568
569
    interface.


    This interface will take one `csv_file` as input and split into i-) data for training and
    ii-) data for testing.
    The data for testing will be further split in data for enrollment and data for probing.
    The input CSV file should be casted in the following format:

    .. code-block:: text

570
571
572
573
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
574
575
576
577
578
579
580
581
582
583
584
585
       ...

    Parameters
    ----------

    csv_file_name: str
      CSV file containing all the samples from your database

    random_state: int
      Pseudo-random number generator seed

    test_size: float
586
      Percentage of the reference_ids used for testing
587
588
589
590

    samples_for_enrollment: float
      Number of samples used for enrollment

591
592
    csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
        Base class that whose objective is to generate :any:`bob.pipelines.Sample`
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        and/or :any:`bob.pipelines.SampleSet` from csv rows

    """

    def __init__(
        self,
        csv_file_name="metadata.csv",
        random_state=0,
        test_size=0.8,
        samples_for_enrollment=1,
        csv_to_sample_loader=CSVToSampleLoader(
            data_loader=bob.io.base.load, dataset_original_directory="", extension=""
        ),
    ):
        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            return cache

        self.random_state = random_state
        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader
617
        self.csv_file_name = open(csv_file_name)
618
619
620
621
622
623
624
625
626
627
        self.samples_for_enrollment = samples_for_enrollment
        self.test_size = test_size

        if self.test_size < 0 and self.test_size > 1:
            raise ValueError(
                f"`test_size` should be between 0 and 1. {test_size} is provided"
            )

    def _do_cross_validation(self):

628
629
        # Shuffling samples by reference_id
        samples_by_reference_id = group_samples_by_reference_id(
630
631
            self.csv_to_sample_loader(self.csv_file_name)
        )
632
        reference_ids = list(samples_by_reference_id.keys())
633
        np.random.seed(self.random_state)
634
        np.random.shuffle(reference_ids)
635
636

        # Getting the training data
637
638
639
        n_samples_for_training = len(reference_ids) - int(
            self.test_size * len(reference_ids)
        )
640
641
        self.cache["train"] = list(
            itertools.chain(
642
643
644
645
                *[
                    samples_by_reference_id[s]
                    for s in reference_ids[0:n_samples_for_training]
                ]
646
647
648
649
650
651
            )
        )

        # Splitting enroll and probe
        self.cache["dev_enroll_csv"] = []
        self.cache["dev_probe_csv"] = []
652
653
        for s in reference_ids[n_samples_for_training:]:
            samples = samples_by_reference_id[s]
654
655
            if len(samples) < self.samples_for_enrollment:
                raise ValueError(
656
                    f"Not enough samples ({len(samples)}) for enrollment for the reference_id {s}"
657
658
659
660
661
662
663
664
665
666
667
668
669
                )

            # Enrollment samples
            self.cache["dev_enroll_csv"].append(
                self.csv_to_sample_loader.convert_samples_to_samplesets(
                    samples[0 : self.samples_for_enrollment]
                )[0]
            )

            self.cache[
                "dev_probe_csv"
            ] += self.csv_to_sample_loader.convert_samples_to_samplesets(
                samples[self.samples_for_enrollment :],
670
671
                group_by_reference_id=False,
                references=reference_ids[n_samples_for_training:],
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
            )

    def _load_from_cache(self, cache_key):
        if self.cache[cache_key] is None:
            self._do_cross_validation()
        return self.cache[cache_key]

    def background_model_samples(self):
        return self._load_from_cache("train")

    def references(self, group="dev"):
        return self._load_from_cache("dev_enroll_csv")

    def probes(self, group="dev"):
        return self._load_from_cache("dev_probe_csv")

688
689
690
691
692
693
694
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
695
696
697
698
699
700
701
            Groups to consider ('train' and/or 'dev'). If `None` is given,
            returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
702
        """
703
704
705
706
707
708
709
710
711
712
        valid_groups = ["train", "dev"]
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

713
        # Get train samples (background_model_samples returns a list of samples)
714
715
716
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
717
718
719

        # Get enroll and probe samples
        for group in groups:
720
721
            samples = samples + [s for s_set in self.references(group) for s in s_set]
            samples = samples + [s for s_set in self.probes(group) for s in s_set]
722
723
        return samples

724

725
def group_samples_by_reference_id(samples):
726
727

    # Grouping sample sets
728
    samples_by_reference_id = dict()
729
    for s in samples:
730
731
732
733
        if s.reference_id not in samples_by_reference_id:
            samples_by_reference_id[s.reference_id] = []
        samples_by_reference_id[s.reference_id].append(s)
    return samples_by_reference_id