csv_dataset.py 22.9 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :


import os
from bob.pipelines import Sample, DelayedSample, SampleSet
7
from bob.db.base.utils import check_parameters_for_validity
8
9
10
11
import csv
import bob.io.base
import functools
from abc import ABCMeta, abstractmethod
12
13
import numpy as np
import itertools
14
import logging
15
import bob.db.base
16
from bob.extension.download import find_element_in_tarball
17
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
18

19
logger = logging.getLogger(__name__)
20

21
22
23
24

#####
# ANNOTATIONS LOADERS
####
25
class AnnotationsLoader:
26
    """
27
28
    Metadata loader that loads annotations in the Idiap format using the function
    :any:`bob.db.base.read_annotation_file`
29
30
31
32
33
    """

    def __init__(
        self,
        annotation_directory=None,
34
        annotation_extension=".json",
35
        annotation_type="json",
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    ):
        self.annotation_directory = annotation_directory
        self.annotation_extension = annotation_extension
        self.annotation_type = annotation_type

    def __call__(self, row, header=None):
        if self.annotation_directory is None:
            return None

        path = row[0]

        # since the file id is equal to the file name, we can simply use it
        annotation_file = os.path.join(
            self.annotation_directory, path + self.annotation_extension
        )

        # return the annotations as read from file
        annotation = {
            "annotations": bob.db.base.read_annotation_file(
                annotation_file, self.annotation_type
            )
        }
        return annotation


#######
# SAMPLE LOADERS
# CONVERT CSV LINES TO SAMPLES
#######


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
67
class CSVBaseSampleLoader(metaclass=ABCMeta):
68
69
    """    
    Base class that converts the lines of a CSV file, like the one below to
70
71
72
73
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`

    .. code-block:: text

74
75
76
77
       PATH,REFERENCE_ID
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
78
79
80
81
82
83
84
85
86
87
88
89
       ...

    .. note::
       This class should be extended

    Parameters
    ----------

        data_loader:
            A python function that can be called parameterlessly, to load the
            sample in question from whatever medium

90
91
92
93
94
95
96
97
        metadata_loader:
            AnnotationsLoader

        dataset_original_directory: str
            Path of where data is stored
        
        extension: str
            Default file extension
98
99
100

    """

101
102
103
104
105
106
107
    def __init__(
        self,
        data_loader,
        metadata_loader=None,
        dataset_original_directory="",
        extension="",
    ):
108
109
        self.data_loader = data_loader
        self.extension = extension
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
110
        self.dataset_original_directory = dataset_original_directory
111
        self.metadata_loader = metadata_loader
112
113
114
115
116
117
118
119
120

    @abstractmethod
    def __call__(self, filename):
        pass

    @abstractmethod
    def convert_row_to_sample(self, row, header):
        pass

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121
    def convert_samples_to_samplesets(
122
        self, samples, group_by_reference_id=True, references=None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123
    ):
124
125
126
127
128
129
        if group_by_reference_id:

            # Grouping sample sets
            sample_sets = dict()
            for s in samples:
                if s.reference_id not in sample_sets:
130
131
132
133
                    sample_sets[s.reference_id] = (
                        SampleSet([s], parent=s)
                        if references is None
                        else SampleSet([s], parent=s, references=references)
134
135
136
137
138
139
                    )
                else:
                    sample_sets[s.reference_id].append(s)
            return list(sample_sets.values())

        else:
140
141
142
143
144
            return (
                [SampleSet([s], parent=s) for s in samples]
                if references is None
                else [SampleSet([s], parent=s, references=references) for s in samples]
            )
145
146


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
147
class CSVToSampleLoader(CSVBaseSampleLoader):
148
    """
149
    Simple mechanism that converts the lines of a CSV file to
150
151
152
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

153
154
    def check_header(self, header):
        """
155
        A header should have at least "reference_id" AND "PATH"
