vanilla_pad.py 4.92 KB
Newer Older
1 2 3
"""Executes PAD pipeline"""


4
from bob.bio.base.script.vanilla_biometrics import VALID_DASK_CLIENT_STRINGS
5
import click
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
6 7 8
from bob.extension.scripts.click_helper import ConfigCommand
from bob.extension.scripts.click_helper import ResourceOption
from bob.extension.scripts.click_helper import verbosity_option
9 10


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
11 12 13 14
@click.command(
    entry_point_group="bob.pad.config",
    cls=ConfigCommand,
    epilog="""\b
15 16 17 18 19
 Command line examples\n
 -----------------------


 $ bob pad vanilla-pad my_experiment.py -vv
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
20
""",
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
)
@click.option(
    "--pipeline",
    "-p",
    required=True,
    entry_point_group="sklearn.pipeline",
    help="Feature extraction algorithm",
    cls=ResourceOption,
)
@click.option(
    "--database",
    "-d",
    required=True,
    cls=ResourceOption,
    entry_point_group="bob.pad.database",
    help="PAD Database connector (class that implements the methods: `fit_samples`, `predict_samples`)",
)
@click.option(
    "--dask-client",
    "-l",
41 42 43
    entry_point_group="dask.client",
    string_exceptions=VALID_DASK_CLIENT_STRINGS,
    default="single-threaded",
44
    help="Dask client for the execution of the pipeline.",
45
    cls=ResourceOption,
46 47 48 49 50 51 52
)
@click.option(
    "--group",
    "-g",
    "groups",
    type=click.Choice(["dev", "eval"]),
    multiple=True,
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
53
    default=("dev", "eval"),
54 55 56 57 58 59 60
    help="If given, this value will limit the experiments belonging to a particular group",
)
@click.option(
    "-o",
    "--output",
    show_default=True,
    default="results",
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
61
    help="Saves scores (and checkpoints) in this folder.",
62 63 64 65 66 67 68 69 70
)
@click.option(
    "--checkpoint",
    "-c",
    is_flag=True,
    help="If set, it will checkpoint all steps of the pipeline. Checkpoints will be saved in `--output`.",
    cls=ResourceOption,
)
@verbosity_option(cls=ResourceOption)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
71 72 73
@click.pass_context
def vanilla_pad(ctx, pipeline, database, dask_client, groups, output, checkpoint, **kwargs):
    """Runs the simplest PAD pipeline."""
74 75

    import gzip
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
76
    import logging
77
    import os
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
78
    import sys
79 80
    from glob import glob

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
81
    import bob.pipelines as mario
82
    import dask.bag
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
83
    from bob.extension.scripts.click_helper import log_parameters
84
    from bob.pipelines.distributed.sge import get_resource_requirements
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
85 86 87

    logger = logging.getLogger(__name__)
    log_parameters(logger)
88 89 90 91 92 93 94 95

    os.makedirs(output, exist_ok=True)

    if checkpoint:
        pipeline = mario.wrap(
            ["checkpoint"], pipeline, features_dir=output, model_path=output
        )

96
    if dask_client is None:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
97 98 99 100 101 102 103 104 105
        logger.warning("`dask_client` not set. Your pipeline will run locally")

    # create an experiment info file
    with open(os.path.join(output, "Experiment_info.txt"), "wt") as f:
        f.write(f"{sys.argv!r}\n")
        f.write(f"database={database!r}\n")
        f.write("Pipeline steps:\n")
        for i, name, estimator in pipeline._iter():
            f.write(f"Step {i}: {name}\n{estimator!r}\n")
106 107

    # train the pipeline
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
108 109
    fit_samples = database.fit_samples()
    pipeline.fit(fit_samples)
110 111 112 113

    for group in groups:

        logger.info(f"Running vanilla biometrics for group {group}")
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
114
        predict_samples = database.predict_samples(group=group)
115 116
        result = pipeline.decision_function(predict_samples)

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
117
        scores_path = os.path.join(output, f"scores-{group}")
118

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
119
        if isinstance(result, dask.bag.core.Bag):
120

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
121 122 123 124 125 126
            # write each partition into a zipped txt file
            result = result.map(pad_predicted_sample_to_score_line)
            prefix, postfix = f"{output}/scores/scores-{group}-", ".txt.gz"
            pattern = f"{prefix}*{postfix}"
            os.makedirs(os.path.dirname(prefix), exist_ok=True)
            logger.info("Writing bag results into files ...")
127 128
            resources = get_resource_requirements(pipeline)
            result.to_textfiles(pattern, last_endline=True, scheduler=dask_client, resources=resources)
129

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
130
            with open(scores_path, "w") as f:
131 132 133 134 135 136 137
                # concatenate scores into one score file
                for path in sorted(
                    glob(pattern),
                    key=lambda l: int(l.replace(prefix, "").replace(postfix, "")),
                ):
                    with gzip.open(path, "rt") as f2:
                        f.write(f2.read())
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
138 139
                    # delete intermediate score files
                    os.remove(path)
140

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
141 142
        else:
            with open(scores_path, "w") as f:
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
                for sample in result:
                    f.write(pad_predicted_sample_to_score_line(sample, endl="\n"))


def pad_predicted_sample_to_score_line(sample, endl=""):
    claimed_id, test_label, score = sample.subject, sample.key, sample.data

    # # use the model_label field to indicate frame number
    # model_label = None
    # if hasattr(sample, "frame_id"):
    #     model_label = sample.frame_id

    real_id = claimed_id if sample.is_bonafide else sample.attack_type

    return f"{claimed_id} {real_id} {test_label} {score}{endl}"
    # return f"{claimed_id} {model_label} {real_id} {test_label} {score}{endl}"