vanilla_pad.py 4.94 KB
Newer Older
1 2 3
"""Executes PAD pipeline"""


4
from bob.pipelines.distributed import VALID_DASK_CLIENT_STRINGS
5
import click
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
6 7 8
from bob.extension.scripts.click_helper import ConfigCommand
from bob.extension.scripts.click_helper import ResourceOption
from bob.extension.scripts.click_helper import verbosity_option
9 10


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
11 12 13 14
@click.command(
    entry_point_group="bob.pad.config",
    cls=ConfigCommand,
    epilog="""\b
15 16 17 18 19
 Command line examples\n
 -----------------------


 $ bob pad vanilla-pad my_experiment.py -vv
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
20
""",
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
)
@click.option(
    "--pipeline",
    "-p",
    required=True,
    entry_point_group="sklearn.pipeline",
    help="Feature extraction algorithm",
    cls=ResourceOption,
)
@click.option(
    "--database",
    "-d",
    required=True,
    cls=ResourceOption,
    entry_point_group="bob.pad.database",
    help="PAD Database connector (class that implements the methods: `fit_samples`, `predict_samples`)",
)
@click.option(
    "--dask-client",
    "-l",
41 42 43
    entry_point_group="dask.client",
    string_exceptions=VALID_DASK_CLIENT_STRINGS,
    default="single-threaded",
44
    help="Dask client for the execution of the pipeline.",
45
    cls=ResourceOption,
46 47 48 49 50 51 52
)
@click.option(
    "--group",
    "-g",
    "groups",
    type=click.Choice(["dev", "eval"]),
    multiple=True,
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
53
    default=("dev", "eval"),
54 55 56 57 58 59 60
    help="If given, this value will limit the experiments belonging to a particular group",
)
@click.option(
    "-o",
    "--output",
    show_default=True,
    default="results",
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
61
    help="Saves scores (and checkpoints) in this folder.",
62 63 64 65 66 67 68 69 70
)
@click.option(
    "--checkpoint",
    "-c",
    is_flag=True,
    help="If set, it will checkpoint all steps of the pipeline. Checkpoints will be saved in `--output`.",
    cls=ResourceOption,
)
@verbosity_option(cls=ResourceOption)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
71
@click.pass_context
72 73 74
def vanilla_pad(
    ctx, pipeline, database, dask_client, groups, output, checkpoint, **kwargs
):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
75
    """Runs the simplest PAD pipeline."""
76 77

    import gzip
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
78
    import logging
79
    import os
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
80
    import sys
81 82
    from glob import glob

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
83
    import bob.pipelines as mario
84
    import dask.bag
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
85
    from bob.extension.scripts.click_helper import log_parameters
86
    from bob.pipelines.distributed.sge import get_resource_requirements
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
87 88 89

    logger = logging.getLogger(__name__)
    log_parameters(logger)
90 91 92 93 94 95 96 97

    os.makedirs(output, exist_ok=True)

    if checkpoint:
        pipeline = mario.wrap(
            ["checkpoint"], pipeline, features_dir=output, model_path=output
        )

98
    if dask_client is None:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
99 100 101 102 103 104 105 106 107
        logger.warning("`dask_client` not set. Your pipeline will run locally")

    # create an experiment info file
    with open(os.path.join(output, "Experiment_info.txt"), "wt") as f:
        f.write(f"{sys.argv!r}\n")
        f.write(f"database={database!r}\n")
        f.write("Pipeline steps:\n")
        for i, name, estimator in pipeline._iter():
            f.write(f"Step {i}: {name}\n{estimator!r}\n")
108 109

    # train the pipeline
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
110 111
    fit_samples = database.fit_samples()
    pipeline.fit(fit_samples)
112 113 114 115

    for group in groups:

        logger.info(f"Running vanilla biometrics for group {group}")
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
116
        predict_samples = database.predict_samples(group=group)
117 118
        result = pipeline.decision_function(predict_samples)

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
119
        scores_path = os.path.join(output, f"scores-{group}")
120

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
121
        if isinstance(result, dask.bag.core.Bag):
122

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
123 124 125 126 127 128
            # write each partition into a zipped txt file
            result = result.map(pad_predicted_sample_to_score_line)
            prefix, postfix = f"{output}/scores/scores-{group}-", ".txt.gz"
            pattern = f"{prefix}*{postfix}"
            os.makedirs(os.path.dirname(prefix), exist_ok=True)
            logger.info("Writing bag results into files ...")
129
            resources = get_resource_requirements(pipeline)
130 131 132
            result.to_textfiles(
                pattern, last_endline=True, scheduler=dask_client, resources=resources
            )
133

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
134
            with open(scores_path, "w") as f:
135 136 137 138 139 140 141
                # concatenate scores into one score file
                for path in sorted(
                    glob(pattern),
                    key=lambda l: int(l.replace(prefix, "").replace(postfix, "")),
                ):
                    with gzip.open(path, "rt") as f2:
                        f.write(f2.read())
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
142 143
                    # delete intermediate score files
                    os.remove(path)
144

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
145 146
        else:
            with open(scores_path, "w") as f:
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
                for sample in result:
                    f.write(pad_predicted_sample_to_score_line(sample, endl="\n"))


def pad_predicted_sample_to_score_line(sample, endl=""):
    claimed_id, test_label, score = sample.subject, sample.key, sample.data

    # # use the model_label field to indicate frame number
    # model_label = None
    # if hasattr(sample, "frame_id"):
    #     model_label = sample.frame_id

    real_id = claimed_id if sample.is_bonafide else sample.attack_type

    return f"{claimed_id} {real_id} {test_label} {score}{endl}"
    # return f"{claimed_id} {model_label} {real_id} {test_label} {score}{endl}"