csv_dataset.py 21.1 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :


import os
from bob.pipelines import Sample, DelayedSample, SampleSet
7
from bob.db.base.utils import check_parameters_for_validity
8
9
10
11
import csv
import bob.io.base
import functools
from abc import ABCMeta, abstractmethod
12
13
import numpy as np
import itertools
14
import logging
15
16
17
import bob.db.base

from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
18

19
logger = logging.getLogger(__name__)
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

#####
# ANNOTATIONS LOADERS
####
class IdiapAnnotationsLoader:
    """
    Load annotations in the Idiap format
    """

    def __init__(
        self,
        annotation_directory=None,
        annotation_extension=".pos",
        annotation_type="eyecenter",
    ):
        self.annotation_directory = annotation_directory
        self.annotation_extension = annotation_extension
        self.annotation_type = annotation_type

    def __call__(self, row, header=None):
        if self.annotation_directory is None:
            return None

        path = row[0]

        # since the file id is equal to the file name, we can simply use it
        annotation_file = os.path.join(
            self.annotation_directory, path + self.annotation_extension
        )

        # return the annotations as read from file
        annotation = {
            "annotations": bob.db.base.read_annotation_file(
                annotation_file, self.annotation_type
            )
        }
        return annotation


#######
# SAMPLE LOADERS
# CONVERT CSV LINES TO SAMPLES
#######


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
66
class CSVBaseSampleLoader(metaclass=ABCMeta):
67
68
69
70
71
72
    """
    Convert CSV files in the format below to either a list of
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`

    .. code-block:: text

73
74
75
76
       PATH,REFERENCE_ID
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
       ...

    .. note::
       This class should be extended

    Parameters
    ----------

        data_loader:
            A python function that can be called parameterlessly, to load the
            sample in question from whatever medium

        extension:
            The file extension

    """

94
95
96
97
98
99
100
    def __init__(
        self,
        data_loader,
        metadata_loader=None,
        dataset_original_directory="",
        extension="",
    ):
101
102
        self.data_loader = data_loader
        self.extension = extension
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103
        self.dataset_original_directory = dataset_original_directory
104
        self.metadata_loader = metadata_loader
105
106
107
108
109
110
111
112
113

    @abstractmethod
    def __call__(self, filename):
        pass

    @abstractmethod
    def convert_row_to_sample(self, row, header):
        pass

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
114
    def convert_samples_to_samplesets(
115
        self, samples, group_by_reference_id=True, references=None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
116
    ):
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        if group_by_reference_id:

            # Grouping sample sets
            sample_sets = dict()
            for s in samples:
                if s.reference_id not in sample_sets:
                    sample_sets[s.reference_id] = SampleSet(
                        [s], parent=s, references=references
                    )
                else:
                    sample_sets[s.reference_id].append(s)
            return list(sample_sets.values())

        else:
            return [SampleSet([s], parent=s, references=references) for s in samples]
132
133


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
134
class CSVToSampleLoader(CSVBaseSampleLoader):
135
136
137
138
139
    """
    Simple mechanism to convert CSV files in the format below to either a list of
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

140
141
    def check_header(self, header):
        """
142
        A header should have at least "reference_id" AND "PATH"
143
144
        """
        header = [h.lower() for h in header]
145
146
147
148
        if not "reference_id" in header:
            raise ValueError(
                "The field `reference_id` is not available in your dataset."
            )
149

150
151
152
153
        if not "path" in header:
            raise ValueError("The field `path` is not available in your dataset.")

    def __call__(self, filename):
154
155
156
157
158

        with open(filename) as cf:
            reader = csv.reader(cf)
            header = next(reader)

159
            self.check_header(header)
160
161
162
163
            return [self.convert_row_to_sample(row, header) for row in reader]

    def convert_row_to_sample(self, row, header):
        path = row[0]
164
165
        reference_id = row[1]

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
166
        kwargs = dict([[h, r] for h, r in zip(header[2:], row[2:])])
167
168
169
170
171

        if self.metadata_loader is not None:
            metadata = self.metadata_loader(row)
            kwargs.update(metadata)

172
        return DelayedSample(
173
174
175
176
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
177
            key=path,
178
            reference_id=reference_id,
179
180
181
182
            **kwargs,
        )


183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class LSTToSampleLoader(CSVBaseSampleLoader):
    """
    Simple mechanism to convert LST files in the format below to either a list of
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

    def __call__(self, filename):

        with open(filename) as cf:
            reader = csv.reader(cf, delimiter=" ")
            return [self.convert_row_to_sample(row) for row in reader]

    def convert_row_to_sample(self, row, header=None):

        path = row[0]
        reference_id = str(row[1])
        kwargs = dict()
        if len(row) == 3:
            subject = row[2]
            kwargs = {"subject": str(subject)}

        if self.metadata_loader is not None:
            metadata = self.metadata_loader(row)
            kwargs.update(metadata)

        return DelayedSample(
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
            key=path,
            reference_id=reference_id,
            **kwargs,
        )