156
157
        """
        header = [h.lower() for h in header]
158
159
160
161
        if not "reference_id" in header:
            raise ValueError(
                "The field `reference_id` is not available in your dataset."
            )
162

163
164
165
        if not "path" in header:
            raise ValueError("The field `path` is not available in your dataset.")

166
167
168
169
    def __call__(self, f):
        f.seek(0)
        reader = csv.reader(f)
        header = next(reader)
170

171
172
        self.check_header(header)
        return [self.convert_row_to_sample(row, header) for row in reader]
173
174
175

    def convert_row_to_sample(self, row, header):
        path = row[0]
176
177
        reference_id = row[1]

178
        kwargs = dict([[str(h).lower(), r] for h, r in zip(header[2:], row[2:])])
179
180

        if self.metadata_loader is not None:
181
            metadata = self.metadata_loader(row, header=header)
182
183
            kwargs.update(metadata)

184
        return DelayedSample(
185
186
187
188
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
189
            key=path,
190
            reference_id=reference_id,
191
192
193
194
            **kwargs,
        )


195
196
class LSTToSampleLoader(CSVBaseSampleLoader):
    """
197
    Simple mechanism that converts the lines of a LST file to
198
199
200
    :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
    """

201
202
203
204
205
206
207
208
    def __call__(self, f):
        f.seek(0)
        reader = csv.reader(f, delimiter=" ")
        samples = []
        for row in reader:
            if row[0][0] == "#":
                continue
            samples.append(self.convert_row_to_sample(row))
209

210
        return samples
211
212
213

    def convert_row_to_sample(self, row, header=None):

214
215
216
217
218
219
220
221
222
223
224
225
        if len(row) == 4:
            path = row[0]
            compare_reference_id = row[1]
            reference_id = str(row[3])
            kwargs = {"compare_reference_id": str(compare_reference_id)}
        else:
            path = row[0]
            reference_id = str(row[1])
            kwargs = dict()
            if len(row) == 3:
                subject = row[2]
                kwargs = {"subject": str(subject)}
226
227

        if self.metadata_loader is not None:
228
            metadata = self.metadata_loader(row, header=header)
229
230
231
232
233
234
235
236
237
238
239
            kwargs.update(metadata)

        return DelayedSample(
            functools.partial(
                self.data_loader,
                os.path.join(self.dataset_original_directory, path + self.extension),
            ),
            key=path,
            reference_id=reference_id,
            **kwargs,
        )
240
241


242
243
244
#####
# DATABASE INTERFACES
#####
245

246
247

class CSVDatasetDevEval(Database):
248
    """
249
250
    Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
251
252
253
254
255
256
257
    interface.

    To create a new dataset, you need to provide a directory structure similar to the one below:

    .. code-block:: text

       my_dataset/
258
259
260
261
262
       my_dataset/my_protocol/norm/train_world.csv
       my_dataset/my_protocol/dev/for_models.csv
       my_dataset/my_protocol/dev/for_probes.csv
       my_dataset/my_protocol/eval/for_models.csv
       my_dataset/my_protocol/eval/for_probes.csv
263
264
265
266
267
268
269
       ...


    In the above directory structure, inside of `my_dataset` should contain the directories with all
    evaluation protocols this dataset might have.
    Inside of the `my_protocol` directory should contain at least two csv files:

270
271
     - for_models.csv
     - for_probes.csv
272
273


274
    Those csv files should contain in each row i-) the path to raw data and ii-) the reference_id label
275
276
    for enrollment (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.references`) and
    probing (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.probes`).
277
278
279
280
    The structure of each CSV file should be as below:

    .. code-block:: text

281
282
283
284
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
285
286
       ...

287

288
289
290
291
292
    You might want to ship metadata within your Samples (e.g gender, age, annotation, ...)
    To do so is simple, just do as below:

    .. code-block:: text

293
294
295
296
       PATH,reference_id,METADATA_1,METADATA_2,METADATA_k
       path_1,reference_id_1,A,B,C
       path_2,reference_id_2,A,B,1
       path_i,reference_id_j,2,3,4
297
298
299
300
301
       ...


    The files `my_dataset/my_protocol/train.csv/eval_enroll.csv` and `my_dataset/my_protocol/train.csv/eval_probe.csv`
    are optional and it is used in case a protocol contains data for evaluation.
302

303
    Finally, the content of the file `my_dataset/my_protocol/train.csv` is used in the case a protocol
304
    contains data for training (`bob.bio.base.pipelines.vanilla_biometrics.Database.background_model_samples`)
305
306
307
308
309

    Parameters
    ----------

        dataset_path: str
