Skip to content
Snippets Groups Projects
Commit 820251d5 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[scripts] move training and prediction scripts to common packge

parent c309fb41
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 64 additions and 516 deletions
......@@ -100,7 +100,7 @@ Reusable auxiliary functions.
mednet.utils.checkpointer
mednet.utils.gitlab
mednet.utils.rc
mednet.utils.resources
mednet.libs.common.utils.resources
mednet.utils.tensorboard
......
......@@ -12,7 +12,7 @@ import torch.utils.data
import torchvision.models as models
import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence
from .model import Model
from mednet.libs.common.models.model import Model
from .separate import separate
from .transforms import RGB, SquareCenterPad
......
......@@ -11,8 +11,8 @@ import torch.nn.functional as F # noqa: N812
import torch.optim.optimizer
import torch.utils.data
from ..data.typing import TransformSequence
from .model import Model
from ...common.data.typing import TransformSequence
from ...common.models.model import Model
from .separate import separate
logger = logging.getLogger(__name__)
......
......@@ -12,7 +12,7 @@ import torch.utils.data
import torchvision.models as models
import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence
from .model import Model
from mednet.libs.common.models.model import Model
from .separate import separate
from .transforms import RGB, SquareCenterPad
......
......@@ -12,7 +12,7 @@ import torch.optim.optimizer
import torch.utils.data
import torchvision.transforms
from mednet.libs.common.data.typing import TransformSequence
from .model import Model
from mednet.libs.common.models.model import Model
from .separate import separate
from .transforms import Grayscale, SquareCenterPad
......
......@@ -5,9 +5,6 @@
import typing
Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
"""Definition of a lightning checkpoint."""
BinaryPrediction: typing.TypeAlias = tuple[str, int, float]
"""The sample name, the target, and the predicted value."""
......
......@@ -23,12 +23,6 @@ classification.add_command(
classification.add_command(
importlib.import_module("..database", package=__name__).database,
)
classification.add_command(
importlib.import_module("..evaluate", package=__name__).evaluate,
)
classification.add_command(
importlib.import_module("..experiment", package=__name__).experiment,
)
classification.add_command(
importlib.import_module("..predict", package=__name__).predict
)
......@@ -37,10 +31,16 @@ classification.add_command(
)
classification.add_command(
importlib.import_module(
"..train_analysis",
"mednet.libs.common.scripts.train_analysis",
package=__name__,
).train_analysis,
)
classification.add_command(
importlib.import_module("..evaluate", package=__name__).evaluate,
)
classification.add_command(
importlib.import_module("..experiment", package=__name__).experiment,
)
@click.group(
......
......@@ -7,8 +7,7 @@ import pathlib
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from .click import ConfigCommand
from mednet.libs.common.scripts.click import ConfigCommand
# avoids X11/graphical desktop requirement when creating plots
__import__("matplotlib").use("agg")
......@@ -116,6 +115,10 @@ def evaluate(
import typing
from matplotlib.backends.backend_pdf import PdfPages
from mednet.libs.common.scripts.utils import (
execution_metadata,
save_json_with_backup,
)
from ..engine.evaluator import (
NumpyJSONEncoder,
......@@ -125,7 +128,6 @@ def evaluate(
score_plot,
tabulate_results,
)
from .utils import execution_metadata, save_json_with_backup
evaluation_filename = "evaluation.json"
evaluation_file = pathlib.Path(output_folder) / evaluation_filename
......
......@@ -98,7 +98,7 @@ def experiment(
logger.info(f"Training runtime: {train_stop_timestamp-train_start_timestamp}")
logger.info("Started train analysis")
from .train_analysis import train_analysis
from mednet.libs.common.scripts.train_analysis import train_analysis
ctx.invoke(
train_analysis,
......
......@@ -2,13 +2,13 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from .click import ConfigCommand
from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.predict import predict as predict_script
from mednet.libs.common.scripts.predict import reusable_options
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -32,92 +32,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
""",
)
@click.option(
"--output",
"-o",
help="""Path to a JSON file in which to save predictions for all samples in the
input DataModule (leading directories are created if they do not
exist).""",
required=True,
default="predictions.json",
cls=ResourceOption,
type=click.Path(
file_okay=True,
dir_okay=False,
writable=True,
path_type=pathlib.Path,
),
)
@click.option(
"--model",
"-m",
help="""A lightning module instance implementing the network architecture
(not the weights, necessarily) to be used for prediction.""",
required=True,
type=click.UNPROCESSED,
cls=ResourceOption,
)
@click.option(
"--datamodule",
"-d",
help="""A lightning DataModule that will be asked for prediction data
loaders. Typically, this includes all configured splits in a DataModule,
however this is not a requirement. A DataModule that returns a single
dataloader for prediction (wrapped in a dictionary) is acceptable.""",
required=True,
type=click.UNPROCESSED,
cls=ResourceOption,
)
@click.option(
"--batch-size",
"-b",
help="""Number of samples in every batch (this parameter affects memory
requirements for the network).""",
required=True,
show_default=True,
default=1,
type=click.IntRange(min=1),
cls=ResourceOption,
)
@click.option(
"--device",
"-d",
help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
show_default=True,
required=True,
default="cpu",
cls=ResourceOption,
)
@click.option(
"--weight",
"-w",
help="""Path or URL to pretrained model file (`.ckpt` extension),
corresponding to the architecture set with `--model`. Optionally, you may
also pass a directory containing the result of a training session, in which
case either the best (lowest validation) or latest model will be loaded.""",
required=True,
cls=ResourceOption,
type=click.Path(
exists=True,
file_okay=True,
dir_okay=True,
readable=True,
path_type=pathlib.Path,
),
)
@click.option(
"--parallel",
"-P",
help="""Use multiprocessing for data loading: if set to -1 (default),
disables multiprocessing data loading. Set to 0 to enable as many data
loading instances as processing cores available in the system. Set to
>= 1 to enable that many multiprocessing instances for data loading.""",
type=click.IntRange(min=-1),
show_default=True,
required=True,
default=-1,
cls=ResourceOption,
)
@reusable_options
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def predict(
output,
......@@ -130,63 +45,13 @@ def predict(
**_,
) -> None: # numpydoc ignore=PR01
"""Run inference (generates scores) on all input images, using a pre-trained model."""
import json
import shutil
import typing
from mednet.libs.classification.engine.predictor import run
from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.utils.checkpointer import (
get_checkpoint_to_run_inference,
)
from .utils import (
device_properties,
execution_metadata,
model_summary,
save_json_with_backup,
predict_script(
output,
model,
datamodule,
batch_size,
device,
weight,
parallel,
**_,
)
datamodule.batch_size = batch_size
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data()
datamodule.setup(stage="predict")
if weight.is_dir():
weight = get_checkpoint_to_run_inference(weight)
logger.info(f"Loading checkpoint from `{weight}`...")
model = type(model).load_from_checkpoint(weight, strict=False)
device_manager = DeviceManager(device)
# register metadata
json_data: dict[str, typing.Any] = execution_metadata()
json_data.update(device_properties(device_manager.device_type))
json_data.update(
dict(
database_name=datamodule.database_name,
database_split=datamodule.split_name,
model_name=model.name,
),
)
json_data.update(model_summary(model))
json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
save_json_with_backup(output.with_suffix(".meta.json"), json_data)
predictions = run(model, datamodule, device_manager)
output.parent.mkdir(parents=True, exist_ok=True)
if output.exists():
backup = output.parent / (output.name + "~")
logger.warning(
f"Output predictions file `{str(output)}` exists - "
f"backing it up to `{str(backup)}`...",
)
shutil.copy(output, backup)
with output.open("w") as f:
json.dump(predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(output)}`")
......@@ -8,9 +8,9 @@ import typing
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand
from ...models.typing import SaliencyMapAlgorithm
from ..click import ConfigCommand
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......
......@@ -8,9 +8,9 @@ import typing
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand
from ...models.typing import SaliencyMapAlgorithm
from ..click import ConfigCommand
# avoids X11/graphical desktop requirement when creating plots
__import__("matplotlib").use("agg")
......
......@@ -8,9 +8,9 @@ import typing
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand
from ...models.typing import SaliencyMapAlgorithm
from ..click import ConfigCommand
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......
......@@ -7,8 +7,7 @@ import pathlib
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from ..click import ConfigCommand
from mednet.libs.common.scripts.click import ConfigCommand
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import functools
import pathlib
import typing
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.train import reusable_options
from mednet.libs.common.scripts.train import train as train_script
from .click import ConfigCommand
# logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
logger = setup("mednet", format="%(levelname)s: %(message)s")
def reusable_options(f):
"""Wrap reusable training script options (for ``experiment``).
This decorator equips the target function ``f`` with all (reusable)
``train`` script options.
Parameters
----------
f
The target function to equip with options. This function must have
parameters that accept such options.
Returns
-------
The decorated version of function ``f``
"""
@click.option(
"--output-folder",
"-o",
help="Directory in which to store results (created if does not exist)",
required=True,
type=click.Path(
file_okay=False,
dir_okay=True,
writable=True,
path_type=pathlib.Path,
),
default="results",
cls=ResourceOption,
)
@click.option(
"--model",
"-m",
help="A lightning module instance implementing the network to be trained",
required=True,
type=click.UNPROCESSED,
cls=ResourceOption,
)
@click.option(
"--datamodule",
"-d",
help="A lightning DataModule containing the training and validation sets.",
required=True,
type=click.UNPROCESSED,
cls=ResourceOption,
)
@click.option(
"--batch-size",
"-b",
help="Number of samples in every batch (this parameter affects "
"memory requirements for the network). If the number of samples in "
"the batch is larger than the total number of samples available for "
"training, this value is truncated. If this number is smaller, then "
"batches of the specified size are created and fed to the network "
"until there are no more new samples to feed (epoch is finished). "
"If the total number of training samples is not a multiple of the "
"batch-size, the last batch will be smaller than the first, unless "
"--drop-incomplete-batch is set, in which case this batch is not used.",
required=True,
show_default=True,
default=1,
type=click.IntRange(min=1),
cls=ResourceOption,
)
@click.option(
"--accumulate-grad-batches",
"-a",
help="Number of accumulations for backward propagation to accumulate "
"gradients over k batches before stepping the optimizer. This "
"parameter, used in conjunction with the batch-size, may be used to "
"reduce the number of samples loaded in each iteration, to affect memory "
"usage in exchange for processing time (more iterations). This is "
"useful interesting when one is training on GPUs with a limited amount "
"of onboard RAM. The default of 1 forces the whole batch to be "
"processed at once. Otherwise the batch is multiplied by "
"accumulate-grad-batches pieces, and gradients are accumulated "
"to complete each training step.",
required=True,
show_default=True,
default=1,
type=click.IntRange(min=1),
cls=ResourceOption,
)
@click.option(
"--drop-incomplete-batch/--no-drop-incomplete-batch",
"-D",
help="If set, the last batch in an epoch will be dropped if "
"incomplete. If you set this option, you should also consider "
"increasing the total number of epochs of training, as the total number "
"of training steps may be reduced.",
required=True,
show_default=True,
default=False,
cls=ResourceOption,
)
@click.option(
"--epochs",
"-e",
help="""Number of epochs (complete training set passes) to train for.
If continuing from a saved checkpoint, ensure to provide a greater
number of epochs than was saved in the checkpoint to be loaded.""",
show_default=True,
required=True,
default=1000,
type=click.IntRange(min=1),
cls=ResourceOption,
)
@click.option(
"--validation-period",
"-p",
help="""Number of epochs after which validation happens. By default,
we run validation after every training epoch (period=1). You can
change this to make validation more sparse, by increasing the
validation period. Notice that this affects checkpoint saving. While
checkpoints are created after every training step (the last training
step always triggers the overriding of latest checkpoint), and
this process is independent of validation runs, evaluation of the
'best' model obtained so far based on those will be influenced by this
setting.""",
show_default=True,
required=True,
default=1,
type=click.IntRange(min=1),
cls=ResourceOption,
)
@click.option(
"--device",
"-x",
help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
show_default=True,
required=True,
default="cpu",
cls=ResourceOption,
)
@click.option(
"--cache-samples/--no-cache-samples",
help="If set to True, loads the sample into memory, "
"otherwise loads them at runtime.",
required=True,
show_default=True,
default=False,
cls=ResourceOption,
)
@click.option(
"--seed",
"-s",
help="Seed to use for the random number generator",
show_default=True,
required=False,
default=42,
type=click.IntRange(min=0),
cls=ResourceOption,
)
@click.option(
"--parallel",
"-P",
help="""Use multiprocessing for data loading: if set to -1 (default),
disables multiprocessing data loading. Set to 0 to enable as many data
loading instances as processing cores available in the system. Set to
>= 1 to enable that many multiprocessing instances for data
loading.""",
type=click.IntRange(min=-1),
show_default=True,
required=True,
default=-1,
cls=ResourceOption,
)
@click.option(
"--monitoring-interval",
"-I",
help="""Time between checks for the use of resources during each training
epoch, in seconds. An interval of 5 seconds, for example, will lead to
CPU and GPU resources being probed every 5 seconds during each training
epoch. Values registered in the training logs correspond to averages
(or maxima) observed through possibly many probes in each epoch.
Notice that setting a very small value may cause the probing process to
become extremely busy, potentially biasing the overall perception of
resource usage.""",
type=click.FloatRange(min=0.1),
show_default=True,
required=True,
default=5.0,
cls=ResourceOption,
)
@click.option(
"--balance-classes/--no-balance-classes",
"-B/-N",
help="""If set, balances weights of the random sampler during
training so that samples from all sample classes are picked
equitably.""",
required=True,
show_default=True,
default=True,
cls=ResourceOption,
)
@click.option(
"--augmentations",
"-A",
help="""Models that can be trained in this package are shipped without
explicit data augmentations. This option allows you to define a list of
data augmentations to use for training the selected model.""",
type=click.UNPROCESSED,
default=[],
cls=ResourceOption,
)
@functools.wraps(f)
def wrapper_reusable_options(*args, **kwargs):
return f(*args, **kwargs)
return wrapper_reusable_options
@click.command(
entry_point_group="mednet.libs.classification.config",
cls=ConfigCommand,
......@@ -269,122 +48,20 @@ def train(
extension that are used in subsequent tasks or from which training
can be resumed.
"""
import torch
from lightning.pytorch import seed_everything
from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.engine.trainer import run
from mednet.libs.common.utils.checkpointer import (
get_checkpoint_to_resume_training,
)
from .utils import (
device_properties,
execution_metadata,
model_summary,
save_json_with_backup,
)
checkpoint_file = None
if output_folder.is_dir():
try:
checkpoint_file = get_checkpoint_to_resume_training(output_folder)
except FileNotFoundError:
logger.info(
f"Folder {output_folder} already exists, but I did not"
f" find any usable checkpoint file to resume training"
f" from. Starting from scratch...",
)
seed_everything(seed)
# report model/transforms options - set data augmentations
logger.info(f"Network model: {type(model).__module__}.{type(model).__name__}")
model.augmentation_transforms = augmentations
# reset datamodule with user configurable options
datamodule.batch_size = batch_size
datamodule.drop_incomplete_batch = drop_incomplete_batch
datamodule.cache_samples = cache_samples
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data()
datamodule.setup(stage="fit")
# If asked, rebalances the loss criterion based on the relative proportion
# of class examples available in the training set. Also affects the
# validation loss if a validation set is available on the DataModule.
if balance_classes:
logger.info("Applying train/valid loss balancing...")
model.balance_losses(datamodule)
else:
logger.info(
"Skipping sample class/dataset ownership balancing on user request",
)
logger.info(f"Training for at most {epochs} epochs.")
arguments = {}
arguments["max_epoch"] = epochs
arguments["epoch"] = 0
if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
# Sets the model normalizer with the unaugmented-train-subset if we are
# starting from scratch and/or the model does not contain its own
# checkpoint loading strategy (e.g. a pytorch stock checkpoint). This
# call may be a NOOP, if the model comes from outside this framework,
# and expects different weights for the normalisation layer.
if hasattr(model, "set_normalizer"):
model.set_normalizer(datamodule.unshuffled_train_dataloader())
else:
logger.warning(
f"Model {model.name} has no `set_normalizer` method. "
"Skipping normalization setup (unsupported external model).",
)
else:
# Normalizer will be loaded during model.on_load_checkpoint
checkpoint = torch.load(checkpoint_file)
start_epoch = checkpoint["epoch"]
logger.info(
f"Resuming from epoch {start_epoch} "
f"(checkpoint file: `{str(checkpoint_file)}`)...",
)
device_manager = DeviceManager(device)
# stores all information we can think of, to reproduce this later
json_data: dict[str, typing.Any] = execution_metadata()
json_data.update(device_properties(device_manager.device_type))
json_data.update(
dict(
database_name=datamodule.database_name,
split_name=datamodule.split_name,
epochs=epochs,
batch_size=batch_size,
accumulate_grad_batches=accumulate_grad_batches,
drop_incomplete_batch=drop_incomplete_batch,
validation_period=validation_period,
cache_samples=cache_samples,
seed=seed,
parallel=parallel,
monitoring_interval=monitoring_interval,
balance_classes=balance_classes,
model_name=model.name,
),
)
json_data.update(model_summary(model))
json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
save_json_with_backup(output_folder / "meta.json", json_data)
run(
model=model,
datamodule=datamodule,
validation_period=validation_period,
device_manager=device_manager,
max_epochs=epochs,
output_folder=output_folder,
monitoring_interval=monitoring_interval,
accumulate_grad_batches=accumulate_grad_batches,
checkpoint=checkpoint_file,
train_script(
model,
output_folder,
epochs,
batch_size,
accumulate_grad_batches,
drop_incomplete_batch,
datamodule,
validation_period,
device,
cache_samples,
seed,
parallel,
monitoring_interval,
balance_classes,
**_,
)
......@@ -8,7 +8,6 @@ from collections import Counter
import torch
import torch.utils.data
from mednet.libs.common.data.typing import DataLoader
logger = logging.getLogger("mednet")
......
File moved
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines most common types used in code."""
import typing
Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
"""Definition of a lightning checkpoint."""
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment