From 0284fbba98abcd0bcc96aed963a472f97b8389ba Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Mon, 10 Jul 2023 11:19:35 +0200 Subject: [PATCH] [ptbench.scripts] Improved docs, adapt changes from weight-balancing strategy --- src/ptbench/scripts/train.py | 63 ++++++++++++++++++++---------------- src/ptbench/scripts/utils.py | 57 ++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 27 deletions(-) create mode 100644 src/ptbench/scripts/utils.py diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 9f770b59..ba45f184 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -16,7 +16,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") epilog="""Examples: \b - 1. Trains PASA model with Montgomery dataset, on a GPU (``cuda:0``): + 1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``): .. code:: sh @@ -36,29 +36,14 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--model", "-m", - help="A torch.nn.Module instance implementing the network to be trained", + help="A lightining module instance implementing the network to be trained", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", - help="A dictionary mapping string keys to " - "torch.utils.data.dataset.Dataset instances implementing datasets " - "to be used for training and validating the model, possibly including all " - "pre-processing pipelines required or, optionally, a dictionary mapping " - "string keys to torch.utils.data.dataset.Dataset instances. At least " - "one key named ``train`` must be available. This dataset will be used for " - "training the network model. The dataset description must include all " - "required pre-processing, including eventual data augmentation. If a " - "dataset named ``__train__`` is available, it is used prioritarily for " - "training instead of ``train``. If a dataset named ``__valid__`` is " - "available, it is used for model validation (and automatic " - "check-pointing) at each epoch. If a dataset list named " - "``__extra_valid__`` is available, then it will be tracked during the " - "validation process and its loss output at the training log as well, " - "in the format of an array occupying a single column. All other keys " - "are considered test datasets and are ignored during training", + help="A lighting data module containing the training and validation sets.", required=True, cls=ResourceOption, ) @@ -149,7 +134,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--cache-samples", + "--cache-samples/--no-cache-samples", help="If set to True, loads the sample into memory, " "otherwise loads them at runtime.", required=True, @@ -205,6 +190,18 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") default=None, cls=ResourceOption, ) +@click.option( + "--balance-classes/--no-balance-classes", + "-B/-N", + help="""If set, then balances weights of the random sampler during + training, so that samples from all sample classes are picked picked + equitably. It also sets the training (and validation) losses to account + for the populations of each class.""", + required=True, + show_default=True, + default=True, + cls=ResourceOption, +) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def train( model, @@ -221,8 +218,9 @@ def train( parallel, monitoring_interval, resume_from, + balance_classes, **_, -): +) -> None: """Trains an CNN to perform image classification. Training is performed for a configurable number of epochs, and @@ -239,7 +237,9 @@ def train( from ..engine.trainer import run from ..utils.checkpointer import get_checkpoint + from .utils import save_sh_command + save_sh_command(output_folder) seed_everything(seed) checkpoint_file = get_checkpoint(output_folder, resume_from) @@ -257,22 +257,31 @@ def train( # this call may be a NOOP, if the model was pre-trained and expects # different weights for the normalisation layer. if hasattr(model, "set_normalizer"): - model.set_normalizer(datamodule.train_dataloader()) + model.set_normalizer(datamodule.unshuffled_train_dataloader()) else: logger.warning( - f"Model {model.name} has no 'set_normalizer' method. No normalization will be applied." + f"Model {model.name} has no 'set_normalizer' method. Skipping." ) - # Rebalances the loss criterion based on the relative proportion of class - # examples available in the training set. Also affects the validation loss - # if a validation set is available on the data module. - model.set_bce_loss_weights(datamodule) + # If asked, rebalances the loss criterion based on the relative proportion + # of class examples available in the training set. Also affects the + # validation loss if a validation set is available on the data module. + if balance_classes: + logger.info("Applying datamodule train sampler balancing...") + datamodule.balance_sampler_by_class = True + # logger.info("Applying train/valid loss balancing...") + # model.balance_losses_by_class(datamodule) + else: + logger.info( + "Skipping sample class/dataset ownership balancing on user request" + ) arguments = {} arguments["max_epoch"] = epochs arguments["epoch"] = 0 - # We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit() + # We only load the checkpoint to get some information about its state. The + # actual loading of the model is done in trainer.fit() if checkpoint_file is not None: checkpoint = torch.load(checkpoint_file) arguments["epoch"] = checkpoint["epoch"] diff --git a/src/ptbench/scripts/utils.py b/src/ptbench/scripts/utils.py new file mode 100644 index 00000000..5553a6ed --- /dev/null +++ b/src/ptbench/scripts/utils.py @@ -0,0 +1,57 @@ +import importlib.metadata +import logging +import os +import pathlib +import sys +import time + +logger = logging.getLogger(__name__) + + +def save_sh_command(output_folder: str | pathlib.Path) -> None: + """Records command-line to reproduce this script. + + This function can record the current command-line used to call the script + being run. It creates an executable ``bash`` script setting up the current + working directory and activating a conda environment, if needed. It + records further information on the date and time the script was run and the + version of the package. + + + Parameters + ---------- + + output_folder : str + Path leading to the directory where the commands to reproduce the current + run will be recorded. A subdirectory will be created each time this function + is called to match lightning's versioning convention for loggers. + """ + + if isinstance(output_folder, str): + output_folder = pathlib.Path(output_folder) + + destfile = output_folder / "command.sh" + + logger.info(f"Writing command-line for reproduction at '{destfile}'...") + os.makedirs(output_folder, exist_ok=True) + + package = __name__.split(".", 1)[0] + version = importlib.metadata.version(package) + + with destfile.open("w") as f: + f.write("#!/usr/bin/env sh\n") + f.write(f"# date: {time.asctime()}\n") + f.write(f"# version: {version} ({package})\n") + f.write(f"# platform: {sys.platform}\n") + f.write("\n") + args = [] + for k in sys.argv: + if " " in k: + args.append(f'"{k}"') + else: + args.append(k) + if os.environ.get("CONDA_DEFAULT_ENV") is not None: + f.write(f"# conda activate {os.environ['CONDA_DEFAULT_ENV']}\n") + f.write(f"# cd {os.path.realpath(os.curdir)}\n") + f.write(" ".join(args) + "\n") + os.chmod(destfile, 0o755) -- GitLab