310
          Absolute path or a tarball of the dataset protocol description.
311

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
312
        protocol_na,e: str
313
314
          The name of the protocol

315
316
        csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
            Base class that whose objective is to generate :any:`bob.pipelines.Sample`
317
            and/or :any:`bob.pipelines.SampleSet` from csv rows
318
    
319
320
321
322
323

    """

    def __init__(
        self,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
324
325
        dataset_protocol_path,
        protocol_name,
326
        csv_to_sample_loader=CSVToSampleLoader(
327
328
329
330
            data_loader=bob.io.base.load,
            metadata_loader=None,
            dataset_original_directory="",
            extension="",
331
        ),
332
        is_sparse=False,
333
    ):
334
        self.dataset_protocol_path = dataset_protocol_path
335
        self.is_sparse = is_sparse
336

337
338
        def get_paths():

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
339
340
            if not os.path.exists(dataset_protocol_path):
                raise ValueError(f"The path `{dataset_protocol_path}` was not found")
341

342
            def path_discovery(option1, option2):
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

                # If the input is a directory
                if os.path.isdir(dataset_protocol_path):
                    option1 = os.path.join(dataset_protocol_path, option1)
                    option2 = os.path.join(dataset_protocol_path, option2)
                    if os.path.exists(option1):
                        return open(option1)
                    else:
                        return open(option2) if os.path.exists(option2) else None

                # If it's not a directory is a tarball
                op1 = find_element_in_tarball(dataset_protocol_path, option1)
                return (
                    op1
                    if op1
                    else find_element_in_tarball(dataset_protocol_path, option2)
                )
360
361
362

            # Here we are handling the legacy
            train_csv = path_discovery(
363
364
                os.path.join(protocol_name, "norm", "train_world.lst"),
                os.path.join(protocol_name, "norm", "train_world.csv"),
365
366
367
            )

            dev_enroll_csv = path_discovery(
368
369
                os.path.join(protocol_name, "dev", "for_models.lst"),
                os.path.join(protocol_name, "dev", "for_models.csv"),
370
371
            )

372
            legacy_probe = "for_scores.lst" if self.is_sparse else "for_probes.lst"
373
            dev_probe_csv = path_discovery(
374
375
                os.path.join(protocol_name, "dev", legacy_probe),
                os.path.join(protocol_name, "dev", "for_probes.csv"),
376
377
378
            )

            eval_enroll_csv = path_discovery(
379
380
                os.path.join(protocol_name, "eval", "for_models.lst"),
                os.path.join(protocol_name, "eval", "for_models.csv"),
381
382
383
            )

            eval_probe_csv = path_discovery(
384
385
                os.path.join(protocol_name, "eval", legacy_probe),
                os.path.join(protocol_name, "eval", "for_probes.csv"),
386
            )
387
388
389
390

            # The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`

            # Dev
391
            if dev_enroll_csv is None:
392
393
394
395
                raise ValueError(
                    f"The file `{dev_enroll_csv}` is required and it was not found"
                )

396
            if dev_probe_csv is None:
397
398
399
                raise ValueError(
                    f"The file `{dev_probe_csv}` is required and it was not found"
                )
400
401
            dev_enroll_csv = dev_enroll_csv
            dev_probe_csv = dev_probe_csv
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439

            return (
                train_csv,
                dev_enroll_csv,
                dev_probe_csv,
                eval_enroll_csv,
                eval_probe_csv,
            )

        (
            self.train_csv,
            self.dev_enroll_csv,
            self.dev_probe_csv,
            self.eval_enroll_csv,
            self.eval_probe_csv,
        ) = get_paths()

        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            cache["eval_enroll_csv"] = None
            cache["eval_probe_csv"] = None
            return cache

        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader

    def background_model_samples(self):
        self.cache["train"] = (
            self.csv_to_sample_loader(self.train_csv)
            if self.cache["train"] is None
            else self.cache["train"]
        )

        return self.cache["train"]

440
441
442
    def _get_samplesets(
        self, group="dev", purpose="enroll", group_by_reference_id=False
    ):
443

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
444
        fetching_probes = False
445
446
447
        if purpose == "enroll":
            cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
448
            fetching_probes = True
449
450
451
452
453
            cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"

        if self.cache[cache_label] is not None:
            return self.cache[cache_label]

454
455
456
        # Getting samples from CSV
        samples = self.csv_to_sample_loader(self.__dict__[cache_label])

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
457
        references = None
458
        if fetching_probes and self.is_sparse:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
459

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
            # Checking if `is_sparse` was set properly
            if len(samples) > 0 and not hasattr(samples[0], "compare_reference_id"):
                ValueError(
                    f"Attribute `compare_reference_id` not found in `{samples[0]}`."
                    "Make sure this attribute exists in your dataset if `is_sparse=True`"
                )

            sparse_samples = dict()
            for s in samples:
                if s.key in sparse_samples:
                    sparse_samples[s.key].references.append(s.compare_reference_id)
                else:
                    s.references = [s.compare_reference_id]
                    sparse_samples[s.key] = s
            samples = sparse_samples.values()
        else:
            if fetching_probes:
                references = list(
                    set([s.reference_id for s in self.references(group=group)])
                )
480
481

        sample_sets = self.csv_to_sample_loader.convert_samples_to_samplesets(
482
            samples, group_by_reference_id=group_by_reference_id, references=references,
483
484
485
486
487
488
489
490
        )

        self.cache[cache_label] = sample_sets

        return self.cache[cache_label]

    def references(self, group="dev"):
        return self._get_samplesets(
491
            group=group, purpose="enroll", group_by_reference_id=True
492
493
494
495
        )

    def probes(self, group="dev"):
        return self._get_samplesets(
496
            group=group, purpose="probe", group_by_reference_id=False
497
        )
498

499
500
501
502
503
504
505
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
506
507
508
509
510
511
512
            Groups to consider ('train', 'dev', and/or 'eval'). If `None` is
            given, returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
513
        """
