csv_dataset.py 22.6 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :


import os
from bob.pipelines import Sample, DelayedSample, SampleSet
7
from bob.db.base.utils import check_parameters_for_validity
8
9
10
11
import csv
import bob.io.base
import functools
from abc import ABCMeta, abstractmethod
12
13
import numpy as np
import itertools
14
import logging
15
16
17
import bob.db.base

from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
18

19
logger = logging.getLogger(__name__)
20

21
22
23
24

#####
# ANNOTATIONS LOADERS
####
25
class AnnotationsLoader:
26
27
28
29
30
31
32
    """
    Load annotations in the Idiap format
    """

    def __init__(
        self,
        annotation_directory=None,
33
        annotation_extension=".json",
34
        annotation_type="json",
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    ):
        self.annotation_directory = annotation_directory
        self.annotation_extension = annotation_extension
        self.annotation_type = annotation_type

    def __call__(self, row, header=None):
        if self.annotation_directory is None:
            return None

        path = row[0]

        # since the file id is equal to the file name, we can simply use it
        annotation_file = os.path.join(
            self.annotation_directory, path + self.annotation_extension
        )

        # return the annotations as read from file
        annotation = {
            "annotations": bob.db.base.read_annotation_file(
                annotation_file, self.annotation_type
            )
        }
        return annotation


#######
# SAMPLE LOADERS
# CONVERT CSV LINES TO SAMPLES
#######


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
66
class CSVBaseSampleLoader(metaclass=ABCMeta):
67
68
69
70
71
72
    """
    Convert CSV files in the format below to either a list of
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`

    .. code-block:: text

73
74
75
76
       PATH,REFERENCE_ID
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
       ...

    .. note::
       This class should be extended

    Parameters
    ----------

        data_loader:
            A python function that can be called parameterlessly, to load the
            sample in question from whatever medium

        extension:
            The file extension

    """

94
95
96
97
98
99
100
    def __init__(
        self,
        data_loader,
        metadata_loader=None,
        dataset_original_directory="",
        extension="",
    ):
101
102
        self.data_loader = data_loader
        self.extension = extension
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103
        self.dataset_original_directory = dataset_original_directory
104
        self.metadata_loader = metadata_loader
105
106
107
108
109
110
111
112
113

    @abstractmethod
    def __call__(self, filename):
        pass

    @abstractmethod
    def convert_row_to_sample(self, row, header):
        pass

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
114
    def convert_samples_to_samplesets(
115
        self, samples, group_by_reference_id=True, references=None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
116
    ):
117
118
119
120
121
122
        if group_by_reference_id:

            # Grouping sample sets
            sample_sets = dict()
            for s in samples:
                if s.reference_id not in sample_sets:
123
124
125
126
                    sample_sets[s.reference_id] = (
                        SampleSet([s], parent=s)
                        if references is None
                        else SampleSet([s], parent=s, references=references)
127
128
129
130
131
132
                    )
                else:
                    sample_sets[s.reference_id].append(s)
            return list(sample_sets.values())

        else:
133
134
135
136
137
            return (
                [SampleSet([s], parent=s) for s in samples]
                if references is None
                else [SampleSet([s], parent=s, references=references) for s in samples]
            )
138
139


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
140
class CSVToSampleLoader(CSVBaseSampleLoader):
141
142
143
144
145
    """
    Simple mechanism to convert CSV files in the format below to either a list of
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

146
147
    def check_header(self, header):
        """
148
        A header should have at least "reference_id" AND "PATH"