217
218


219
220
221
#####
# DATABASE INTERFACES
#####
222

223
224

class CSVDatasetDevEval(Database):
225
    """
226
227
    Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    interface.

    To create a new dataset, you need to provide a directory structure similar to the one below:

    .. code-block:: text

       my_dataset/
       my_dataset/my_protocol/
       my_dataset/my_protocol/train.csv
       my_dataset/my_protocol/train.csv/dev_enroll.csv
       my_dataset/my_protocol/train.csv/dev_probe.csv
       my_dataset/my_protocol/train.csv/eval_enroll.csv
       my_dataset/my_protocol/train.csv/eval_probe.csv
       ...


    In the above directory structure, inside of `my_dataset` should contain the directories with all
    evaluation protocols this dataset might have.
    Inside of the `my_protocol` directory should contain at least two csv files:

     - dev_enroll.csv
     - dev_probe.csv


252
    Those csv files should contain in each row i-) the path to raw data and ii-) the reference_id label
253
254
    for enrollment (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.references`) and
    probing (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.probes`).
255
256
257
258
    The structure of each CSV file should be as below:

    .. code-block:: text

259
260
261
262
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
263
264
       ...

265

266
267
268
269
270
    You might want to ship metadata within your Samples (e.g gender, age, annotation, ...)
    To do so is simple, just do as below:

    .. code-block:: text

271
272
273
274
       PATH,reference_id,METADATA_1,METADATA_2,METADATA_k
       path_1,reference_id_1,A,B,C
       path_2,reference_id_2,A,B,1
       path_i,reference_id_j,2,3,4
275
276
277
278
279
       ...


    The files `my_dataset/my_protocol/train.csv/eval_enroll.csv` and `my_dataset/my_protocol/train.csv/eval_probe.csv`
    are optional and it is used in case a protocol contains data for evaluation.
280

281
    Finally, the content of the file `my_dataset/my_protocol/train.csv` is used in the case a protocol
282
    contains data for training (`bob.bio.base.pipelines.vanilla_biometrics.Database.background_model_samples`)
283
284
285
286
287
288
289

    Parameters
    ----------

        dataset_path: str
          Absolute path of the dataset protocol description

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
290
        protocol_na,e: str
291
292
          The name of the protocol

293
294
        csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
            Base class that whose objective is to generate :any:`bob.pipelines.Sample`
295
            and/or :any:`bob.pipelines.SampleSet` from csv rows
296
297
298
299
300

    """

    def __init__(
        self,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
301
302
        dataset_protocol_path,
        protocol_name,
303
        csv_to_sample_loader=CSVToSampleLoader(
304
305
306
307
            data_loader=bob.io.base.load,
            metadata_loader=None,
            dataset_original_directory="",
            extension="",
308
309
        ),
    ):
310
311
        self.dataset_protocol_path = dataset_protocol_path

312
313
        def get_paths():

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
314
315
            if not os.path.exists(dataset_protocol_path):
                raise ValueError(f"The path `{dataset_protocol_path}` was not found")
316
317

            # TODO: Unzip file if dataset path is a zip
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
318
            protocol_path = os.path.join(dataset_protocol_path, protocol_name)
319
            if not os.path.exists(protocol_path):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
320
                raise ValueError(f"The protocol `{protocol_name}` was not found")