514
515
516
517
518
519
520
521
522
523
524
525
526
527
        valid_groups = ["train"]
        if self.dev_enroll_csv and self.dev_probe_csv:
            valid_groups.append("dev")
        if self.eval_enroll_csv and self.eval_probe_csv:
            valid_groups.append("eval")
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

528
        # Get train samples (background_model_samples returns a list of samples)
529
530
531
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
532
533
534
535
536

        # Get enroll and probe samples
        for group in groups:
            for purpose in ("enroll", "probe"):
                label = f"{group}_{purpose}_csv"
537
                samples = samples + self.csv_to_sample_loader(self.__dict__[label])
538
539
        return samples

540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    def groups(self):
        """This function returns the list of groups for this database.

        Returns
        -------

        [str]
          A list of groups
        """

        # We always have dev-set
        groups = ["dev"]

        if self.train_csv is not None:
            groups.append("train")

        if self.eval_enroll_csv is not None:
            groups.append("eval")

        return groups

561
562
563

class CSVDatasetCrossValidation:
    """
564
    Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline that
565
566
    handles **CROSS VALIDATION**.

567
    Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
568
569
570
571
572
573
574
575
576
577
    interface.


    This interface will take one `csv_file` as input and split into i-) data for training and
    ii-) data for testing.
    The data for testing will be further split in data for enrollment and data for probing.
    The input CSV file should be casted in the following format:

    .. code-block:: text

578
579
580
581
       PATH,reference_id
       path_1,reference_id_1
       path_2,reference_id_2
       path_i,reference_id_j
582
583
584
585
586
587
588
589
590
591
592
593
       ...

    Parameters
    ----------

    csv_file_name: str
      CSV file containing all the samples from your database

    random_state: int
      Pseudo-random number generator seed

    test_size: float
594
      Percentage of the reference_ids used for testing
595
596
597
598

    samples_for_enrollment: float
      Number of samples used for enrollment

599
600
    csv_to_sample_loader: :any:`bob.bio.base.database.CSVBaseSampleLoader`
        Base class that whose objective is to generate :any:`bob.pipelines.Sample`
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        and/or :any:`bob.pipelines.SampleSet` from csv rows

    """

    def __init__(
        self,
        csv_file_name="metadata.csv",
        random_state=0,
        test_size=0.8,
        samples_for_enrollment=1,
        csv_to_sample_loader=CSVToSampleLoader(
            data_loader=bob.io.base.load, dataset_original_directory="", extension=""
        ),
    ):
        def get_dict_cache():
            cache = dict()
            cache["train"] = None
            cache["dev_enroll_csv"] = None
            cache["dev_probe_csv"] = None
            return cache

        self.random_state = random_state
        self.cache = get_dict_cache()
        self.csv_to_sample_loader = csv_to_sample_loader