149
150
        """
        header = [h.lower() for h in header]
151
152
153
154
        if not "reference_id" in header:
            raise ValueError(
                "The field `reference_id` is not available in your dataset."
            )
155

156
157
158
159
        if not "path" in header:
            raise ValueError("The field `path` is not available in your dataset.")

    def __call__(self, filename):
160
161
162
163
164

        with open(filename) as cf:
            reader = csv.reader(cf)
            header = next(reader)

165
            self.check_header(header)
166
167
168
169
            return [self.convert_row_to_sample(row, header) for row in reader]

    def convert_row_to_sample(self, row, header):
        path = row[0]
170
171
        reference_id = row[1]

172
        kwargs = dict([[str(h).lower(), r] for h, r in zip(header[2:], row[2:])])
173
174

        if self.metadata_loader is not None:
175
            metadata = self.metadata_loader(row, header=header)
176
177
            kwargs.update(metadata)

178
        return DelayedSample(
179
180
181
182
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
183
            key=path,
184
            reference_id=reference_id,
185
186
187
188
            **kwargs,
        )


189
190
191
192
193
194
195
196
197
198
class LSTToSampleLoader(CSVBaseSampleLoader):
    """
    Simple mechanism to convert LST files in the format below to either a list of
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

    def __call__(self, filename):

        with open(filename) as cf:
            reader = csv.reader(cf, delimiter=" ")
199
200
201
202
203
204
205
            samples = []
            for row in reader:
                if row[0][0] == "#":
                    continue
                samples.append(self.convert_row_to_sample(row))

            return samples
206
207
208

    def convert_row_to_sample(self, row, header=None):

209
210
211
212
213
214
215
216
217
218
219
220
        if len(row) == 4:
            path = row[0]
            compare_reference_id = row[1]
            reference_id = str(row[3])
            kwargs = {"compare_reference_id": str(compare_reference_id)}
        else:
            path = row[0]
            reference_id = str(row[1])
            kwargs = dict()
            if len(row) == 3:
                subject = row[2]
                kwargs = {"subject": str(subject)}
221
222

        if self.metadata_loader is not None:
223
            metadata = self.metadata_loader(row, header=header)
224
225
226
227
228
229
230
231
232
233
234
            kwargs.update(metadata)

        return DelayedSample(
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
            key=path,
            reference_id=reference_id,
            **kwargs,
        )
235
236


237
238
239
#####
# DATABASE INTERFACES
#####
240

241
242

class CSVDatasetDevEval(Database):
243
    """
244
245
    Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
246
247
248
249
250
251
252
    interface.

    To create a new dataset, you need to provide a directory structure similar to the one below:

    .. code-block:: text

       my_dataset/
253
254
255
256
257
       my_dataset/my_protocol/norm/train_world.csv
       my_dataset/my_protocol/dev/for_models.csv
       my_dataset/my_protocol/dev/for_probes.csv
       my_dataset/my_protocol/eval/for_models.csv
       my_dataset/my_protocol/eval/for_probes.csv
258
259
260
261
262
263
264
       ...


    In the above directory structure, inside of `my_dataset` should contain the directories with all
    evaluation protocols this dataset might have.
    Inside of the `my_protocol` directory should contain at least two csv files:

265
266
     - for_models.csv
     - for_probes.csv
267
268


269
    Those csv files should contain in each row i-) the path to raw data and ii-) the reference_id label
270
271
    for enrollment (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.references`) and
    probing (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.probes`).
272
273
274
275
    The structure of each CSV file should be as below:

    .. code-block:: text

276
277
278
279
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
280
281
       ...

282

283
284
285
286
287
    You might want to ship metadata within your Samples (e.g gender, age, annotation, ...)
    To do so is simple, just do as below:

    .. code-block:: text

288
289
290
291
       PATH,reference_id,METADATA_1,METADATA_2,METADATA_k
       path_1,reference_id_1,A,B,C
       path_2,reference_id_2,A,B,1
       path_i,reference_id_j,2,3,4
292
293
294
295
296
       ...


    The files `my_dataset/my_protocol/train.csv/eval_enroll.csv` and `my_dataset/my_protocol/train.csv/eval_probe.csv`
    are optional and it is used in case a protocol contains data for evaluation.
297

298
    Finally, the content of the file `my_dataset/my_protocol/train.csv` is used in the case a protocol
299
    contains data for training (`bob.bio.base.pipelines.vanilla_biometrics.Database.background_model_samples`)
300
301
302
303
304
305
306

    Parameters
    ----------

        dataset_path: str
          Absolute path of the dataset protocol description

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
307
        protocol_na,e: str
308
309
          The name of the protocol

310
311
        csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
            Base class that whose objective is to generate :any:`bob.pipelines.Sample`
312
            and/or :any:`bob.pipelines.SampleSet` from csv rows
313
314
315
316
317

    """

    def __init__(
        self,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
318
319
        dataset_protocol_path,
        protocol_name,
320
        csv_to_sample_loader=CSVToSampleLoader(
321
322
323
324
            data_loader=bob.io.base.load,
            metadata_loader=None,
            dataset_original_directory="",
            extension="",
325
        ),
326
        is_sparse=False,
327
    ):