321

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
            def path_discovery(option1, option2):
                return option1 if os.path.exists(option1) else option2

            # Here we are handling the legacy
            train_csv = path_discovery(
                os.path.join(protocol_path, "norm", "train_world.lst"),
                os.path.join(protocol_path, "norm", "train_world.csv"),
            )

            dev_enroll_csv = path_discovery(
                os.path.join(protocol_path, "dev", "for_models.lst"),
                os.path.join(protocol_path, "dev", "for_models.csv"),
            )

            dev_probe_csv = path_discovery(
                os.path.join(protocol_path, "dev", "for_probes.lst"),
                os.path.join(protocol_path, "dev", "for_probes.csv"),
            )

            eval_enroll_csv = path_discovery(
                os.path.join(protocol_path, "eval", "for_models.lst"),
                os.path.join(protocol_path, "eval", "for_models.csv"),
            )

            eval_probe_csv = path_discovery(
                os.path.join(protocol_path, "eval", "for_probes.lst"),
                os.path.join(protocol_path, "eval", "for_probes.csv"),
            )
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369

            # The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
            train_csv = train_csv if os.path.exists(train_csv) else None

            # Eval
            eval_enroll_csv = (
                eval_enroll_csv if os.path.exists(eval_enroll_csv) else None
            )
            eval_probe_csv = eval_probe_csv if os.path.exists(eval_probe_csv) else None

            # Dev
            if not os.path.exists(dev_enroll_csv):
                raise ValueError(
                    f"The file `{dev_enroll_csv}` is required and it was not found"
                )

            if not os.path.exists(dev_probe_csv):
                raise ValueError(
                    f"The file `{dev_probe_csv}` is required and it was not found"
                )
370
371
            dev_enroll_csv = dev_enroll_csv
            dev_probe_csv = dev_probe_csv
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409

            return (
                train_csv,
                dev_enroll_csv,
                dev_probe_csv,
                eval_enroll_csv,
                eval_probe_csv,
            )

        (
            self.train_csv,
            self.dev_enroll_csv,
            self.dev_probe_csv,
            self.eval_enroll_csv,
            self.eval_probe_csv,
        ) = get_paths()

        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            cache["eval_enroll_csv"] = None
            cache["eval_probe_csv"] = None
            return cache

        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader

    def background_model_samples(self):
        self.cache["train"] = (
            self.csv_to_sample_loader(self.train_csv)
            if self.cache["train"] is None
            else self.cache["train"]
        )

        return self.cache["train"]

410
411
412
    def _get_samplesets(
        self, group="dev", purpose="enroll", group_by_reference_id=False
    ):
413

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
414
        fetching_probes = False
415
416
417
        if purpose == "enroll":
            cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
418
            fetching_probes = True
419
420
421
422
423
            cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"

        if self.cache[cache_label] is not None:
            return self.cache[cache_label]

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
424
425
        references = None
        if fetching_probes:
426
427
428
            references = list(
                set([s.reference_id for s in self.references(group=group)])
            )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
429
430

        samples = self.csv_to_sample_loader(self.__dict__[cache_label])
431
432

        sample_sets = self.csv_to_sample_loader.convert_samples_to_samplesets(
433
            samples, group_by_reference_id=group_by_reference_id, references=references
434
435
436
437
438
439
440
441
        )

        self.cache[cache_label] = sample_sets

        return self.cache[cache_label]

    def references(self, group="dev"):
        return self._get_samplesets(
442
            group=group, purpose="enroll", group_by_reference_id=True
443
444
445
446
        )

    def probes(self, group="dev"):
        return self._get_samplesets(
447
            group=group, purpose="probe", group_by_reference_id=False
448
        )
449

450
451
452
453
454
455
456
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
457
458
459
460
461
462
463
            Groups to consider ('train', 'dev', and/or 'eval'). If `None` is
            given, returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
464
        """
465
466
467
468
469
470
471
472
473
474
475
476
477
478
        valid_groups = ["train"]
        if self.dev_enroll_csv and self.dev_probe_csv:
            valid_groups.append("dev")
        if self.eval_enroll_csv and self.eval_probe_csv:
            valid_groups.append("eval")
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

479
        # Get train samples (background_model_samples returns a list of samples)
480
481
482
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
483
484
485
486
487

        # Get enroll and probe samples
        for group in groups:
            for purpose in ("enroll", "probe"):
                label = f"{group}_{purpose}_csv"
488
                samples = samples + self.csv_to_sample_loader(self.__dict__[label])
489
490
        return samples

491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
    def groups(self):
        """This function returns the list of groups for this database.

        Returns
        -------

        [str]
          A list of groups
        """

        # We always have dev-set
        groups = ["dev"]

        if self.train_csv is not None:
            groups.append("train")

        if self.eval_enroll_csv is not None:
            groups.append("eval")

        return groups

512
513
514

class CSVDatasetCrossValidation:
    """
515
    Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline that
516
517
    handles **CROSS VALIDATION**.

518
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
519
520
521
522
523
524
525
526
527
528
    interface.


    This interface will take one `csv_file` as input and split into i-) data for training and
    ii-) data for testing.
    The data for testing will be further split in data for enrollment and data for probing.
    The input CSV file should be casted in the following format:

    .. code-block:: text