625
        self.csv_file_name = open(csv_file_name)
626
627
628
629
630
631
632
633
634
635
        self.samples_for_enrollment = samples_for_enrollment
        self.test_size = test_size

        if self.test_size < 0 and self.test_size > 1:
            raise ValueError(
                f"`test_size` should be between 0 and 1. {test_size} is provided"
            )

    def _do_cross_validation(self):

636
637
        # Shuffling samples by reference_id
        samples_by_reference_id = group_samples_by_reference_id(
638
639
            self.csv_to_sample_loader(self.csv_file_name)
        )
640
        reference_ids = list(samples_by_reference_id.keys())
641
        np.random.seed(self.random_state)
642
        np.random.shuffle(reference_ids)
643
644

        # Getting the training data
645
646
647
        n_samples_for_training = len(reference_ids) - int(
            self.test_size * len(reference_ids)
        )
648
649
        self.cache["train"] = list(
            itertools.chain(
650
651
652
653
                *[
                    samples_by_reference_id[s]
                    for s in reference_ids[0:n_samples_for_training]
                ]
654
655
656
657
658
659
            )
        )

        # Splitting enroll and probe
        self.cache["dev_enroll_csv"] = []
        self.cache["dev_probe_csv"] = []
660
661
        for s in reference_ids[n_samples_for_training:]:
            samples = samples_by_reference_id[s]
662
663
            if len(samples) < self.samples_for_enrollment:
                raise ValueError(
664
                    f"Not enough samples ({len(samples)}) for enrollment for the reference_id {s}"
665
666
667
668
669
670
671
672
673
674
675
676
677
                )

            # Enrollment samples
            self.cache["dev_enroll_csv"].append(
                self.csv_to_sample_loader.convert_samples_to_samplesets(
                    samples[0 : self.samples_for_enrollment]
                )[0]
            )

            self.cache[
                "dev_probe_csv"
            ] += self.csv_to_sample_loader.convert_samples_to_samplesets(
                samples[self.samples_for_enrollment :],
678
679
                group_by_reference_id=False,
                references=reference_ids[n_samples_for_training:],
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
            )

    def _load_from_cache(self, cache_key):
        if self.cache[cache_key] is None:
            self._do_cross_validation()
        return self.cache[cache_key]

    def background_model_samples(self):
        return self._load_from_cache("train")

    def references(self, group="dev"):
        return self._load_from_cache("dev_enroll_csv")

    def probes(self, group="dev"):
        return self._load_from_cache("dev_probe_csv")

696
697
698
699
700
701
702
    def all_samples(self, groups=None):
        """
        Reads and returns all the samples in `groups`.

        Parameters
        ----------
        groups: list or None
703
704
705
706
707
708
709
            Groups to consider ('train' and/or 'dev'). If `None` is given,
            returns the samples from all groups.

        Returns
        -------
        samples: list
            List of :class:`bob.pipelines.Sample` objects.
710
        """
711
712
713
714
715
716
717
718
719
720
        valid_groups = ["train", "dev"]
        groups = check_parameters_for_validity(
            parameters=groups,
            parameter_description="groups",
            valid_parameters=valid_groups,
            default_parameters=valid_groups,
        )

        samples = []

721
        # Get train samples (background_model_samples returns a list of samples)
722
723
724
        if "train" in groups:
            samples = samples + self.background_model_samples()
            groups.remove("train")
725
726
727

        # Get enroll and probe samples
        for group in groups:
728
729
            samples = samples + [s for s_set in self.references(group) for s in s_set]
            samples = samples + [s for s_set in self.probes(group) for s in s_set]
730
731
        return samples

732

733
def group_samples_by_reference_id(samples):
734
735

    # Grouping sample sets
736
    samples_by_reference_id = dict()
737
    for s in samples:
738
739
740
741
        if s.reference_id not in samples_by_reference_id:
            samples_by_reference_id[s.reference_id] = []
        samples_by_reference_id[s.reference_id].append(s)
    return samples_by_reference_id