328
        self.dataset_protocol_path = dataset_protocol_path
329
        self.is_sparse = is_sparse
330

331
332
        def get_paths():

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
333
334
            if not os.path.exists(dataset_protocol_path):
                raise ValueError(f"The path `{dataset_protocol_path}` was not found")
335
336

            # TODO: Unzip file if dataset path is a zip
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
337
            protocol_path = os.path.join(dataset_protocol_path, protocol_name)
338
            if not os.path.exists(protocol_path):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
339
                raise ValueError(f"The protocol `{protocol_name}` was not found")
340

341
342
343
344
345
346
347
348
349
350
351
352
353
354
            def path_discovery(option1, option2):
                return option1 if os.path.exists(option1) else option2

            # Here we are handling the legacy
            train_csv = path_discovery(
                os.path.join(protocol_path, "norm", "train_world.lst"),
                os.path.join(protocol_path, "norm", "train_world.csv"),
            )

            dev_enroll_csv = path_discovery(
                os.path.join(protocol_path, "dev", "for_models.lst"),
                os.path.join(protocol_path, "dev", "for_models.csv"),
            )

355
            legacy_probe = "for_scores.lst" if self.is_sparse else "for_probes.lst"
356
            dev_probe_csv = path_discovery(
357
                os.path.join(protocol_path, "dev", legacy_probe),
358
359
360
361
362
363
364
365
366
                os.path.join(protocol_path, "dev", "for_probes.csv"),
            )

            eval_enroll_csv = path_discovery(
                os.path.join(protocol_path, "eval", "for_models.lst"),
                os.path.join(protocol_path, "eval", "for_models.csv"),
            )

            eval_probe_csv = path_discovery(
367
                os.path.join(protocol_path, "eval", legacy_probe),
368
369
                os.path.join(protocol_path, "eval", "for_probes.csv"),
            )
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389

            # The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
            train_csv = train_csv if os.path.exists(train_csv) else None

            # Eval
            eval_enroll_csv = (
                eval_enroll_csv if os.path.exists(eval_enroll_csv) else None
            )
            eval_probe_csv = eval_probe_csv if os.path.exists(eval_probe_csv) else None

            # Dev
            if not os.path.exists(dev_enroll_csv):
                raise ValueError(
                    f"The file `{dev_enroll_csv}` is required and it was not found"
                )

            if not os.path.exists(dev_probe_csv):
                raise ValueError(
                    f"The file `{dev_probe_csv}` is required and it was not found"
                )
390
391
            dev_enroll_csv = dev_enroll_csv
            dev_probe_csv = dev_probe_csv
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429

            return (
                train_csv,
                dev_enroll_csv,
                dev_probe_csv,
                eval_enroll_csv,
                eval_probe_csv,
            )

        (
            self.train_csv,
            self.dev_enroll_csv,
            self.dev_probe_csv,
            self.eval_enroll_csv,
            self.eval_probe_csv,
        ) = get_paths()

        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            cache["eval_enroll_csv"] = None
            cache["eval_probe_csv"] = None
            return cache

        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader

    def background_model_samples(self):
        self.cache["train"] = (
            self.csv_to_sample_loader(self.train_csv)
            if self.cache["train"] is None
            else self.cache["train"]
        )

        return self.cache["train"]

