# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import os import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def setup_pytorch_device(name): """Sets-up the pytorch device to use. Parameters ---------- name : str The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you set a specific cuda device such as ``cuda:1``, then we'll make sure it is currently set. Returns ------- device : :py:class:`torch.device` The pytorch device to use, pre-configured (and checked) """ import torch if name.startswith("cuda:"): # In case one has multiple devices, we must first set the one # we would like to use so pytorch can find it. logger.info(f"User set device to '{name}' - trying to force device...") os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1] if not torch.cuda.is_available(): raise RuntimeError( f"CUDA is not currently available, but " f"you set device to '{name}'" ) # Let pytorch auto-select from environment variable return torch.device("cuda") elif name.startswith("cuda"): # use default device logger.info(f"User set device to '{name}' - using default CUDA device") assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None # cuda or cpu return torch.device(name) def set_seeds(value, all_gpus): """Sets up all relevant random seeds (numpy, python, cuda) If running with multiple GPUs **at the same time**, set ``all_gpus`` to ``True`` to force all GPU seeds to be initialized. Reference: `PyTorch page for reproducibility <https://pytorch.org/docs/stable/notes/randomness.html>`_. Parameters ---------- value : int The random seed value to use all_gpus : :py:class:`bool`, Optional If set, then reset the seed on all GPUs available at once. This is normally **not** what you want if running on a single GPU """ import random import numpy.random import torch import torch.cuda random.seed(value) numpy.random.seed(value) torch.manual_seed(value) torch.cuda.manual_seed(value) # noop if cuda not available # set seeds for all gpus if all_gpus: torch.cuda.manual_seed_all(value) # noop if cuda not available def set_reproducible_cuda(): """Turns-off all CUDA optimizations that would affect reproducibility. For full reproducibility, also ensure not to use multiple (parallel) data lowers. That is setup ``num_workers=0``. Reference: `PyTorch page for reproducibility <https://pytorch.org/docs/stable/notes/randomness.html>`_. """ import torch.backends.cudnn # ensure to use only optimization algos for cuda that are known to have # a deterministic effect (not random) torch.backends.cudnn.deterministic = True # turns off any optimization tricks torch.backends.cudnn.benchmark = False @click.command( entry_point_group="ptbench.config", cls=ConfigCommand, epilog="""Examples: \b 1. Trains PASA model with Montgomery dataset, on a GPU (``cuda:0``): .. code:: sh ptbench train -vv pasa montgomery --batch-size=4 --device="cuda:0" """, ) @click.option( "--output-folder", "-o", help="Path where to store the generated model (created if does not exist)", required=True, type=click.Path(), default="results", cls=ResourceOption, ) @click.option( "--model", "-m", help="A torch.nn.Module instance implementing the network to be trained", required=True, cls=ResourceOption, ) @click.option( "--dataset", "-d", help="A dictionary mapping string keys to " "torch.utils.data.dataset.Dataset instances implementing datasets " "to be used for training and validating the model, possibly including all " "pre-processing pipelines required or, optionally, a dictionary mapping " "string keys to torch.utils.data.dataset.Dataset instances. At least " "one key named ``train`` must be available. This dataset will be used for " "training the network model. The dataset description must include all " "required pre-processing, including eventual data augmentation. If a " "dataset named ``__train__`` is available, it is used prioritarily for " "training instead of ``train``. If a dataset named ``__valid__`` is " "available, it is used for model validation (and automatic " "check-pointing) at each epoch. If a dataset list named " "``__extra_valid__`` is available, then it will be tracked during the " "validation process and its loss output at the training log as well, " "in the format of an array occupying a single column. All other keys " "are considered test datasets and are ignored during training", required=True, cls=ResourceOption, ) @click.option( "--optimizer", help="A torch.optim.Optimizer that will be used to train the network", required=True, cls=ResourceOption, ) @click.option( "--criterion", help="A loss function to compute the CNN error for every sample " "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)", required=True, cls=ResourceOption, ) @click.option( "--criterion-valid", help="A specific loss function for the validation set to compute the CNN" "error for every sample respecting the PyTorch API for loss functions" "(see torch.nn.modules.loss)", required=False, cls=ResourceOption, ) @click.option( "--batch-size", "-b", help="Number of samples in every batch (this parameter affects " "memory requirements for the network). If the number of samples in " "the batch is larger than the total number of samples available for " "training, this value is truncated. If this number is smaller, then " "batches of the specified size are created and fed to the network " "until there are no more new samples to feed (epoch is finished). " "If the total number of training samples is not a multiple of the " "batch-size, the last batch will be smaller than the first, unless " "--drop-incomplete-batch is set, in which case this batch is not used.", required=True, show_default=True, default=1, type=click.IntRange(min=1), cls=ResourceOption, ) @click.option( "--batch-chunk-count", "-c", help="Number of chunks in every batch (this parameter affects " "memory requirements for the network). The number of samples " "loaded for every iteration will be batch-size/batch-chunk-count. " "batch-size needs to be divisible by batch-chunk-count, otherwise an " "error will be raised. This parameter is used to reduce number of " "samples loaded in each iteration, in order to reduce the memory usage " "in exchange for processing time (more iterations). This is specially " "interesting whe one is running with GPUs with limited RAM. The " "default of 1 forces the whole batch to be processed at once. Otherwise " "the batch is broken into batch-chunk-count pieces, and gradients are " "accumulated to complete each batch.", required=True, show_default=True, default=1, type=click.IntRange(min=1), cls=ResourceOption, ) @click.option( "--drop-incomplete-batch/--no-drop-incomplete-batch", "-D", help="If set, then may drop the last batch in an epoch, in case it is " "incomplete. If you set this option, you should also consider " "increasing the total number of epochs of training, as the total number " "of training steps may be reduced", required=True, show_default=True, default=False, cls=ResourceOption, ) @click.option( "--epochs", "-e", help="Number of epochs (complete training set passes) to train for. " "If continuing from a saved checkpoint, ensure to provide a greater " "number of epochs than that saved on the checkpoint to be loaded. ", show_default=True, required=True, default=1000, type=click.IntRange(min=1), cls=ResourceOption, ) @click.option( "--checkpoint-period", "-p", help="Number of epochs after which a checkpoint is saved. " "A value of zero will disable check-pointing. If checkpointing is " "enabled and training stops, it is automatically resumed from the " "last saved checkpoint if training is restarted with the same " "configuration.", show_default=True, required=True, default=0, type=click.IntRange(min=0), cls=ResourceOption, ) @click.option( "--device", help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, default="cpu", cls=ResourceOption, ) @click.option( "--seed", "-s", help="Seed to use for the random number generator", show_default=True, required=False, default=42, type=click.IntRange(min=0), cls=ResourceOption, ) @click.option( "--parallel", "-P", help="""Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores as available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading.""", type=click.IntRange(min=-1), show_default=True, required=True, default=-1, cls=ResourceOption, ) @click.option( "--weight", "-w", help="Path or URL to pretrained model file (.pth extension)", required=False, cls=ResourceOption, ) @click.option( "--normalization", "-n", help="Z-Normalization of input images: 'imagenet' for ImageNet parameters," " 'current' for parameters of the current trainset, " "'none' for no normalization.", required=False, default="none", cls=ResourceOption, ) @click.option( "--monitoring-interval", "-I", help="""Time between checks for the use of resources during each training epoch. An interval of 5 seconds, for example, will lead to CPU and GPU resources being probed every 5 seconds during each training epoch. Values registered in the training logs correspond to averages (or maxima) observed through possibly many probes in each epoch. Notice that setting a very small value may cause the probing process to become extremely busy, potentially biasing the overall perception of resource usage.""", type=click.FloatRange(min=0.1), show_default=True, required=True, default=5.0, cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def train( model, optimizer, output_folder, epochs, batch_size, batch_chunk_count, drop_incomplete_batch, criterion, criterion_valid, dataset, checkpoint_period, device, seed, parallel, weight, normalization, monitoring_interval, **_, ): """Trains an CNN to perform tuberculosis detection. Training is performed for a configurable number of epochs, and generates at least a final_model.pth. It may also generate a number of intermediate checkpoints. Checkpoints are model files (.pth files) that are stored during the training and useful to resume the procedure in case it stops abruptly. """ import multiprocessing import sys import torch.cuda import torch.nn from torch.nn import BCEWithLogitsLoss from torch.utils.data import DataLoader, WeightedRandomSampler from ..configs.datasets import get_positive_weights, get_samples_weights from ..engine.trainer import run from ..utils.checkpointer import Checkpointer from ..utils.download import download_to_tempfile device = setup_pytorch_device(device) set_seeds(seed, all_gpus=False) use_dataset = dataset validation_dataset = None extra_validation_datasets = [] if isinstance(dataset, dict): if "__train__" in dataset: logger.info("Found (dedicated) '__train__' set for training") use_dataset = dataset["__train__"] else: use_dataset = dataset["train"] if "__valid__" in dataset: logger.info("Found (dedicated) '__valid__' set for validation") logger.info("Will checkpoint lowest loss model on validation set") validation_dataset = dataset["__valid__"] if "__extra_valid__" in dataset: if not isinstance(dataset["__extra_valid__"], list): raise RuntimeError( f"If present, dataset['__extra_valid__'] must be a list, " f"but you passed a {type(dataset['__extra_valid__'])}, " f"which is invalid." ) logger.info( f"Found {len(dataset['__extra_valid__'])} extra validation " f"set(s) to be tracked during training" ) logger.info( "Extra validation sets are NOT used for model checkpointing!" ) extra_validation_datasets = dataset["__extra_valid__"] # PyTorch dataloader multiproc_kwargs = dict() if parallel < 0: multiproc_kwargs["num_workers"] = 0 else: multiproc_kwargs["num_workers"] = ( parallel or multiprocessing.cpu_count() ) if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin": multiproc_kwargs[ "multiprocessing_context" ] = multiprocessing.get_context("spawn") batch_chunk_size = batch_size if batch_size % batch_chunk_count != 0: # batch_size must be divisible by batch_chunk_count. raise RuntimeError( f"--batch-size ({batch_size}) must be divisible by " f"--batch-chunk-size ({batch_chunk_count})." ) else: batch_chunk_size = batch_size // batch_chunk_count # Create weighted random sampler train_samples_weights = get_samples_weights(use_dataset) train_samples_weights = train_samples_weights.to( device=device, non_blocking=torch.cuda.is_available() ) train_sampler = WeightedRandomSampler( train_samples_weights, len(train_samples_weights), replacement=True ) # Redefine a weighted criterion if possible if isinstance(criterion, torch.nn.BCEWithLogitsLoss): positive_weights = get_positive_weights(use_dataset) positive_weights = positive_weights.to( device=device, non_blocking=torch.cuda.is_available() ) criterion = BCEWithLogitsLoss(pos_weight=positive_weights) else: logger.warning("Weighted criterion not supported") # PyTorch dataloader data_loader = DataLoader( dataset=use_dataset, batch_size=batch_chunk_size, drop_last=drop_incomplete_batch, pin_memory=torch.cuda.is_available(), sampler=train_sampler, **multiproc_kwargs, ) valid_loader = None if validation_dataset is not None: # Redefine a weighted valid criterion if possible if ( isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) or criterion_valid is None ): positive_weights = get_positive_weights(validation_dataset) positive_weights = positive_weights.to( device=device, non_blocking=torch.cuda.is_available() ) criterion_valid = BCEWithLogitsLoss(pos_weight=positive_weights) else: logger.warning("Weighted valid criterion not supported") valid_loader = DataLoader( dataset=validation_dataset, batch_size=batch_chunk_size, shuffle=False, drop_last=False, pin_memory=torch.cuda.is_available(), **multiproc_kwargs, ) extra_valid_loaders = [ DataLoader( dataset=k, batch_size=batch_chunk_size, shuffle=False, drop_last=False, pin_memory=torch.cuda.is_available(), **multiproc_kwargs, ) for k in extra_validation_datasets ] # Create z-normalization model layer if needed if normalization == "imagenet": model.normalizer.set_mean_std( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] ) logger.info("Z-normalization with ImageNet mean and std") elif normalization == "current": # Compute mean/std of current train subset temp_dl = DataLoader(dataset=use_dataset, batch_size=len(use_dataset)) data = next(iter(temp_dl)) mean = data[1].mean(dim=[0, 2, 3]) std = data[1].std(dim=[0, 2, 3]) model.normalizer.set_mean_std(mean, std) # Format mean and std for logging mean = str( [ round(x, 3) for x in ((mean * 10**3).round() / (10**3)).tolist() ] ) std = str( [ round(x, 3) for x in ((std * 10**3).round() / (10**3)).tolist() ] ) logger.info(f"Z-normalization with mean {mean} and std {std}") # Checkpointer checkpointer = Checkpointer(model, optimizer, path=output_folder) # Initialize epoch information arguments = {} arguments["epoch"] = 0 # Load pretrained weights if needed if weight is not None: if checkpointer.has_checkpoint(): logger.warning( "Weights are being ignored because a checkpoint already exists. " "Weights from checkpoint will be loaded instead." ) extra_checkpoint_data = checkpointer.load() else: if weight.startswith("http"): logger.info(f"Temporarily downloading '{weight}'...") f = download_to_tempfile(weight, progress=True) weight_fullpath = os.path.abspath(f.name) else: weight_fullpath = os.path.abspath(weight) extra_checkpoint_data = checkpointer.load( weight_fullpath, strict=False ) else: extra_checkpoint_data = checkpointer.load() # Update epoch information with checkpoint data arguments.update(extra_checkpoint_data) arguments["max_epoch"] = epochs logger.info("Training for {} epochs".format(arguments["max_epoch"])) logger.info("Continuing from epoch {}".format(arguments["epoch"])) run( model=model, data_loader=data_loader, valid_loader=valid_loader, extra_valid_loaders=extra_valid_loaders, optimizer=optimizer, criterion=criterion, checkpointer=checkpointer, checkpoint_period=checkpoint_period, device=device, arguments=arguments, output_folder=output_folder, monitoring_interval=monitoring_interval, batch_chunk_count=batch_chunk_count, criterion_valid=criterion_valid, )