wrappers.py 13 KB
Newer Older
1
2
3
4
5
6
7
from bob.pipelines import (
    DelayedSample,
    SampleSet,
    Sample,
    DelayedSampleSet,
    DelayedSampleSetCached,
)
8
9
10
11
12
import bob.io.base
import os
import dask
import functools
from .score_writers import FourColumnsScoreWriter
13
from .abstract_classes import BioAlgorithm
14
import bob.pipelines
15
import numpy as np
16
import h5py
17
from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper
18
from .legacy import BioAlgorithmLegacy
19
20
21
22
23
from bob.bio.base.transformers import (
    PreprocessorTransformer,
    ExtractorTransformer,
    AlgorithmTransformer,
)
24
from bob.pipelines.wrappers import SampleWrapper, CheckpointWrapper
25
from bob.pipelines.distributed.sge import SGEMultipleQueuesCluster
26
import logging
27
28
from bob.pipelines.utils import isinstance_nested
import gc
29
30
import time
from . import pickle_compress, uncompress_unpickle
31

32
logger = logging.getLogger(__name__)
33

34

35
36
37
38
39
class BioAlgorithmCheckpointWrapper(BioAlgorithm):
    """Wrapper used to checkpoint enrolled and Scoring samples.

    Parameters
    ----------
40
41
        biometric_algorithm: :any:`bob.bio.base.pipelines.vanilla_biometrics.BioAlgorithm`
           An implemented :any:`bob.bio.base.pipelines.vanilla_biometrics.BioAlgorithm`
42
43
44
45
46
47
48
49
50
51
    
        base_dir: str
           Path to store biometric references and scores
        
        extension: str
            File extension

        force: bool
          If True, will recompute scores and biometric references no matter if a file exists

52
53
54
55
56
57
58
59
        hash_fn
        Pointer to a hash function. This hash function maps
        `sample.key` to a hash code and this hash code corresponds a relative directory
        where a single `sample` will be checkpointed.
        This is useful when is desirable file directories with less than
        a certain number of files.


60
61
62
    Examples
    --------

63
64
65
    >>> from bob.bio.base.pipelines.vanilla_biometrics import BioAlgorithmCheckpointWrapper, Distance    
    >>> biometric_algorithm = BioAlgorithmCheckpointWrapper(Distance(), base_dir="./")
    >>> biometric_algorithm.enroll(sample) # doctest: +SKIP
66
67
68

    """

69
    def __init__(
70
71
72
73
74
75
76
        self,
        biometric_algorithm,
        base_dir,
        group=None,
        force=False,
        hash_fn=None,
        **kwargs
77
    ):
78
79
        super().__init__(**kwargs)

80
81
82
        self.base_dir = base_dir
        self.set_score_references_path(group)

83
84
        self.biometric_algorithm = biometric_algorithm
        self.force = force
85
        self._biometric_reference_extension = ".hdf5"
86
        self._score_extension = ".pickle.gz"
87
        self.hash_fn = hash_fn
88

89
90
91
    def clear_caches(self):
        self.biometric_algorithm.clear_caches()

92
93
    def set_score_references_path(self, group):
        if group is None:
94
95
96
            self.biometric_reference_dir = os.path.join(
                self.base_dir, "biometric_references"
            )
97
98
            self.score_dir = os.path.join(self.base_dir, "scores")
        else:
99
100
101
            self.biometric_reference_dir = os.path.join(
                self.base_dir, group, "biometric_references"
            )
102
            self.score_dir = os.path.join(self.base_dir, group, "scores")
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    def enroll(self, enroll_features):
        return self.biometric_algorithm.enroll(enroll_features)

    def score(self, biometric_reference, data):
        return self.biometric_algorithm.score(biometric_reference, data)

    def score_multiple_biometric_references(self, biometric_references, data):
        return self.biometric_algorithm.score_multiple_biometric_references(
            biometric_references, data
        )

    def write_biometric_reference(self, sample, path):
        return bob.io.base.save(sample.data, path, create_directories=True)

118
    def write_scores(self, samples, path):
119
        pickle_compress(path, samples)
120

121
122
123
124
125
126
    def _enroll_sample_set(self, sampleset):
        """
        Enroll a sample set with checkpointing
        """

        # Amending `models` directory
127
128
129
130
        hash_dir_name = (
            self.hash_fn(str(sampleset.key)) if self.hash_fn is not None else ""
        )