430
431
432
    def _get_samplesets(
        self, group="dev", purpose="enroll", group_by_reference_id=False
    ):
433

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
434
        fetching_probes = False
435
436
437
        if purpose == "enroll":
            cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
438
            fetching_probes = True
439
440
441
442
443
            cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"

        if self.cache[cache_label] is not None:
            return self.cache[cache_label]

444
445
446
        # Getting samples from CSV
        samples = self.csv_to_sample_loader(self.__dict__[cache_label])

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
447
        references = None
448
        if fetching_probes and self.is_sparse:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
449

450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
            # Checking if `is_sparse` was set properly
            if len(samples) > 0 and not hasattr(samples[0], "compare_reference_id"):
                ValueError(
                    f"Attribute `compare_reference_id` not found in `{samples[0]}`."
                    "Make sure this attribute exists in your dataset if `is_sparse=True`"
                )

            sparse_samples = dict()
            for s in samples:
                if s.key in sparse_samples:
                    sparse_samples[s.key].references.append(s.compare_reference_id)
                else:
                    s.references = [s.compare_reference_id]
                    sparse_samples[s.key] = s
            samples = sparse_samples.values()
        else:
            if fetching_probes:
                references = list(
                    set([s.reference_id for s in self.references(group=group)])
                )
470
471

        sample_sets = self.csv_to_sample_loader.convert_samples_to_samplesets(
472
            samples, group_by_reference_id=group_by_reference_id, references=references,
473
474
475
476
477
478
479
480
        )

        self.cache[cache_label] = sample_sets

        return self.cache[cache_label]

    def references(self, group="dev"):
        return self._get_samplesets(
481
            group=group, purpose="enroll", group_by_reference_id=True
482
483
484
485
        )

    def probes(self, group="dev"):
        return self._get_samplesets(
486
            group=group, purpose="probe", group_by_reference_id=False
487
        )
488

489
490
491
492
493
494
495
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
496
497
498
499
500
501
502
            Groups to consider ('train', 'dev', and/or 'eval'). If `None` is
            given, returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
503
        """
504
505
506
507
508
509
510
511
512
513
514
515
516
517
        valid_groups = ["train"]
        if self.dev_enroll_csv and self.dev_probe_csv:
            valid_groups.append("dev")
        if self.eval_enroll_csv and self.eval_probe_csv:
            valid_groups.append("eval")
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

518
        # Get train samples (background_model_samples returns a list of samples)
519
520
521
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
522
523
524
525
526

        # Get enroll and probe samples
        for group in groups:
            for purpose in ("enroll", "probe"):
                label = f"{group}_{purpose}_csv"
527
                samples = samples + self.csv_to_sample_loader(self.__dict__[label])
528
529
        return samples

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    def groups(self):
        """This function returns the list of groups for this database.

        Returns
        -------

        [str]
          A list of groups
        """

        # We always have dev-set
        groups = ["dev"]

        if self.train_csv is not None:
            groups.append("train")

        if self.eval_enroll_csv is not None:
            groups.append("eval")

        return groups

551
552
553

class CSVDatasetCrossValidation:
    """
554
    Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline that
555
556
    handles **CROSS VALIDATION**.

557
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
558
559
560
561
562
563
564
565
566
567
    interface.


    This interface will take one `csv_file` as input and split into i-) data for training and
    ii-) data for testing.
    The data for testing will be further split in data for enrollment and data for probing.
    The input CSV file should be casted in the following format:

    .. code-block:: text

568
569
570
571
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
572
573
574
575
576
577
578
579
580
581
582
583
       ...

    Parameters
    ----------

    csv_file_name: str
      CSV file containing all the samples from your database

    random_state: int
      Pseudo-random number generator seed

    test_size: float
584
      Percentage of the reference_ids used for testing
585
586
587
588

    samples_for_enrollment: float
      Number of samples used for enrollment

