vanilla_pad.py 4.67 KB
Newer Older
1 2 3 4
"""Executes PAD pipeline"""


import click
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
5 6 7
from bob.extension.scripts.click_helper import ConfigCommand
from bob.extension.scripts.click_helper import ResourceOption
from bob.extension.scripts.click_helper import verbosity_option
8 9


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
10 11 12 13
@click.command(
    entry_point_group="bob.pad.config",
    cls=ConfigCommand,
    epilog="""\b
14 15 16 17 18
 Command line examples\n
 -----------------------


 $ bob pad vanilla-pad my_experiment.py -vv
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
19
""",
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
)
@click.option(
    "--pipeline",
    "-p",
    required=True,
    entry_point_group="sklearn.pipeline",
    help="Feature extraction algorithm",
    cls=ResourceOption,
)
@click.option(
    "--database",
    "-d",
    required=True,
    cls=ResourceOption,
    entry_point_group="bob.pad.database",
    help="PAD Database connector (class that implements the methods: `fit_samples`, `predict_samples`)",
)
@click.option(
    "--dask-client",
    "-l",
    required=False,
    cls=ResourceOption,
    help="Dask client for the execution of the pipeline.",
)
@click.option(
    "--group",
    "-g",
    "groups",
    type=click.Choice(["dev", "eval"]),
    multiple=True,
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
50
    default=("dev", "eval"),
51 52 53 54 55 56 57
    help="If given, this value will limit the experiments belonging to a particular group",
)
@click.option(
    "-o",
    "--output",
    show_default=True,
    default="results",
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
58
    help="Saves scores (and checkpoints) in this folder.",
59 60 61 62 63 64 65 66 67
)
@click.option(
    "--checkpoint",
    "-c",
    is_flag=True,
    help="If set, it will checkpoint all steps of the pipeline. Checkpoints will be saved in `--output`.",
    cls=ResourceOption,
)
@verbosity_option(cls=ResourceOption)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
68 69 70
@click.pass_context
def vanilla_pad(ctx, pipeline, database, dask_client, groups, output, checkpoint, **kwargs):
    """Runs the simplest PAD pipeline."""
71 72

    import gzip
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
73
    import logging
74
    import os
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
75
    import sys
76 77
    from glob import glob

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
78
    import bob.pipelines as mario
79
    import dask.bag
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
80 81 82 83
    from bob.extension.scripts.click_helper import log_parameters

    logger = logging.getLogger(__name__)
    log_parameters(logger)
84 85 86 87 88 89 90 91 92 93

    os.makedirs(output, exist_ok=True)

    if checkpoint:
        pipeline = mario.wrap(
            ["checkpoint"], pipeline, features_dir=output, model_path=output
        )

    if dask_client is not None:
        pipeline = mario.wrap(["dask"], pipeline)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
94 95 96 97 98 99 100 101 102 103
    else:
        logger.warning("`dask_client` not set. Your pipeline will run locally")

    # create an experiment info file
    with open(os.path.join(output, "Experiment_info.txt"), "wt") as f:
        f.write(f"{sys.argv!r}\n")
        f.write(f"database={database!r}\n")
        f.write("Pipeline steps:\n")
        for i, name, estimator in pipeline._iter():
            f.write(f"Step {i}: {name}\n{estimator!r}\n")
104 105

    # train the pipeline
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
106 107
    fit_samples = database.fit_samples()
    pipeline.fit(fit_samples)
108 109 110 111

    for group in groups:

        logger.info(f"Running vanilla biometrics for group {group}")
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
112
        predict_samples = database.predict_samples(group=group)
113 114
        result = pipeline.decision_function(predict_samples)

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
115
        scores_path = os.path.join(output, f"scores-{group}")
116

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
117
        if isinstance(result, dask.bag.core.Bag):
118

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
119 120 121 122 123 124 125
            # write each partition into a zipped txt file
            result = result.map(pad_predicted_sample_to_score_line)
            prefix, postfix = f"{output}/scores/scores-{group}-", ".txt.gz"
            pattern = f"{prefix}*{postfix}"
            os.makedirs(os.path.dirname(prefix), exist_ok=True)
            logger.info("Writing bag results into files ...")
            result.to_textfiles(pattern, last_endline=True, scheduler=dask_client)
126

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
127
            with open(scores_path, "w") as f:
128 129 130 131 132 133 134
                # concatenate scores into one score file
                for path in sorted(
                    glob(pattern),
                    key=lambda l: int(l.replace(prefix, "").replace(postfix, "")),
                ):
                    with gzip.open(path, "rt") as f2:
                        f.write(f2.read())
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
135 136
                    # delete intermediate score files
                    os.remove(path)
137

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
138 139
        else:
            with open(scores_path, "w") as f:
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
                for sample in result:
                    f.write(pad_predicted_sample_to_score_line(sample, endl="\n"))


def pad_predicted_sample_to_score_line(sample, endl=""):
    claimed_id, test_label, score = sample.subject, sample.key, sample.data

    # # use the model_label field to indicate frame number
    # model_label = None
    # if hasattr(sample, "frame_id"):
    #     model_label = sample.frame_id

    real_id = claimed_id if sample.is_bonafide else sample.attack_type

    return f"{claimed_id} {real_id} {test_label} {score}{endl}"
    # return f"{claimed_id} {model_label} {real_id} {test_label} {score}{endl}"