callbacks.py 1.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from pytorch_lightning.callbacks import Callback

import logging
import os
from bob.bio.base.script.vanilla_biometrics import vanilla_biometrics
import bob.bio.base
import bob.measure

logger = logging.getLogger(__name__)


class VanillaBiometricsCallback(Callback):
    def __init__(self, config, output_path, name="vanilla-biometrics", fmr=0.001):
        """
        Callback that calls `bob bio pipelines vanilla-biometrics` at every `on_epoch_end`.
        FNMR@FMR=fmr is reported at every epoch
        

        Parameters
        ----------
           
           config: str
             Path containing the `bob bio pipelines vanilla-biometrics` input script.
             Please, check :any:`bob.bio.base.vanilla_biometrics_intro` on how to setup the 

           output_path: str
             Path where the checkpoiny is being written
             
           fmr: float
              False match rate threshold that will be used to compute FNRM

        """
        self.config = config
        self.fmr = fmr
        self.output_path = output_path
36
        self.scores_dev = os.path.join(output_path, "scores-dev.csv")
37
38
39
40
41
42
43
44
45
46
47
        super(VanillaBiometricsCallback, self).__init__()

    def on_train_epoch_end(self, epoch, logs=None):
        logger.info(f"Run vanilla biometrics {epoch}. Input script: {self.config}")

        vanilla_biometrics.main(
            [self.config],
            prog_name="bob bio pipelines vanilla-biometrics",
            standalone_mode=False,
        )

48
        neg, pos = bob.bio.base.score.load.split_csv_scores(self.scores_dev)
49
50
51
52
53
        far_thres = bob.measure.far_threshold(neg, pos, self.fmr)
        fmr, fnmr = bob.measure.fprfnr(neg, pos, far_thres)

        self.log(f"validation/fnmr@fmr={fmr}", fnmr)