589
590
    csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
        Base class that whose objective is to generate :any:`bob.pipelines.Sample`
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
        and/or :any:`bob.pipelines.SampleSet` from csv rows

    """

    def __init__(
        self,
        csv_file_name="metadata.csv",
        random_state=0,
        test_size=0.8,
        samples_for_enrollment=1,
        csv_to_sample_loader=CSVToSampleLoader(
            data_loader=bob.io.base.load, dataset_original_directory="", extension=""
        ),
    ):
        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            return cache

        self.random_state = random_state
        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader
        self.csv_file_name = csv_file_name
        self.samples_for_enrollment = samples_for_enrollment
        self.test_size = test_size

        if self.test_size < 0 and self.test_size > 1:
            raise ValueError(
                f"`test_size` should be between 0 and 1. {test_size} is provided"
            )

    def _do_cross_validation(self):

626
627
        # Shuffling samples by reference_id
        samples_by_reference_id = group_samples_by_reference_id(
628
629
            self.csv_to_sample_loader(self.csv_file_name)
        )
630
        reference_ids = list(samples_by_reference_id.keys())
631
        np.random.seed(self.random_state)
632
        np.random.shuffle(reference_ids)
633
634

        # Getting the training data
635
636
637
        n_samples_for_training = len(reference_ids) - int(
            self.test_size * len(reference_ids)
        )
638
639
        self.cache["train"] = list(
            itertools.chain(
640
641
642
643
                *[
                    samples_by_reference_id[s]
                    for s in reference_ids[0:n_samples_for_training]
                ]
644
645
646
647
648
649
            )
        )

        # Splitting enroll and probe
        self.cache["dev_enroll_csv"] = []
        self.cache["dev_probe_csv"] = []
650
651
        for s in reference_ids[n_samples_for_training:]:
            samples = samples_by_reference_id[s]
652
653
            if len(samples) < self.samples_for_enrollment:
                raise ValueError(
654
                    f"Not enough samples ({len(samples)}) for enrollment for the reference_id {s}"
655
656
657
658
659
660
661
662
663
664
665
666
667
                )

            # Enrollment samples
            self.cache["dev_enroll_csv"].append(
                self.csv_to_sample_loader.convert_samples_to_samplesets(
                    samples[0 : self.samples_for_enrollment]
                )[0]
            )

            self.cache[
                "dev_probe_csv"
            ] += self.csv_to_sample_loader.convert_samples_to_samplesets(
                samples[self.samples_for_enrollment :],
668
669
                group_by_reference_id=False,
                references=reference_ids[n_samples_for_training:],
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
            )

    def _load_from_cache(self, cache_key):
        if self.cache[cache_key] is None:
            self._do_cross_validation()
        return self.cache[cache_key]

    def background_model_samples(self):
        return self._load_from_cache("train")

    def references(self, group="dev"):
        return self._load_from_cache("dev_enroll_csv")

    def probes(self, group="dev"):
        return self._load_from_cache("dev_probe_csv")

686
687
688
689
690
691
692
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
693
694
695
696
697
698
699
            Groups to consider ('train' and/or 'dev'). If `None` is given,
            returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
700
        """
701
702
703
704
705
706
707
708
709
710
        valid_groups = ["train", "dev"]
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

711
        # Get train samples (background_model_samples returns a list of samples)
712
713
714
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
715
716
717

        # Get enroll and probe samples
        for group in groups:
718
719
            samples = samples + [s for s_set in self.references(group) for s in s_set]
            samples = samples + [s for s_set in self.probes(group) for s in s_set]
720
721
        return samples

722

723
def group_samples_by_reference_id(samples):
724
725

    # Grouping sample sets
726
    samples_by_reference_id = dict()
727
    for s in samples:
728
729
730
731
        if s.reference_id not in samples_by_reference_id:
            samples_by_reference_id[s.reference_id] = []
        samples_by_reference_id[s.reference_id].append(s)
    return samples_by_reference_id