Skip to content
Snippets Groups Projects
Commit 0284fbba authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[ptbench.scripts] Improved docs, adapt changes from weight-balancing strategy

parent 427c1cf1
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -16,7 +16,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
epilog="""Examples:
\b
1. Trains PASA model with Montgomery dataset, on a GPU (``cuda:0``):
1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``):
.. code:: sh
......@@ -36,29 +36,14 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.option(
"--model",
"-m",
help="A torch.nn.Module instance implementing the network to be trained",
help="A lightining module instance implementing the network to be trained",
required=True,
cls=ResourceOption,
)
@click.option(
"--datamodule",
"-d",
help="A dictionary mapping string keys to "
"torch.utils.data.dataset.Dataset instances implementing datasets "
"to be used for training and validating the model, possibly including all "
"pre-processing pipelines required or, optionally, a dictionary mapping "
"string keys to torch.utils.data.dataset.Dataset instances. At least "
"one key named ``train`` must be available. This dataset will be used for "
"training the network model. The dataset description must include all "
"required pre-processing, including eventual data augmentation. If a "
"dataset named ``__train__`` is available, it is used prioritarily for "
"training instead of ``train``. If a dataset named ``__valid__`` is "
"available, it is used for model validation (and automatic "
"check-pointing) at each epoch. If a dataset list named "
"``__extra_valid__`` is available, then it will be tracked during the "
"validation process and its loss output at the training log as well, "
"in the format of an array occupying a single column. All other keys "
"are considered test datasets and are ignored during training",
help="A lighting data module containing the training and validation sets.",
required=True,
cls=ResourceOption,
)
......@@ -149,7 +134,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption,
)
@click.option(
"--cache-samples",
"--cache-samples/--no-cache-samples",
help="If set to True, loads the sample into memory, "
"otherwise loads them at runtime.",
required=True,
......@@ -205,6 +190,18 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
default=None,
cls=ResourceOption,
)
@click.option(
"--balance-classes/--no-balance-classes",
"-B/-N",
help="""If set, then balances weights of the random sampler during
training, so that samples from all sample classes are picked picked
equitably. It also sets the training (and validation) losses to account
for the populations of each class.""",
required=True,
show_default=True,
default=True,
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def train(
model,
......@@ -221,8 +218,9 @@ def train(
parallel,
monitoring_interval,
resume_from,
balance_classes,
**_,
):
) -> None:
"""Trains an CNN to perform image classification.
Training is performed for a configurable number of epochs, and
......@@ -239,7 +237,9 @@ def train(
from ..engine.trainer import run
from ..utils.checkpointer import get_checkpoint
from .utils import save_sh_command
save_sh_command(output_folder)
seed_everything(seed)
checkpoint_file = get_checkpoint(output_folder, resume_from)
......@@ -257,22 +257,31 @@ def train(
# this call may be a NOOP, if the model was pre-trained and expects
# different weights for the normalisation layer.
if hasattr(model, "set_normalizer"):
model.set_normalizer(datamodule.train_dataloader())
model.set_normalizer(datamodule.unshuffled_train_dataloader())
else:
logger.warning(
f"Model {model.name} has no 'set_normalizer' method. No normalization will be applied."
f"Model {model.name} has no 'set_normalizer' method. Skipping."
)
# Rebalances the loss criterion based on the relative proportion of class
# examples available in the training set. Also affects the validation loss
# if a validation set is available on the data module.
model.set_bce_loss_weights(datamodule)
# If asked, rebalances the loss criterion based on the relative proportion
# of class examples available in the training set. Also affects the
# validation loss if a validation set is available on the data module.
if balance_classes:
logger.info("Applying datamodule train sampler balancing...")
datamodule.balance_sampler_by_class = True
# logger.info("Applying train/valid loss balancing...")
# model.balance_losses_by_class(datamodule)
else:
logger.info(
"Skipping sample class/dataset ownership balancing on user request"
)
arguments = {}
arguments["max_epoch"] = epochs
arguments["epoch"] = 0
# We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit()
# We only load the checkpoint to get some information about its state. The
# actual loading of the model is done in trainer.fit()
if checkpoint_file is not None:
checkpoint = torch.load(checkpoint_file)
arguments["epoch"] = checkpoint["epoch"]
......
import importlib.metadata
import logging
import os
import pathlib
import sys
import time
logger = logging.getLogger(__name__)
def save_sh_command(output_folder: str | pathlib.Path) -> None:
"""Records command-line to reproduce this script.
This function can record the current command-line used to call the script
being run. It creates an executable ``bash`` script setting up the current
working directory and activating a conda environment, if needed. It
records further information on the date and time the script was run and the
version of the package.
Parameters
----------
output_folder : str
Path leading to the directory where the commands to reproduce the current
run will be recorded. A subdirectory will be created each time this function
is called to match lightning's versioning convention for loggers.
"""
if isinstance(output_folder, str):
output_folder = pathlib.Path(output_folder)
destfile = output_folder / "command.sh"
logger.info(f"Writing command-line for reproduction at '{destfile}'...")
os.makedirs(output_folder, exist_ok=True)
package = __name__.split(".", 1)[0]
version = importlib.metadata.version(package)
with destfile.open("w") as f:
f.write("#!/usr/bin/env sh\n")
f.write(f"# date: {time.asctime()}\n")
f.write(f"# version: {version} ({package})\n")
f.write(f"# platform: {sys.platform}\n")
f.write("\n")
args = []
for k in sys.argv:
if " " in k:
args.append(f'"{k}"')
else:
args.append(k)
if os.environ.get("CONDA_DEFAULT_ENV") is not None:
f.write(f"# conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
f.write(f"# cd {os.path.realpath(os.curdir)}\n")
f.write(" ".join(args) + "\n")
os.chmod(destfile, 0o755)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment