Skip to content
Snippets Groups Projects
train.py 11.21 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import functools
import pathlib

import click

from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup

from .click import ConfigCommand

logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")


def reusable_options(f):
    """Options that can be re-used by top-level scripts (i.e. ``experiment```).

    This decorator equips the target function ``f`` with all (reusable)
    ``train`` script options.


    Parameters
    ----------
    f
        The target function to equip with options.  This function must have
        parameters that accept such options.


    Returns
    -------
        The decorated version of function ``f``
    """

    @click.option(
        "--output-folder",
        "-o",
        help="Path where to store results (created if does not exist)",
        required=True,
        type=click.Path(
            file_okay=False,
            dir_okay=True,
            writable=True,
            path_type=pathlib.Path,
        ),
        default="results",
        cls=ResourceOption,
    )
    @click.option(
        "--model",
        "-m",
        help="A lightining module instance implementing the network to be trained",
        required=True,
        cls=ResourceOption,
    )
    @click.option(
        "--datamodule",
        "-d",
        help="A lighting data module containing the training and validation sets.",
        required=True,
        cls=ResourceOption,
    )
    @click.option(
        "--batch-size",
        "-b",
        help="Number of samples in every batch (this parameter affects "
        "memory requirements for the network).  If the number of samples in "
        "the batch is larger than the total number of samples available for "
        "training, this value is truncated.  If this number is smaller, then "
        "batches of the specified size are created and fed to the network "
        "until there are no more new samples to feed (epoch is finished).  "
        "If the total number of training samples is not a multiple of the "
        "batch-size, the last batch will be smaller than the first, unless "
        "--drop-incomplete-batch is set, in which case this batch is not used.",
        required=True,
        show_default=True,
        default=1,
        type=click.IntRange(min=1),
        cls=ResourceOption,
    )
    @click.option(
        "--batch-chunk-count",
        "-c",
        help="Number of chunks in every batch (this parameter affects "
        "memory requirements for the network). The number of samples "
        "loaded for every iteration will be batch-size/batch-chunk-count. "
        "batch-size needs to be divisible by batch-chunk-count, otherwise an "
        "error will be raised. This parameter is used to reduce number of "
        "samples loaded in each iteration, in order to reduce the memory usage "
        "in exchange for processing time (more iterations).  This is specially "
        "interesting whe one is running with GPUs with limited RAM. The "
        "default of 1 forces the whole batch to be processed at once.  Otherwise "
        "the batch is broken into batch-chunk-count pieces, and gradients are "
        "accumulated to complete each batch.",
        required=True,
        show_default=True,
        default=1,
        type=click.IntRange(min=1),
        cls=ResourceOption,
    )
    @click.option(
        "--drop-incomplete-batch/--no-drop-incomplete-batch",
        "-D",
        help="If set, then may drop the last batch in an epoch, in case it is "
        "incomplete.  If you set this option, you should also consider "
        "increasing the total number of epochs of training, as the total number "
        "of training steps may be reduced",
        required=True,
        show_default=True,
        default=False,
        cls=ResourceOption,
    )
    @click.option(
        "--epochs",
        "-e",
        help="""Number of epochs (complete training set passes) to train for.
        If continuing from a saved checkpoint, ensure to provide a greater
        number of epochs than that saved on the checkpoint to be loaded.""",
        show_default=True,
        required=True,
        default=1000,
        type=click.IntRange(min=1),
        cls=ResourceOption,
    )
    @click.option(
        "--checkpoint-period",
        "-p",
        help="""Number of epochs after which a checkpoint is saved. A value of
        zero will disable check-pointing. If checkpointing is enabled and
        training stops, it is automatically resumed from the last saved
        checkpoint if training is restarted with the same configuration.""",
        show_default=True,
        required=False,
        default=None,
        type=click.IntRange(min=0),
        cls=ResourceOption,
    )
    @click.option(
        "--device",
        "-x",
        help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
        show_default=True,
        required=True,
        default="cpu",
        cls=ResourceOption,
    )
    @click.option(
        "--cache-samples/--no-cache-samples",
        help="If set to True, loads the sample into memory, "
        "otherwise loads them at runtime.",
        required=True,
        show_default=True,
        default=False,
        cls=ResourceOption,
    )
    @click.option(
        "--seed",
        "-s",
        help="Seed to use for the random number generator",
        show_default=True,
        required=False,
        default=42,
        type=click.IntRange(min=0),
        cls=ResourceOption,
    )
    @click.option(
        "--parallel",
        "-P",
        help="""Use multiprocessing for data loading: if set to -1 (default),
        disables multiprocessing data loading.  Set to 0 to enable as many data
        loading instances as processing cores available in the system.  Set to
        >= 1 to enable that many multiprocessing instances for data
        loading.""",
        type=click.IntRange(min=-1),
        show_default=True,
        required=True,
        default=-1,
        cls=ResourceOption,
    )
    @click.option(
        "--monitoring-interval",
        "-I",
        help="""Time between checks for the use of resources during each training
        epoch.  An interval of 5 seconds, for example, will lead to CPU and GPU
        resources being probed every 5 seconds during each training epoch.
        Values registered in the training logs correspond to averages (or maxima)
        observed through possibly many probes in each epoch.  Notice that setting a
        very small value may cause the probing process to become extremely busy,
        potentially biasing the overall perception of resource usage.""",
        type=click.FloatRange(min=0.1),
        show_default=True,
        required=True,
        default=5.0,
        cls=ResourceOption,
    )
    @click.option(
        "--resume-from",
        help="""Which checkpoint to resume training from. If set, can be one of
        `best`, `last`, or a path to a model checkpoint.""",
        type=click.STRING,
        required=False,
        default=None,
        cls=ResourceOption,
    )
    @click.option(
        "--balance-classes/--no-balance-classes",
        "-B/-N",
        help="""If set, then balances weights of the random sampler during
        training, so that samples from all sample classes are picked picked
        equitably.""",
        required=True,
        show_default=True,
        default=True,
        cls=ResourceOption,
    )
    @functools.wraps(f)
    def wrapper_reusable_options(*args, **kwargs):
        return f(*args, **kwargs)

    return wrapper_reusable_options