131
132
        path = os.path.join(
            self.biometric_reference_dir,
133
            hash_dir_name,
134
135
            str(sampleset.key) + self._biometric_reference_extension,
        )
136

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        if self.force or not os.path.exists(path):

            enrolled_sample = self.biometric_algorithm._enroll_sample_set(sampleset)

            # saving the new sample
            self.write_biometric_reference(enrolled_sample, path)

        # This seems inefficient, but it's crucial for large datasets
        delayed_enrolled_sample = DelayedSample(
            functools.partial(bob.io.base.load, path), parent=sampleset
        )

        return delayed_enrolled_sample

    def _score_sample_set(
        self,
        sampleset,
        biometric_references,
        allow_scoring_with_all_biometric_references=False,
    ):
157
        """Given a sampleset for probing, compute the scores and returns a sample set with the scores
158
159
        """

160
        def _load(path):
161
            return uncompress_unpickle(path)
162

163
164
165
        def _make_name(sampleset, biometric_references):
            # The score file name is composed by sampleset key and the
            # first 3 biometric_references
166
            reference_id = str(sampleset.reference_id)
167
            name = str(sampleset.key)
168
            suffix = "_".join([str(s.key) for s in biometric_references[0:3]])
169
            return os.path.join(reference_id, name + suffix)
170

171
172
173
174
175
        # Amending `models` directory
        hash_dir_name = (
            self.hash_fn(str(sampleset.key)) if self.hash_fn is not None else ""
        )

176
        path = os.path.join(
177
            self.score_dir,
178
            hash_dir_name,
179
            _make_name(sampleset, biometric_references) + self._score_extension,
180
        )
181

182
        parent = sampleset
183
        if self.force or not os.path.exists(path):
184

185
186
187
188
189
190
191
            # Computing score
            scored_sample_set = self.biometric_algorithm._score_sample_set(
                sampleset,
                biometric_references,
                allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
            )
            self.write_scores(scored_sample_set.samples, path)
192
            parent = scored_sample_set
193

194
        scored_sample_set = DelayedSampleSetCached(
195
196
            functools.partial(_load, path), parent=parent
        )
197
198
199
200

        return scored_sample_set


201
class BioAlgorithmDaskWrapper(BioAlgorithm):
202
    """
203
    Wrap :any:`bob.bio.base.pipelines.vanilla_biometrics.BioAlgorithm` to work with DASK
204
205
    """

206
207
208
    def __init__(self, biometric_algorithm, **kwargs):
        self.biometric_algorithm = biometric_algorithm

209
210
211
    def clear_caches(self):
        self.biometric_algorithm.clear_caches()

212
    def enroll_samples(self, biometric_reference_features):
213

214
        biometric_references = biometric_reference_features.map_partitions(
215
            self.biometric_algorithm.enroll_samples
216
        )
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        return biometric_references

    def score_samples(
        self,
        probe_features,
        biometric_references,
        allow_scoring_with_all_biometric_references=False,
    ):

        # TODO: Here, we are sending all computed biometric references to all
        # probes.  It would be more efficient if only the models related to each
        # probe are sent to the probing split.  An option would be to use caching
        # and allow the ``score`` function above to load the required data from
        # the disk, directly.  A second option would be to generate named delays
        # for each model and then associate them here.

234
        all_references = dask.delayed(list)(biometric_references)
235
        scores = probe_features.map_partitions(
236
            self.biometric_algorithm.score_samples,
237
238
239
240
            all_references,
            allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
        )
        return scores
241
242
243
244
245
246
247
248
249
250
251
252

    def enroll(self, data):
        return self.biometric_algorithm.enroll(data)

    def score(self, biometric_reference, data):
        return self.biometric_algorithm.score(biometric_reference, data)

    def score_multiple_biometric_references(self, biometric_references, data):
        return self.biometric_algorithm.score_multiple_biometric_references(
            biometric_references, data
        )

253
254
255
    def set_score_references_path(self, group):
        self.biometric_algorithm.set_score_references_path(group)

256

257
def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None):
258
    """
259
260
    Given a :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline`, wraps :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` and
    :any:`bob.bio.base.pipelines.vanilla_biometrics.BioAlgorithm` to be executed with dask
261
262
263
264

    Parameters
    ----------

265
    pipeline: :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline`
266
267
268
       Vanilla Biometrics based pipeline to be dasked

    npartitions: int
269
       Number of partitions for the initial `dask.bag`
270
271

    partition_size: int
272
       Size of the partition for the initial `dask.bag`
273
274
    """

