From 0284fbba98abcd0bcc96aed963a472f97b8389ba Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 10 Jul 2023 11:19:35 +0200
Subject: [PATCH] [ptbench.scripts] Improved docs, adapt changes from
 weight-balancing strategy

---
 src/ptbench/scripts/train.py | 63 ++++++++++++++++++++----------------
 src/ptbench/scripts/utils.py | 57 ++++++++++++++++++++++++++++++++
 2 files changed, 93 insertions(+), 27 deletions(-)
 create mode 100644 src/ptbench/scripts/utils.py

diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 9f770b59..ba45f184 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -16,7 +16,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     epilog="""Examples:
 
 \b
-    1. Trains PASA model with Montgomery dataset, on a GPU (``cuda:0``):
+    1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``):
 
        .. code:: sh
 
@@ -36,29 +36,14 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @click.option(
     "--model",
     "-m",
-    help="A torch.nn.Module instance implementing the network to be trained",
+    help="A lightining module instance implementing the network to be trained",
     required=True,
     cls=ResourceOption,
 )
 @click.option(
     "--datamodule",
     "-d",
-    help="A dictionary mapping string keys to "
-    "torch.utils.data.dataset.Dataset instances implementing datasets "
-    "to be used for training and validating the model, possibly including all "
-    "pre-processing pipelines required or, optionally, a dictionary mapping "
-    "string keys to torch.utils.data.dataset.Dataset instances.  At least "
-    "one key named ``train`` must be available.  This dataset will be used for "
-    "training the network model.  The dataset description must include all "
-    "required pre-processing, including eventual data augmentation.  If a "
-    "dataset named ``__train__`` is available, it is used prioritarily for "
-    "training instead of ``train``.  If a dataset named ``__valid__`` is "
-    "available, it is used for model validation (and automatic "
-    "check-pointing) at each epoch.  If a dataset list named "
-    "``__extra_valid__`` is available, then it will be tracked during the "
-    "validation process and its loss output at the training log as well, "
-    "in the format of an array occupying a single column.  All other keys "
-    "are considered test datasets and are ignored during training",
+    help="A lighting data module containing the training and validation sets.",
     required=True,
     cls=ResourceOption,
 )
@@ -149,7 +134,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @click.option(
-    "--cache-samples",
+    "--cache-samples/--no-cache-samples",
     help="If set to True, loads the sample into memory, "
     "otherwise loads them at runtime.",
     required=True,
@@ -205,6 +190,18 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     default=None,
     cls=ResourceOption,
 )
+@click.option(
+    "--balance-classes/--no-balance-classes",
+    "-B/-N",
+    help="""If set, then balances weights of the random sampler during
+    training, so that samples from all sample classes are picked picked
+    equitably.  It also sets the training (and validation) losses to account
+    for the populations of each class.""",
+    required=True,
+    show_default=True,
+    default=True,
+    cls=ResourceOption,
+)
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def train(
     model,
@@ -221,8 +218,9 @@ def train(
     parallel,
     monitoring_interval,
     resume_from,
+    balance_classes,
     **_,
-):
+) -> None:
     """Trains an CNN to perform image classification.
 
     Training is performed for a configurable number of epochs, and
@@ -239,7 +237,9 @@ def train(
 
     from ..engine.trainer import run
     from ..utils.checkpointer import get_checkpoint
+    from .utils import save_sh_command
 
+    save_sh_command(output_folder)
     seed_everything(seed)
 
     checkpoint_file = get_checkpoint(output_folder, resume_from)
@@ -257,22 +257,31 @@ def train(
     # this call may be a NOOP, if the model was pre-trained and expects
     # different weights for the normalisation layer.
     if hasattr(model, "set_normalizer"):
-        model.set_normalizer(datamodule.train_dataloader())
+        model.set_normalizer(datamodule.unshuffled_train_dataloader())
     else:
         logger.warning(
-            f"Model {model.name} has no 'set_normalizer' method. No normalization will be applied."
+            f"Model {model.name} has no 'set_normalizer' method. Skipping."
         )
 
-    # Rebalances the loss criterion based on the relative proportion of class
-    # examples available in the training set.  Also affects the validation loss
-    # if a validation set is available on the data module.
-    model.set_bce_loss_weights(datamodule)
+    # If asked, rebalances the loss criterion based on the relative proportion
+    # of class examples available in the training set.  Also affects the
+    # validation loss if a validation set is available on the data module.
+    if balance_classes:
+        logger.info("Applying datamodule train sampler balancing...")
+        datamodule.balance_sampler_by_class = True
+        # logger.info("Applying train/valid loss balancing...")
+        # model.balance_losses_by_class(datamodule)
+    else:
+        logger.info(
+            "Skipping sample class/dataset ownership balancing on user request"
+        )
 
     arguments = {}
     arguments["max_epoch"] = epochs
     arguments["epoch"] = 0
 
-    # We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit()
+    # We only load the checkpoint to get some information about its state. The
+    # actual loading of the model is done in trainer.fit()
     if checkpoint_file is not None:
         checkpoint = torch.load(checkpoint_file)
         arguments["epoch"] = checkpoint["epoch"]
diff --git a/src/ptbench/scripts/utils.py b/src/ptbench/scripts/utils.py
new file mode 100644
index 00000000..5553a6ed
--- /dev/null
+++ b/src/ptbench/scripts/utils.py
@@ -0,0 +1,57 @@
+import importlib.metadata
+import logging
+import os
+import pathlib
+import sys
+import time
+
+logger = logging.getLogger(__name__)
+
+
+def save_sh_command(output_folder: str | pathlib.Path) -> None:
+    """Records command-line to reproduce this script.
+
+    This function can record the current command-line used to call the script
+    being run.  It creates an executable ``bash`` script setting up the current
+    working directory and activating a conda environment, if needed.  It
+    records further information on the date and time the script was run and the
+    version of the package.
+
+
+    Parameters
+    ----------
+
+    output_folder : str
+        Path leading to the directory where the commands to reproduce the current
+        run will be recorded. A subdirectory will be created each time this function
+        is called to match lightning's versioning convention for loggers.
+    """
+
+    if isinstance(output_folder, str):
+        output_folder = pathlib.Path(output_folder)
+
+    destfile = output_folder / "command.sh"
+
+    logger.info(f"Writing command-line for reproduction at '{destfile}'...")
+    os.makedirs(output_folder, exist_ok=True)
+
+    package = __name__.split(".", 1)[0]
+    version = importlib.metadata.version(package)
+
+    with destfile.open("w") as f:
+        f.write("#!/usr/bin/env sh\n")
+        f.write(f"# date: {time.asctime()}\n")
+        f.write(f"# version: {version} ({package})\n")
+        f.write(f"# platform: {sys.platform}\n")
+        f.write("\n")
+        args = []
+        for k in sys.argv:
+            if " " in k:
+                args.append(f'"{k}"')
+            else:
+                args.append(k)
+        if os.environ.get("CONDA_DEFAULT_ENV") is not None:
+            f.write(f"# conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
+        f.write(f"# cd {os.path.realpath(os.curdir)}\n")
+        f.write(" ".join(args) + "\n")
+    os.chmod(destfile, 0o755)
-- 
GitLab