@click.command(
    entry_point_group="ptbench.config",
    cls=ConfigCommand,
    epilog="""Examples:

1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``):

   .. code:: sh

      ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0"
""",
)
@reusable_options
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def train(
    model,
    output_folder,
    epochs,
    batch_size,
    batch_chunk_count,
    drop_incomplete_batch,
    datamodule,
    checkpoint_period,
    device,
    cache_samples,
    seed,
    parallel,
    monitoring_interval,
    resume_from,
    balance_classes,
    **_,
) -> None:
    """Trains an CNN to perform image classification.

    Training is performed for a configurable number of epochs, and
    generates at least a final_model.ckpt.  It may also generate a
    number of intermediate checkpoints.  Checkpoints are model files
    (.ckpt files) that are stored during the training and useful to
    resume the procedure in case it stops abruptly.
    """

    import torch

    from lightning.pytorch import seed_everything

    from ..engine.device import DeviceManager
    from ..engine.trainer import run
    from ..utils.checkpointer import get_checkpoint
    from .utils import save_sh_command

    save_sh_command(output_folder / "command.sh")
    seed_everything(seed)

    checkpoint_file = get_checkpoint(output_folder, resume_from)

    # reset datamodule with user configurable options
    datamodule.set_chunk_size(batch_size, batch_chunk_count)
    datamodule.drop_incomplete_batch = drop_incomplete_batch
    datamodule.cache_samples = cache_samples
    datamodule.parallel = parallel
    datamodule.model_transforms = model.model_transforms

    datamodule.prepare_data()
    datamodule.setup(stage="fit")

    # If asked, rebalances the loss criterion based on the relative proportion
    # of class examples available in the training set.  Also affects the
    # validation loss if a validation set is available on the data module.
    if balance_classes:
        logger.info("Applying datamodule train sampler balancing...")
        datamodule.balance_sampler_by_class = True
        # logger.info("Applying train/valid loss balancing...")
        # model.balance_losses_by_class(datamodule)
    else:
        logger.info(
            "Skipping sample class/dataset ownership balancing on user request"
        )

    logger.info(f"Training for at most {epochs} epochs.")

    arguments = {}
    arguments["max_epoch"] = epochs
    arguments["epoch"] = 0

    if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
        # Sets the model normalizer with the unaugmented-train-subset.
        # this call may be a NOOP, if the model was pre-trained and expects
        # different weights for the normalisation layer.
        if hasattr(model, "set_normalizer"):
            model.set_normalizer(datamodule.unshuffled_train_dataloader())
        else:
            logger.warning(
                f"Model {model.name} has no 'set_normalizer' method. Skipping."
            )
    else:
        # Normalizer will be loaded during model.on_load_checkpoint
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint["epoch"]
        logger.info(f"Resuming from epoch {start_epoch}...")

    run(
        model=model,
        datamodule=datamodule,
        checkpoint_period=checkpoint_period,
        device_manager=DeviceManager(device),
        max_epochs=epochs,
        output_folder=output_folder,
        monitoring_interval=monitoring_interval,
        batch_chunk_count=batch_chunk_count,
        checkpoint=checkpoint_file,
    )