275
    if isinstance(pipeline, ZTNormPipeline):
276
        # Dasking the first part of the pipelines
277
        pipeline.vanilla_biometrics_pipeline = dask_vanilla_biometrics(
278
279
280
            pipeline.vanilla_biometrics_pipeline,
            npartitions=npartitions,
            partition_size=partition_size,
281
        )
282
283
284
        pipeline.biometric_algorithm = (
            pipeline.vanilla_biometrics_pipeline.biometric_algorithm
        )
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
285
        pipeline.transformer = pipeline.vanilla_biometrics_pipeline.transformer
286

287
        pipeline.ztnorm_solver = ZTNormDaskWrapper(pipeline.ztnorm_solver)
288

289
    else:
290

291
        if partition_size is None:
292
            pipeline.transformer = bob.pipelines.wrap(
293
294
295
                ["dask"], pipeline.transformer, npartitions=npartitions
            )
        else:
296
            pipeline.transformer = bob.pipelines.wrap(
297
298
                ["dask"], pipeline.transformer, partition_size=partition_size
            )
299
300
        pipeline.biometric_algorithm = BioAlgorithmDaskWrapper(
            pipeline.biometric_algorithm
301
302
        )

303
304
        def _write_scores(scores):
            return scores.map_partitions(pipeline.write_scores_on_dask)
305

306
        pipeline.write_scores_on_dask = pipeline.write_scores
307
        pipeline.write_scores = _write_scores
308

309
    return pipeline
310

311

312
313
314
def checkpoint_vanilla_biometrics(
    pipeline, base_dir, biometric_algorithm_dir=None, hash_fn=None
):
315
    """
316
317
    Given a :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline`, wraps :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` and
    :any:`bob.bio.base.pipelines.vanilla_biometrics.BioAlgorithm` to be checkpointed
318
319
320
321

    Parameters
    ----------

322
    pipeline: :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline`
323
       Vanilla Biometrics based pipeline to be checkpointed
324
325

    base_dir: str
326
327
328
329
330
331
332
333
       Path to store transformed input data and possibly biometric references and scores

    biometric_algorithm_dir: str
       If set, it will checkpoint the biometric references and scores to this path.
       If not, `base_dir` will be used.
       This is useful when it's suitable to have the transformed data path, and biometric references and scores
       in different paths.

334
335
336
337
338
339
    hash_fn
       Pointer to a hash function. This hash function will map
       `sample.key` to a hash code and this hash code will be the
       relative directory where a single `sample` will be checkpointed.
       This is useful when is desireable file directories with more than
       a certain number of files.
340
341
342
    """

    sk_pipeline = pipeline.transformer
343
344
345
346
347

    bio_ref_scores_dir = (
        base_dir if biometric_algorithm_dir is None else biometric_algorithm_dir
    )

348
349
350
    for i, name, estimator in sk_pipeline._iter():

        wraped_estimator = bob.pipelines.wrap(
351
352
353
            ["checkpoint"],
            estimator,
            features_dir=os.path.join(base_dir, name),
354
            hash_fn=hash_fn,
355
356
357
358
        )

        sk_pipeline.steps[i] = (name, wraped_estimator)

359
    if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy):
360
        pipeline.biometric_algorithm.base_dir = bio_ref_scores_dir
361
362
    else:
        pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper(
363
            pipeline.biometric_algorithm, base_dir=bio_ref_scores_dir, hash_fn=hash_fn
364
365
366
        )

    return pipeline
367
368
369
370


def is_checkpointed(pipeline):
    """
371
    Check if :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` is checkpointed
372
373
374
375
376


    Parameters
    ----------

377
    pipeline: :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline`
378
379
380
381
       Vanilla Biometrics based pipeline to be checkpointed

    """

382
383
    # We have to check if is BioAlgorithmCheckpointWrapper OR
    # If it BioAlgorithmLegacy and the transformer of BioAlgorithmLegacy is also checkpointable
384
385
    return isinstance_nested(
        pipeline, "biometric_algorithm", BioAlgorithmCheckpointWrapper
386
387
388
    ) or (
        isinstance_nested(pipeline, "biometric_algorithm", BioAlgorithmLegacy)
        and isinstance_nested(pipeline, "transformer", CheckpointWrapper)
389
    )