vanilla_pad.py 4.6 KB
Newer Older
1 2 3 4
"""Executes PAD pipeline"""


import click
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
5 6 7
from bob.extension.scripts.click_helper import ConfigCommand
from bob.extension.scripts.click_helper import ResourceOption
from bob.extension.scripts.click_helper import verbosity_option
8 9


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
10 11 12 13
@click.command(
    entry_point_group="bob.pad.config",
    cls=ConfigCommand,
    epilog="""\b
14 15 16 17 18
 Command line examples\n
 -----------------------


 $ bob pad vanilla-pad my_experiment.py -vv
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
19
""",
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
)
@click.option(
    "--pipeline",
    "-p",
    required=True,
    entry_point_group="sklearn.pipeline",
    help="Feature extraction algorithm",
    cls=ResourceOption,
)
@click.option(
    "--database",
    "-d",
    required=True,
    cls=ResourceOption,
    entry_point_group="bob.pad.database",
    help="PAD Database connector (class that implements the methods: `fit_samples`, `predict_samples`)",
)
@click.option(
    "--dask-client",
    "-l",
    required=False,
    cls=ResourceOption,
    help="Dask client for the execution of the pipeline.",
)
@click.option(
    "--group",
    "-g",
    "groups",
    type=click.Choice(["dev", "eval"]),
    multiple=True,
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
50
    default=("dev", "eval"),
51 52 53 54 55 56 57
    help="If given, this value will limit the experiments belonging to a particular group",
)
@click.option(
    "-o",
    "--output",
    show_default=True,
    default="results",
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
58
    help="Saves scores (and checkpoints) in this folder.",
59 60 61 62 63 64 65 66 67
)
@click.option(
    "--checkpoint",
    "-c",
    is_flag=True,
    help="If set, it will checkpoint all steps of the pipeline. Checkpoints will be saved in `--output`.",
    cls=ResourceOption,
)
@verbosity_option(cls=ResourceOption)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
68 69 70
@click.pass_context
def vanilla_pad(ctx, pipeline, database, dask_client, groups, output, checkpoint, **kwargs):
    """Runs the simplest PAD pipeline."""
71 72

    import gzip
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
73
    import logging
74
    import os
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
75
    import sys
76 77
    from glob import glob

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
78
    import bob.pipelines as mario
79
    import dask.bag
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
80 81 82 83
    from bob.extension.scripts.click_helper import log_parameters

    logger = logging.getLogger(__name__)
    log_parameters(logger)
84 85 86 87 88 89 90 91

    os.makedirs(output, exist_ok=True)

    if checkpoint:
        pipeline = mario.wrap(
            ["checkpoint"], pipeline, features_dir=output, model_path=output
        )

92
    if dask_client is None:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
93 94 95 96 97 98 99 100 101
        logger.warning("`dask_client` not set. Your pipeline will run locally")

    # create an experiment info file
    with open(os.path.join(output, "Experiment_info.txt"), "wt") as f:
        f.write(f"{sys.argv!r}\n")
        f.write(f"database={database!r}\n")
        f.write("Pipeline steps:\n")
        for i, name, estimator in pipeline._iter():
            f.write(f"Step {i}: {name}\n{estimator!r}\n")
102 103

    # train the pipeline
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
104 105
    fit_samples = database.fit_samples()
    pipeline.fit(fit_samples)
106 107 108 109

    for group in groups:

        logger.info(f"Running vanilla biometrics for group {group}")
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
110
        predict_samples = database.predict_samples(group=group)
111 112
        result = pipeline.decision_function(predict_samples)

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
113
        scores_path = os.path.join(output, f"scores-{group}")
114

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
115
        if isinstance(result, dask.bag.core.Bag):
116

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
117 118 119 120 121 122 123
            # write each partition into a zipped txt file
            result = result.map(pad_predicted_sample_to_score_line)
            prefix, postfix = f"{output}/scores/scores-{group}-", ".txt.gz"
            pattern = f"{prefix}*{postfix}"
            os.makedirs(os.path.dirname(prefix), exist_ok=True)
            logger.info("Writing bag results into files ...")
            result.to_textfiles(pattern, last_endline=True, scheduler=dask_client)
124

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
125
            with open(scores_path, "w") as f:
126 127 128 129 130 131 132
                # concatenate scores into one score file
                for path in sorted(
                    glob(pattern),
                    key=lambda l: int(l.replace(prefix, "").replace(postfix, "")),
                ):
                    with gzip.open(path, "rt") as f2:
                        f.write(f2.read())
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
133 134
                    # delete intermediate score files
                    os.remove(path)
135

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
136 137
        else:
            with open(scores_path, "w") as f:
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
                for sample in result:
                    f.write(pad_predicted_sample_to_score_line(sample, endl="\n"))


def pad_predicted_sample_to_score_line(sample, endl=""):
    claimed_id, test_label, score = sample.subject, sample.key, sample.data

    # # use the model_label field to indicate frame number
    # model_label = None
    # if hasattr(sample, "frame_id"):
    #     model_label = sample.frame_id

    real_id = claimed_id if sample.is_bonafide else sample.attack_type

    return f"{claimed_id} {real_id} {test_label} {score}{endl}"
    # return f"{claimed_id} {model_label} {real_id} {test_label} {score}{endl}"