529
530
531
532
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
533
534
535
536
537
538
539
540
541
542
543
544
       ...

    Parameters
    ----------

    csv_file_name: str
      CSV file containing all the samples from your database

    random_state: int
      Pseudo-random number generator seed

    test_size: float
545
      Percentage of the reference_ids used for testing
546
547
548
549

    samples_for_enrollment: float
      Number of samples used for enrollment

550
551
    csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
        Base class that whose objective is to generate :any:`bob.pipelines.Sample`
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        and/or :any:`bob.pipelines.SampleSet` from csv rows

    """

    def __init__(
        self,
        csv_file_name="metadata.csv",
        random_state=0,
        test_size=0.8,
        samples_for_enrollment=1,
        csv_to_sample_loader=CSVToSampleLoader(
            data_loader=bob.io.base.load, dataset_original_directory="", extension=""
        ),
    ):
        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            return cache

        self.random_state = random_state
        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader
        self.csv_file_name = csv_file_name
        self.samples_for_enrollment = samples_for_enrollment
        self.test_size = test_size

        if self.test_size < 0 and self.test_size > 1:
            raise ValueError(
                f"`test_size` should be between 0 and 1. {test_size} is provided"
            )

    def _do_cross_validation(self):

587
588
        # Shuffling samples by reference_id
        samples_by_reference_id = group_samples_by_reference_id(
589
590
            self.csv_to_sample_loader(self.csv_file_name)
        )
591
        reference_ids = list(samples_by_reference_id.keys())
592
        np.random.seed(self.random_state)
593
        np.random.shuffle(reference_ids)
594
595

        # Getting the training data
596
597
598
        n_samples_for_training = len(reference_ids) - int(
            self.test_size * len(reference_ids)
        )
599
600
        self.cache["train"] = list(
            itertools.chain(
601
602
603
604
                *[
                    samples_by_reference_id[s]
                    for s in reference_ids[0:n_samples_for_training]
                ]
605
606
607
608
609
610
            )
        )

        # Splitting enroll and probe
        self.cache["dev_enroll_csv"] = []
        self.cache["dev_probe_csv"] = []
611
612
        for s in reference_ids[n_samples_for_training:]:
            samples = samples_by_reference_id[s]
613
614
            if len(samples) < self.samples_for_enrollment:
                raise ValueError(
615
                    f"Not enough samples ({len(samples)}) for enrollment for the reference_id {s}"
616
617
618
619
620
621
622
623
624
625
626
627
628
                )

            # Enrollment samples
            self.cache["dev_enroll_csv"].append(
                self.csv_to_sample_loader.convert_samples_to_samplesets(
                    samples[0 : self.samples_for_enrollment]
                )[0]
            )

            self.cache[
                "dev_probe_csv"
            ] += self.csv_to_sample_loader.convert_samples_to_samplesets(
                samples[self.samples_for_enrollment :],
629
630
                group_by_reference_id=False,
                references=reference_ids[n_samples_for_training:],
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
            )

    def _load_from_cache(self, cache_key):
        if self.cache[cache_key] is None:
            self._do_cross_validation()
        return self.cache[cache_key]

    def background_model_samples(self):
        return self._load_from_cache("train")

    def references(self, group="dev"):
        return self._load_from_cache("dev_enroll_csv")

    def probes(self, group="dev"):
        return self._load_from_cache("dev_probe_csv")

647
648
649
650
651
652
653
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
654
655
656
657
658
659
660
            Groups to consider ('train' and/or 'dev'). If `None` is given,
            returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
661
        """
662
663
664
665
666
667
668
669
670
671
        valid_groups = ["train", "dev"]
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

672
        # Get train samples (background_model_samples returns a list of samples)
673
674
675
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
676
677
678

        # Get enroll and probe samples
        for group in groups:
679
680
            samples = samples + [s for s_set in self.references(group) for s in s_set]
            samples = samples + [s for s_set in self.probes(group) for s in s_set]
681
682
        return samples

683

684
def group_samples_by_reference_id(samples):
685
686

    # Grouping sample sets
687
    samples_by_reference_id = dict()
688
    for s in samples:
689
690
691
692
        if s.reference_id not in samples_by_reference_id:
            samples_by_reference_id[s.reference_id] = []
        samples_by_reference_id[s.reference_id].append(s)
    return samples_by_reference_id