train.py 10.33 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from lightning.pytorch import seed_everything
from ..utils.checkpointer import get_checkpoint
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.command(
entry_point_group="ptbench.config",
cls=ConfigCommand,
epilog="""Examples:
\b
1. Trains PASA model with Montgomery dataset, on a GPU (``cuda:0``):
.. code:: sh
ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0"
""",
)
@click.option(
"--output-folder",
"-o",
help="Path where to store the generated model (created if does not exist)",
required=True,
type=click.Path(),
default="results",
cls=ResourceOption,
)
@click.option(
"--model",
"-m",
help="A torch.nn.Module instance implementing the network to be trained",
required=True,
cls=ResourceOption,
)
@click.option(
"--datamodule",
"-d",
help="A dictionary mapping string keys to "
"torch.utils.data.dataset.Dataset instances implementing datasets "
"to be used for training and validating the model, possibly including all "
"pre-processing pipelines required or, optionally, a dictionary mapping "
"string keys to torch.utils.data.dataset.Dataset instances. At least "
"one key named ``train`` must be available. This dataset will be used for "
"training the network model. The dataset description must include all "
"required pre-processing, including eventual data augmentation. If a "
"dataset named ``__train__`` is available, it is used prioritarily for "
"training instead of ``train``. If a dataset named ``__valid__`` is "
"available, it is used for model validation (and automatic "
"check-pointing) at each epoch. If a dataset list named "
"``__extra_valid__`` is available, then it will be tracked during the "
"validation process and its loss output at the training log as well, "
"in the format of an array occupying a single column. All other keys "
"are considered test datasets and are ignored during training",
required=True,
cls=ResourceOption,
)
@click.option(
"--criterion",
help="A loss function to compute the CNN error for every sample "
"respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
required=True,
cls=ResourceOption,
)
@click.option(
"--criterion-valid",
help="A specific loss function for the validation set to compute the CNN"
"error for every sample respecting the PyTorch API for loss functions"
"(see torch.nn.modules.loss)",
required=False,
cls=ResourceOption,
)
@click.option(
"--batch-size",
"-b",
help="Number of samples in every batch (this parameter affects "
"memory requirements for the network). If the number of samples in "
"the batch is larger than the total number of samples available for "
"training, this value is truncated. If this number is smaller, then "
"batches of the specified size are created and fed to the network "
"until there are no more new samples to feed (epoch is finished). "
"If the total number of training samples is not a multiple of the "
"batch-size, the last batch will be smaller than the first, unless "
"--drop-incomplete-batch is set, in which case this batch is not used.",
required=True,
show_default=True,
default=1,
type=click.IntRange(min=1),
cls=ResourceOption,
)
@click.option(
"--batch-chunk-count",
"-c",
help="Number of chunks in every batch (this parameter affects "
"memory requirements for the network). The number of samples "
"loaded for every iteration will be batch-size/batch-chunk-count. "
"batch-size needs to be divisible by batch-chunk-count, otherwise an "
"error will be raised. This parameter is used to reduce number of "
"samples loaded in each iteration, in order to reduce the memory usage "
"in exchange for processing time (more iterations). This is specially "
"interesting whe one is running with GPUs with limited RAM. The "
"default of 1 forces the whole batch to be processed at once. Otherwise "
"the batch is broken into batch-chunk-count pieces, and gradients are "
"accumulated to complete each batch.",
required=True,
show_default=True,
default=1,
type=click.IntRange(min=1),
cls=ResourceOption,
)
@click.option(
"--drop-incomplete-batch/--no-drop-incomplete-batch",
"-D",
help="If set, then may drop the last batch in an epoch, in case it is "
"incomplete. If you set this option, you should also consider "
"increasing the total number of epochs of training, as the total number "
"of training steps may be reduced",
required=True,
show_default=True,
default=False,
cls=ResourceOption,
)
@click.option(
"--epochs",
"-e",
help="Number of epochs (complete training set passes) to train for. "
"If continuing from a saved checkpoint, ensure to provide a greater "
"number of epochs than that saved on the checkpoint to be loaded. ",
show_default=True,
required=True,
default=1000,
type=click.IntRange(min=1),
cls=ResourceOption,
)
@click.option(
"--checkpoint-period",
"-p",
help="Number of epochs after which a checkpoint is saved. "
"A value of zero will disable check-pointing. If checkpointing is "
"enabled and training stops, it is automatically resumed from the "
"last saved checkpoint if training is restarted with the same "
"configuration.",
show_default=True,
required=False,
default=None,
type=click.IntRange(min=0),
cls=ResourceOption,
)
@click.option(
"--accelerator",
"-a",
help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
show_default=True,
required=True,
default="cpu",
cls=ResourceOption,
)
@click.option(
"--cache-samples",
help="If set to True, loads the sample into memory, otherwise loads them at runtime.",
required=True,
show_default=True,
default=False,
cls=ResourceOption,
)
@click.option(
"--seed",
"-s",
help="Seed to use for the random number generator",
show_default=True,
required=False,
default=42,
type=click.IntRange(min=0),
cls=ResourceOption,
)
@click.option(
"--parallel",
"-P",
help="""Use multiprocessing for data loading: if set to -1 (default),
disables multiprocessing data loading. Set to 0 to enable as many data
loading instances as processing cores as available in the system. Set to
>= 1 to enable that many multiprocessing instances for data loading.""",
type=click.IntRange(min=-1),
show_default=True,
required=True,
default=-1,
cls=ResourceOption,
)
@click.option(
"--normalization",
"-n",
help="Z-Normalization of input images: 'imagenet' for ImageNet parameters,"
" 'current' for parameters of the current trainset, "
"'none' for no normalization.",
required=False,
default="none",
cls=ResourceOption,
)
@click.option(
"--monitoring-interval",
"-I",
help="""Time between checks for the use of resources during each training
epoch. An interval of 5 seconds, for example, will lead to CPU and GPU
resources being probed every 5 seconds during each training epoch.
Values registered in the training logs correspond to averages (or maxima)
observed through possibly many probes in each epoch. Notice that setting a
very small value may cause the probing process to become extremely busy,
potentially biasing the overall perception of resource usage.""",
type=click.FloatRange(min=0.1),
show_default=True,
required=True,
default=5.0,
cls=ResourceOption,
)
@click.option(
"--resume-from",
help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.",
type=str,
required=False,
default=None,
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def train(
model,
output_folder,
epochs,
batch_size,
batch_chunk_count,
drop_incomplete_batch,
criterion,
criterion_valid,
datamodule,
checkpoint_period,
accelerator,
cache_samples,
seed,
parallel,
normalization,
monitoring_interval,
resume_from,
**_,
):
"""Trains an CNN to perform tuberculosis detection.
Training is performed for a configurable number of epochs, and
generates at least a final_model.pth. It may also generate a number
of intermediate checkpoints. Checkpoints are model files (.pth
files) that are stored during the training and useful to resume the
procedure in case it stops abruptly.
"""
import torch.cuda
import torch.nn
from ..data.dataset import normalize_data, reweight_BCEWithLogitsLoss
from ..engine.trainer import run
seed_everything(seed)
checkpoint_file = get_checkpoint(output_folder, resume_from)
datamodule = datamodule(
batch_size=batch_size,
batch_chunk_count=batch_chunk_count,
drop_incomplete_batch=drop_incomplete_batch,
cache_samples=cache_samples,
parallel=parallel,
)
datamodule.prepare_data()
datamodule.setup(stage="fit")
reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid)
normalize_data(normalization, model, datamodule)
arguments = {}
arguments["max_epoch"] = epochs
arguments["epoch"] = 0
# We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit()
if checkpoint_file is not None:
checkpoint = torch.load(checkpoint_file)
arguments["epoch"] = checkpoint["epoch"]
logger.info("Training for {} epochs".format(arguments["max_epoch"]))
logger.info("Continuing from epoch {}".format(arguments["epoch"]))
run(
model=model,
datamodule=datamodule,
checkpoint_period=checkpoint_period,
accelerator=accelerator,
arguments=arguments,
output_folder=output_folder,
monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count,
checkpoint=checkpoint_file,
)