diff --git a/doc/api.rst b/doc/api.rst index d0f2f4118f29f0113819ed4b35b9fef66509bb2b..d4e441f00698cd6bbb8bbf42c233f0b3048f8c46 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -100,7 +100,7 @@ Reusable auxiliary functions. mednet.utils.checkpointer mednet.utils.gitlab mednet.utils.rc - mednet.utils.resources + mednet.libs.common.utils.resources mednet.utils.tensorboard diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py index 481bc03fd3cfd7322187792f5020cd5e71d8aefa..03802839900ee4bd3f5d8b368cb8d611cbc6b006 100644 --- a/src/mednet/libs/classification/models/alexnet.py +++ b/src/mednet/libs/classification/models/alexnet.py @@ -12,7 +12,7 @@ import torch.utils.data import torchvision.models as models import torchvision.transforms from mednet.libs.common.data.typing import TransformSequence -from .model import Model +from mednet.libs.common.models.model import Model from .separate import separate from .transforms import RGB, SquareCenterPad diff --git a/src/mednet/models/cnn3d.py b/src/mednet/libs/classification/models/cnn3d.py similarity index 98% rename from src/mednet/models/cnn3d.py rename to src/mednet/libs/classification/models/cnn3d.py index d0be0e62e6c66431e0d14ac1d2ecddcfd91cf1d7..b9abf5e98c3c7fef25e255e5a8b656dee7b272c2 100644 --- a/src/mednet/models/cnn3d.py +++ b/src/mednet/libs/classification/models/cnn3d.py @@ -11,8 +11,8 @@ import torch.nn.functional as F # noqa: N812 import torch.optim.optimizer import torch.utils.data -from ..data.typing import TransformSequence -from .model import Model +from ...common.data.typing import TransformSequence +from ...common.models.model import Model from .separate import separate logger = logging.getLogger(__name__) diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py index bd7b83d8bfdd8d1f05c94adfc1a7a92461f74369..69ac44b72d22b9cd0ca932e00a6dee51fd1f93bd 100644 --- a/src/mednet/libs/classification/models/densenet.py +++ b/src/mednet/libs/classification/models/densenet.py @@ -12,7 +12,7 @@ import torch.utils.data import torchvision.models as models import torchvision.transforms from mednet.libs.common.data.typing import TransformSequence -from .model import Model +from mednet.libs.common.models.model import Model from .separate import separate from .transforms import RGB, SquareCenterPad diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py index 4fddda019a86057753e5e7361e7f03198df12d61..35953ddfd14377ce9a88267a154fbe99a6b8f9b2 100644 --- a/src/mednet/libs/classification/models/pasa.py +++ b/src/mednet/libs/classification/models/pasa.py @@ -12,7 +12,7 @@ import torch.optim.optimizer import torch.utils.data import torchvision.transforms from mednet.libs.common.data.typing import TransformSequence -from .model import Model +from mednet.libs.common.models.model import Model from .separate import separate from .transforms import Grayscale, SquareCenterPad diff --git a/src/mednet/libs/classification/models/typing.py b/src/mednet/libs/classification/models/typing.py index 9d1e7f32c1b08684cf324233694518855841eaa0..bbc0ba02438ffc5efee2f19964a5120caf4e9a66 100644 --- a/src/mednet/libs/classification/models/typing.py +++ b/src/mednet/libs/classification/models/typing.py @@ -5,9 +5,6 @@ import typing -Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any] -"""Definition of a lightning checkpoint.""" - BinaryPrediction: typing.TypeAlias = tuple[str, int, float] """The sample name, the target, and the predicted value.""" diff --git a/src/mednet/libs/classification/scripts/cli.py b/src/mednet/libs/classification/scripts/cli.py index 70db90d60caf29a4fc90a01cc0288a7d11dac265..b7839d0dd1bf78eae8eb81e4fdbd8078cc9bec0d 100644 --- a/src/mednet/libs/classification/scripts/cli.py +++ b/src/mednet/libs/classification/scripts/cli.py @@ -23,12 +23,6 @@ classification.add_command( classification.add_command( importlib.import_module("..database", package=__name__).database, ) -classification.add_command( - importlib.import_module("..evaluate", package=__name__).evaluate, -) -classification.add_command( - importlib.import_module("..experiment", package=__name__).experiment, -) classification.add_command( importlib.import_module("..predict", package=__name__).predict ) @@ -37,10 +31,16 @@ classification.add_command( ) classification.add_command( importlib.import_module( - "..train_analysis", + "mednet.libs.common.scripts.train_analysis", package=__name__, ).train_analysis, ) +classification.add_command( + importlib.import_module("..evaluate", package=__name__).evaluate, +) +classification.add_command( + importlib.import_module("..experiment", package=__name__).experiment, +) @click.group( diff --git a/src/mednet/libs/classification/scripts/evaluate.py b/src/mednet/libs/classification/scripts/evaluate.py index bd3a48f7c58627e1ac4fd6c50408cb30dbfa1aea..eb7970c9f8ad8eae45ff647d066f204a54ccf9cd 100644 --- a/src/mednet/libs/classification/scripts/evaluate.py +++ b/src/mednet/libs/classification/scripts/evaluate.py @@ -7,8 +7,7 @@ import pathlib import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup - -from .click import ConfigCommand +from mednet.libs.common.scripts.click import ConfigCommand # avoids X11/graphical desktop requirement when creating plots __import__("matplotlib").use("agg") @@ -116,6 +115,10 @@ def evaluate( import typing from matplotlib.backends.backend_pdf import PdfPages + from mednet.libs.common.scripts.utils import ( + execution_metadata, + save_json_with_backup, + ) from ..engine.evaluator import ( NumpyJSONEncoder, @@ -125,7 +128,6 @@ def evaluate( score_plot, tabulate_results, ) - from .utils import execution_metadata, save_json_with_backup evaluation_filename = "evaluation.json" evaluation_file = pathlib.Path(output_folder) / evaluation_filename diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py index b0bedb0d408e384496a8c5a2cbbe267aa5d70eae..3118b14b51332d79faafd31755856e50bac47ae3 100644 --- a/src/mednet/libs/classification/scripts/experiment.py +++ b/src/mednet/libs/classification/scripts/experiment.py @@ -98,7 +98,7 @@ def experiment( logger.info(f"Training runtime: {train_stop_timestamp-train_start_timestamp}") logger.info("Started train analysis") - from .train_analysis import train_analysis + from mednet.libs.common.scripts.train_analysis import train_analysis ctx.invoke( train_analysis, diff --git a/src/mednet/libs/classification/scripts/predict.py b/src/mednet/libs/classification/scripts/predict.py index e67a97476753d73c347a253ff1272e6f7955acea..50d7d03d9fce1674c1a049cf3ae008fb06516715 100644 --- a/src/mednet/libs/classification/scripts/predict.py +++ b/src/mednet/libs/classification/scripts/predict.py @@ -2,13 +2,13 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pathlib import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup - -from .click import ConfigCommand +from mednet.libs.common.scripts.click import ConfigCommand +from mednet.libs.common.scripts.predict import predict as predict_script +from mednet.libs.common.scripts.predict import reusable_options logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -32,92 +32,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") """, ) -@click.option( - "--output", - "-o", - help="""Path to a JSON file in which to save predictions for all samples in the - input DataModule (leading directories are created if they do not - exist).""", - required=True, - default="predictions.json", - cls=ResourceOption, - type=click.Path( - file_okay=True, - dir_okay=False, - writable=True, - path_type=pathlib.Path, - ), -) -@click.option( - "--model", - "-m", - help="""A lightning module instance implementing the network architecture - (not the weights, necessarily) to be used for prediction.""", - required=True, - type=click.UNPROCESSED, - cls=ResourceOption, -) -@click.option( - "--datamodule", - "-d", - help="""A lightning DataModule that will be asked for prediction data - loaders. Typically, this includes all configured splits in a DataModule, - however this is not a requirement. A DataModule that returns a single - dataloader for prediction (wrapped in a dictionary) is acceptable.""", - required=True, - type=click.UNPROCESSED, - cls=ResourceOption, -) -@click.option( - "--batch-size", - "-b", - help="""Number of samples in every batch (this parameter affects memory - requirements for the network).""", - required=True, - show_default=True, - default=1, - type=click.IntRange(min=1), - cls=ResourceOption, -) -@click.option( - "--device", - "-d", - help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', - show_default=True, - required=True, - default="cpu", - cls=ResourceOption, -) -@click.option( - "--weight", - "-w", - help="""Path or URL to pretrained model file (`.ckpt` extension), - corresponding to the architecture set with `--model`. Optionally, you may - also pass a directory containing the result of a training session, in which - case either the best (lowest validation) or latest model will be loaded.""", - required=True, - cls=ResourceOption, - type=click.Path( - exists=True, - file_okay=True, - dir_okay=True, - readable=True, - path_type=pathlib.Path, - ), -) -@click.option( - "--parallel", - "-P", - help="""Use multiprocessing for data loading: if set to -1 (default), - disables multiprocessing data loading. Set to 0 to enable as many data - loading instances as processing cores available in the system. Set to - >= 1 to enable that many multiprocessing instances for data loading.""", - type=click.IntRange(min=-1), - show_default=True, - required=True, - default=-1, - cls=ResourceOption, -) +@reusable_options @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def predict( output, @@ -130,63 +45,13 @@ def predict( **_, ) -> None: # numpydoc ignore=PR01 """Run inference (generates scores) on all input images, using a pre-trained model.""" - import json - import shutil - import typing - - from mednet.libs.classification.engine.predictor import run - from mednet.libs.common.engine.device import DeviceManager - from mednet.libs.common.utils.checkpointer import ( - get_checkpoint_to_run_inference, - ) - - from .utils import ( - device_properties, - execution_metadata, - model_summary, - save_json_with_backup, + predict_script( + output, + model, + datamodule, + batch_size, + device, + weight, + parallel, + **_, ) - - datamodule.batch_size = batch_size - datamodule.parallel = parallel - datamodule.model_transforms = model.model_transforms - - datamodule.prepare_data() - datamodule.setup(stage="predict") - - if weight.is_dir(): - weight = get_checkpoint_to_run_inference(weight) - - logger.info(f"Loading checkpoint from `{weight}`...") - model = type(model).load_from_checkpoint(weight, strict=False) - - device_manager = DeviceManager(device) - - # register metadata - json_data: dict[str, typing.Any] = execution_metadata() - json_data.update(device_properties(device_manager.device_type)) - json_data.update( - dict( - database_name=datamodule.database_name, - database_split=datamodule.split_name, - model_name=model.name, - ), - ) - json_data.update(model_summary(model)) - json_data = {k.replace("_", "-"): v for k, v in json_data.items()} - save_json_with_backup(output.with_suffix(".meta.json"), json_data) - - predictions = run(model, datamodule, device_manager) - - output.parent.mkdir(parents=True, exist_ok=True) - if output.exists(): - backup = output.parent / (output.name + "~") - logger.warning( - f"Output predictions file `{str(output)}` exists - " - f"backing it up to `{str(backup)}`...", - ) - shutil.copy(output, backup) - - with output.open("w") as f: - json.dump(predictions, f, indent=2) - logger.info(f"Predictions saved to `{str(output)}`") diff --git a/src/mednet/libs/classification/scripts/saliency/completeness.py b/src/mednet/libs/classification/scripts/saliency/completeness.py index 9ac006d542f057303f179027e6626597fd803ef0..fa2d243d6287df127f6da33cb0a31f7ed2b405f6 100644 --- a/src/mednet/libs/classification/scripts/saliency/completeness.py +++ b/src/mednet/libs/classification/scripts/saliency/completeness.py @@ -8,9 +8,9 @@ import typing import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup +from mednet.libs.common.scripts.click import ConfigCommand from ...models.typing import SaliencyMapAlgorithm -from ..click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/mednet/libs/classification/scripts/saliency/evaluate.py b/src/mednet/libs/classification/scripts/saliency/evaluate.py index 93390e040eb4bc27ae9cbbc2fdac0afe64771cf4..035d9b0e732edb2c70e4e430a596eafe4639ac01 100644 --- a/src/mednet/libs/classification/scripts/saliency/evaluate.py +++ b/src/mednet/libs/classification/scripts/saliency/evaluate.py @@ -8,9 +8,9 @@ import typing import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup +from mednet.libs.common.scripts.click import ConfigCommand from ...models.typing import SaliencyMapAlgorithm -from ..click import ConfigCommand # avoids X11/graphical desktop requirement when creating plots __import__("matplotlib").use("agg") diff --git a/src/mednet/libs/classification/scripts/saliency/generate.py b/src/mednet/libs/classification/scripts/saliency/generate.py index d4e8fd92dcd8dae0e6d25e12507b5acc8fbffcf9..3dfd8d84d29c3142f9081e4d6ffe56a38ffa9726 100644 --- a/src/mednet/libs/classification/scripts/saliency/generate.py +++ b/src/mednet/libs/classification/scripts/saliency/generate.py @@ -8,9 +8,9 @@ import typing import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup +from mednet.libs.common.scripts.click import ConfigCommand from ...models.typing import SaliencyMapAlgorithm -from ..click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/mednet/libs/classification/scripts/saliency/interpretability.py b/src/mednet/libs/classification/scripts/saliency/interpretability.py index 65a117d3073eefed8b4690ebaa3ca74d59444b9f..c4b2ba525ed720ac1c8bc4fc2651cc2c7e6fb3d4 100644 --- a/src/mednet/libs/classification/scripts/saliency/interpretability.py +++ b/src/mednet/libs/classification/scripts/saliency/interpretability.py @@ -7,8 +7,7 @@ import pathlib import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup - -from ..click import ConfigCommand +from mednet.libs.common.scripts.click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py index ea09def8d0ce69ad0f42be23d4fe7b5ec6d5b566..94fbcb00bacbb0cadce116e5a944c4abf57cdbb3 100644 --- a/src/mednet/libs/classification/scripts/train.py +++ b/src/mednet/libs/classification/scripts/train.py @@ -1,234 +1,13 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import functools -import pathlib -import typing - import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup +from mednet.libs.common.scripts.click import ConfigCommand +from mednet.libs.common.scripts.train import reusable_options +from mednet.libs.common.scripts.train import train as train_script -from .click import ConfigCommand - -# logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup("mednet", format="%(levelname)s: %(message)s") -def reusable_options(f): - """Wrap reusable training script options (for ``experiment``). - - This decorator equips the target function ``f`` with all (reusable) - ``train`` script options. - - Parameters - ---------- - f - The target function to equip with options. This function must have - parameters that accept such options. - - Returns - ------- - The decorated version of function ``f`` - """ - - @click.option( - "--output-folder", - "-o", - help="Directory in which to store results (created if does not exist)", - required=True, - type=click.Path( - file_okay=False, - dir_okay=True, - writable=True, - path_type=pathlib.Path, - ), - default="results", - cls=ResourceOption, - ) - @click.option( - "--model", - "-m", - help="A lightning module instance implementing the network to be trained", - required=True, - type=click.UNPROCESSED, - cls=ResourceOption, - ) - @click.option( - "--datamodule", - "-d", - help="A lightning DataModule containing the training and validation sets.", - required=True, - type=click.UNPROCESSED, - cls=ResourceOption, - ) - @click.option( - "--batch-size", - "-b", - help="Number of samples in every batch (this parameter affects " - "memory requirements for the network). If the number of samples in " - "the batch is larger than the total number of samples available for " - "training, this value is truncated. If this number is smaller, then " - "batches of the specified size are created and fed to the network " - "until there are no more new samples to feed (epoch is finished). " - "If the total number of training samples is not a multiple of the " - "batch-size, the last batch will be smaller than the first, unless " - "--drop-incomplete-batch is set, in which case this batch is not used.", - required=True, - show_default=True, - default=1, - type=click.IntRange(min=1), - cls=ResourceOption, - ) - @click.option( - "--accumulate-grad-batches", - "-a", - help="Number of accumulations for backward propagation to accumulate " - "gradients over k batches before stepping the optimizer. This " - "parameter, used in conjunction with the batch-size, may be used to " - "reduce the number of samples loaded in each iteration, to affect memory " - "usage in exchange for processing time (more iterations). This is " - "useful interesting when one is training on GPUs with a limited amount " - "of onboard RAM. The default of 1 forces the whole batch to be " - "processed at once. Otherwise the batch is multiplied by " - "accumulate-grad-batches pieces, and gradients are accumulated " - "to complete each training step.", - required=True, - show_default=True, - default=1, - type=click.IntRange(min=1), - cls=ResourceOption, - ) - @click.option( - "--drop-incomplete-batch/--no-drop-incomplete-batch", - "-D", - help="If set, the last batch in an epoch will be dropped if " - "incomplete. If you set this option, you should also consider " - "increasing the total number of epochs of training, as the total number " - "of training steps may be reduced.", - required=True, - show_default=True, - default=False, - cls=ResourceOption, - ) - @click.option( - "--epochs", - "-e", - help="""Number of epochs (complete training set passes) to train for. - If continuing from a saved checkpoint, ensure to provide a greater - number of epochs than was saved in the checkpoint to be loaded.""", - show_default=True, - required=True, - default=1000, - type=click.IntRange(min=1), - cls=ResourceOption, - ) - @click.option( - "--validation-period", - "-p", - help="""Number of epochs after which validation happens. By default, - we run validation after every training epoch (period=1). You can - change this to make validation more sparse, by increasing the - validation period. Notice that this affects checkpoint saving. While - checkpoints are created after every training step (the last training - step always triggers the overriding of latest checkpoint), and - this process is independent of validation runs, evaluation of the - 'best' model obtained so far based on those will be influenced by this - setting.""", - show_default=True, - required=True, - default=1, - type=click.IntRange(min=1), - cls=ResourceOption, - ) - @click.option( - "--device", - "-x", - help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', - show_default=True, - required=True, - default="cpu", - cls=ResourceOption, - ) - @click.option( - "--cache-samples/--no-cache-samples", - help="If set to True, loads the sample into memory, " - "otherwise loads them at runtime.", - required=True, - show_default=True, - default=False, - cls=ResourceOption, - ) - @click.option( - "--seed", - "-s", - help="Seed to use for the random number generator", - show_default=True, - required=False, - default=42, - type=click.IntRange(min=0), - cls=ResourceOption, - ) - @click.option( - "--parallel", - "-P", - help="""Use multiprocessing for data loading: if set to -1 (default), - disables multiprocessing data loading. Set to 0 to enable as many data - loading instances as processing cores available in the system. Set to - >= 1 to enable that many multiprocessing instances for data - loading.""", - type=click.IntRange(min=-1), - show_default=True, - required=True, - default=-1, - cls=ResourceOption, - ) - @click.option( - "--monitoring-interval", - "-I", - help="""Time between checks for the use of resources during each training - epoch, in seconds. An interval of 5 seconds, for example, will lead to - CPU and GPU resources being probed every 5 seconds during each training - epoch. Values registered in the training logs correspond to averages - (or maxima) observed through possibly many probes in each epoch. - Notice that setting a very small value may cause the probing process to - become extremely busy, potentially biasing the overall perception of - resource usage.""", - type=click.FloatRange(min=0.1), - show_default=True, - required=True, - default=5.0, - cls=ResourceOption, - ) - @click.option( - "--balance-classes/--no-balance-classes", - "-B/-N", - help="""If set, balances weights of the random sampler during - training so that samples from all sample classes are picked - equitably.""", - required=True, - show_default=True, - default=True, - cls=ResourceOption, - ) - @click.option( - "--augmentations", - "-A", - help="""Models that can be trained in this package are shipped without - explicit data augmentations. This option allows you to define a list of - data augmentations to use for training the selected model.""", - type=click.UNPROCESSED, - default=[], - cls=ResourceOption, - ) - @functools.wraps(f) - def wrapper_reusable_options(*args, **kwargs): - return f(*args, **kwargs) - - return wrapper_reusable_options - - @click.command( entry_point_group="mednet.libs.classification.config", cls=ConfigCommand, @@ -269,122 +48,20 @@ def train( extension that are used in subsequent tasks or from which training can be resumed. """ - - import torch - from lightning.pytorch import seed_everything - from mednet.libs.common.engine.device import DeviceManager - from mednet.libs.common.engine.trainer import run - from mednet.libs.common.utils.checkpointer import ( - get_checkpoint_to_resume_training, - ) - - from .utils import ( - device_properties, - execution_metadata, - model_summary, - save_json_with_backup, - ) - - checkpoint_file = None - if output_folder.is_dir(): - try: - checkpoint_file = get_checkpoint_to_resume_training(output_folder) - except FileNotFoundError: - logger.info( - f"Folder {output_folder} already exists, but I did not" - f" find any usable checkpoint file to resume training" - f" from. Starting from scratch...", - ) - - seed_everything(seed) - - # report model/transforms options - set data augmentations - logger.info(f"Network model: {type(model).__module__}.{type(model).__name__}") - model.augmentation_transforms = augmentations - - # reset datamodule with user configurable options - datamodule.batch_size = batch_size - datamodule.drop_incomplete_batch = drop_incomplete_batch - datamodule.cache_samples = cache_samples - datamodule.parallel = parallel - datamodule.model_transforms = model.model_transforms - - datamodule.prepare_data() - datamodule.setup(stage="fit") - - # If asked, rebalances the loss criterion based on the relative proportion - # of class examples available in the training set. Also affects the - # validation loss if a validation set is available on the DataModule. - if balance_classes: - logger.info("Applying train/valid loss balancing...") - model.balance_losses(datamodule) - else: - logger.info( - "Skipping sample class/dataset ownership balancing on user request", - ) - - logger.info(f"Training for at most {epochs} epochs.") - - arguments = {} - arguments["max_epoch"] = epochs - arguments["epoch"] = 0 - - if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): - # Sets the model normalizer with the unaugmented-train-subset if we are - # starting from scratch and/or the model does not contain its own - # checkpoint loading strategy (e.g. a pytorch stock checkpoint). This - # call may be a NOOP, if the model comes from outside this framework, - # and expects different weights for the normalisation layer. - if hasattr(model, "set_normalizer"): - model.set_normalizer(datamodule.unshuffled_train_dataloader()) - else: - logger.warning( - f"Model {model.name} has no `set_normalizer` method. " - "Skipping normalization setup (unsupported external model).", - ) - else: - # Normalizer will be loaded during model.on_load_checkpoint - checkpoint = torch.load(checkpoint_file) - start_epoch = checkpoint["epoch"] - logger.info( - f"Resuming from epoch {start_epoch} " - f"(checkpoint file: `{str(checkpoint_file)}`)...", - ) - - device_manager = DeviceManager(device) - - # stores all information we can think of, to reproduce this later - json_data: dict[str, typing.Any] = execution_metadata() - json_data.update(device_properties(device_manager.device_type)) - json_data.update( - dict( - database_name=datamodule.database_name, - split_name=datamodule.split_name, - epochs=epochs, - batch_size=batch_size, - accumulate_grad_batches=accumulate_grad_batches, - drop_incomplete_batch=drop_incomplete_batch, - validation_period=validation_period, - cache_samples=cache_samples, - seed=seed, - parallel=parallel, - monitoring_interval=monitoring_interval, - balance_classes=balance_classes, - model_name=model.name, - ), - ) - json_data.update(model_summary(model)) - json_data = {k.replace("_", "-"): v for k, v in json_data.items()} - save_json_with_backup(output_folder / "meta.json", json_data) - - run( - model=model, - datamodule=datamodule, - validation_period=validation_period, - device_manager=device_manager, - max_epochs=epochs, - output_folder=output_folder, - monitoring_interval=monitoring_interval, - accumulate_grad_batches=accumulate_grad_batches, - checkpoint=checkpoint_file, + train_script( + model, + output_folder, + epochs, + batch_size, + accumulate_grad_batches, + drop_incomplete_batch, + datamodule, + validation_period, + device, + cache_samples, + seed, + parallel, + monitoring_interval, + balance_classes, + **_, ) diff --git a/src/mednet/libs/common/models/__init__.py b/src/mednet/libs/common/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/libs/classification/models/loss_weights.py b/src/mednet/libs/common/models/loss_weights.py similarity index 98% rename from src/mednet/libs/classification/models/loss_weights.py rename to src/mednet/libs/common/models/loss_weights.py index 8111c354f1163ad5febbe4629868da3fa901618a..b3c2edec5b982762521eedbc4778ea21abb4d5fc 100644 --- a/src/mednet/libs/classification/models/loss_weights.py +++ b/src/mednet/libs/common/models/loss_weights.py @@ -8,7 +8,6 @@ from collections import Counter import torch import torch.utils.data -from mednet.libs.common.data.typing import DataLoader logger = logging.getLogger("mednet") diff --git a/src/mednet/models/model.py b/src/mednet/libs/common/models/model.py similarity index 100% rename from src/mednet/models/model.py rename to src/mednet/libs/common/models/model.py diff --git a/src/mednet/libs/classification/models/normalizer.py b/src/mednet/libs/common/models/normalizer.py similarity index 100% rename from src/mednet/libs/classification/models/normalizer.py rename to src/mednet/libs/common/models/normalizer.py diff --git a/src/mednet/libs/common/models/typing.py b/src/mednet/libs/common/models/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..97a04b6616724503e7690f6e8d33a40c315feed6 --- /dev/null +++ b/src/mednet/libs/common/models/typing.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Defines most common types used in code.""" + +import typing + +Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any] +"""Definition of a lightning checkpoint.""" \ No newline at end of file diff --git a/src/mednet/libs/classification/scripts/click.py b/src/mednet/libs/common/scripts/click.py similarity index 100% rename from src/mednet/libs/classification/scripts/click.py rename to src/mednet/libs/common/scripts/click.py diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..11cd63366803d73c0f4bf0600c1abe6ab2beecc0 --- /dev/null +++ b/src/mednet/libs/common/scripts/predict.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import functools +import pathlib + +import click +from clapper.click import ResourceOption +from clapper.logging import setup + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +def reusable_options(f): + """Wrap reusable predict script options (for ``experiment``). + + This decorator equips the target function ``f`` with all (reusable) + ``predict`` script options. + + Parameters + ---------- + f + The target function to equip with options. This function must have + parameters that accept such options. + + Returns + ------- + The decorated version of function ``f`` + """ + + @click.option( + "--output", + "-o", + help="""Path to a JSON file in which to save predictions for all samples in the + input DataModule (leading directories are created if they do not + exist).""", + required=True, + default="predictions.json", + cls=ResourceOption, + type=click.Path( + file_okay=True, + dir_okay=False, + writable=True, + path_type=pathlib.Path, + ), + ) + @click.option( + "--model", + "-m", + help="""A lightning module instance implementing the network architecture + (not the weights, necessarily) to be used for prediction.""", + required=True, + cls=ResourceOption, + ) + @click.option( + "--datamodule", + "-d", + help="""A lightning DataModule that will be asked for prediction data + loaders. Typically, this includes all configured splits in a DataModule, + however this is not a requirement. A DataModule that returns a single + dataloader for prediction (wrapped in a dictionary) is acceptable.""", + required=True, + cls=ResourceOption, + ) + @click.option( + "--batch-size", + "-b", + help="""Number of samples in every batch (this parameter affects memory + requirements for the network).""", + required=True, + show_default=True, + default=1, + type=click.IntRange(min=1), + cls=ResourceOption, + ) + @click.option( + "--device", + "-d", + help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + show_default=True, + required=True, + default="cpu", + cls=ResourceOption, + ) + @click.option( + "--weight", + "-w", + help="""Path or URL to pretrained model file (`.ckpt` extension), + corresponding to the architecture set with `--model`. Optionally, you may + also pass a directory containing the result of a training session, in which + case either the best (lowest validation) or latest model will be loaded.""", + required=True, + cls=ResourceOption, + type=click.Path( + exists=True, + file_okay=True, + dir_okay=True, + readable=True, + path_type=pathlib.Path, + ), + ) + @click.option( + "--parallel", + "-P", + help="""Use multiprocessing for data loading: if set to -1 (default), + disables multiprocessing data loading. Set to 0 to enable as many data + loading instances as processing cores available in the system. Set to + >= 1 to enable that many multiprocessing instances for data loading.""", + type=click.IntRange(min=-1), + show_default=True, + required=True, + default=-1, + cls=ResourceOption, + ) + @functools.wraps(f) + def wrapper_reusable_options(*args, **kwargs): + return f(*args, **kwargs) + + return wrapper_reusable_options + + +def predict( + output, + model, + datamodule, + batch_size, + device, + weight, + parallel, + **_, +) -> None: # numpydoc ignore=PR01 + """Run inference (generates scores) on all input images, using a pre-trained model.""" + import json + import shutil + import typing + + from mednet.libs.classification.engine.predictor import run + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.utils.checkpointer import ( + get_checkpoint_to_run_inference, + ) + + from .utils import ( + device_properties, + execution_metadata, + model_summary, + save_json_with_backup, + ) + + datamodule.set_chunk_size(batch_size, 1) + datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms + + datamodule.prepare_data() + datamodule.setup(stage="predict") + + if weight.is_dir(): + weight = get_checkpoint_to_run_inference(weight) + + logger.info(f"Loading checkpoint from `{weight}`...") + model = type(model).load_from_checkpoint(weight, strict=False) + + device_manager = DeviceManager(device) + + # register metadata + json_data: dict[str, typing.Any] = execution_metadata() + json_data.update(device_properties(device_manager.device_type)) + json_data.update( + dict( + database_name=datamodule.database_name, + database_split=datamodule.split_name, + model_name=model.name, + ), + ) + json_data.update(model_summary(model)) + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + save_json_with_backup(output.with_suffix(".meta.json"), json_data) + + predictions = run(model, datamodule, device_manager) + + output.parent.mkdir(parents=True, exist_ok=True) + if output.exists(): + backup = output.parent / (output.name + "~") + logger.warning( + f"Output predictions file `{str(output)}` exists - " + f"backing it up to `{str(backup)}`...", + ) + shutil.copy(output, backup) + + with output.open("w") as f: + json.dump(predictions, f, indent=2) + logger.info(f"Predictions saved to `{str(output)}`") diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2228725b673068f8417fcb1a8b6fc829357dab39 --- /dev/null +++ b/src/mednet/libs/common/scripts/train.py @@ -0,0 +1,355 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import functools +import pathlib +import typing + +import click +from clapper.click import ResourceOption +from clapper.logging import setup + +# logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +logger = setup("mednet", format="%(levelname)s: %(message)s") + + +def reusable_options(f): + """Wrap reusable training script options (for ``experiment``). + + This decorator equips the target function ``f`` with all (reusable) + ``train`` script options. + + Parameters + ---------- + f + The target function to equip with options. This function must have + parameters that accept such options. + + Returns + ------- + The decorated version of function ``f`` + """ + + @click.option( + "--output-folder", + "-o", + help="Directory in which to store results (created if does not exist)", + required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), + default="results", + cls=ResourceOption, + ) + @click.option( + "--model", + "-m", + help="A lightning module instance implementing the network to be trained", + required=True, + cls=ResourceOption, + ) + @click.option( + "--datamodule", + "-d", + help="A lightning DataModule containing the training and validation sets.", + required=True, + cls=ResourceOption, + ) + @click.option( + "--batch-size", + "-b", + help="Number of samples in every batch (this parameter affects " + "memory requirements for the network). If the number of samples in " + "the batch is larger than the total number of samples available for " + "training, this value is truncated. If this number is smaller, then " + "batches of the specified size are created and fed to the network " + "until there are no more new samples to feed (epoch is finished). " + "If the total number of training samples is not a multiple of the " + "batch-size, the last batch will be smaller than the first, unless " + "--drop-incomplete-batch is set, in which case this batch is not used.", + required=True, + show_default=True, + default=1, + type=click.IntRange(min=1), + cls=ResourceOption, + ) + @click.option( + "--accumulate-grad-batches", + "-a", + help="Number of accumulations for backward propagation to accumulate " + "gradients over k batches before stepping the optimizer. This " + "parameter, used in conjunction with the batch-size, may be used to " + "reduce the number of samples loaded in each iteration, to affect memory " + "usage in exchange for processing time (more iterations). This is " + "useful interesting when one is training on GPUs with a limited amount " + "of onboard RAM. The default of 1 forces the whole batch to be " + "processed at once. Otherwise the batch is multiplied by " + "accumulate-grad-batches pieces, and gradients are accumulated " + "to complete each training step.", + required=True, + show_default=True, + default=1, + type=click.IntRange(min=1), + cls=ResourceOption, + ) + @click.option( + "--drop-incomplete-batch/--no-drop-incomplete-batch", + "-D", + help="If set, the last batch in an epoch will be dropped if " + "incomplete. If you set this option, you should also consider " + "increasing the total number of epochs of training, as the total number " + "of training steps may be reduced.", + required=True, + show_default=True, + default=False, + cls=ResourceOption, + ) + @click.option( + "--epochs", + "-e", + help="""Number of epochs (complete training set passes) to train for. + If continuing from a saved checkpoint, ensure to provide a greater + number of epochs than was saved in the checkpoint to be loaded.""", + show_default=True, + required=True, + default=1000, + type=click.IntRange(min=1), + cls=ResourceOption, + ) + @click.option( + "--validation-period", + "-p", + help="""Number of epochs after which validation happens. By default, + we run validation after every training epoch (period=1). You can + change this to make validation more sparse, by increasing the + validation period. Notice that this affects checkpoint saving. While + checkpoints are created after every training step (the last training + step always triggers the overriding of latest checkpoint), and + this process is independent of validation runs, evaluation of the + 'best' model obtained so far based on those will be influenced by this + setting.""", + show_default=True, + required=True, + default=1, + type=click.IntRange(min=1), + cls=ResourceOption, + ) + @click.option( + "--device", + "-x", + help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + show_default=True, + required=True, + default="cpu", + cls=ResourceOption, + ) + @click.option( + "--cache-samples/--no-cache-samples", + help="If set to True, loads the sample into memory, " + "otherwise loads them at runtime.", + required=True, + show_default=True, + default=False, + cls=ResourceOption, + ) + @click.option( + "--seed", + "-s", + help="Seed to use for the random number generator", + show_default=True, + required=False, + default=42, + type=click.IntRange(min=0), + cls=ResourceOption, + ) + @click.option( + "--parallel", + "-P", + help="""Use multiprocessing for data loading: if set to -1 (default), + disables multiprocessing data loading. Set to 0 to enable as many data + loading instances as processing cores available in the system. Set to + >= 1 to enable that many multiprocessing instances for data + loading.""", + type=click.IntRange(min=-1), + show_default=True, + required=True, + default=-1, + cls=ResourceOption, + ) + @click.option( + "--monitoring-interval", + "-I", + help="""Time between checks for the use of resources during each training + epoch, in seconds. An interval of 5 seconds, for example, will lead to + CPU and GPU resources being probed every 5 seconds during each training + epoch. Values registered in the training logs correspond to averages + (or maxima) observed through possibly many probes in each epoch. + Notice that setting a very small value may cause the probing process to + become extremely busy, potentially biasing the overall perception of + resource usage.""", + type=click.FloatRange(min=0.1), + show_default=True, + required=True, + default=5.0, + cls=ResourceOption, + ) + @click.option( + "--balance-classes/--no-balance-classes", + "-B/-N", + help="""If set, balances weights of the random sampler during + training so that samples from all sample classes are picked + equitably.""", + required=True, + show_default=True, + default=True, + cls=ResourceOption, + ) + @functools.wraps(f) + def wrapper_reusable_options(*args, **kwargs): + return f(*args, **kwargs) + + return wrapper_reusable_options + + +def train( + model, + output_folder, + epochs, + batch_size, + accumulate_grad_batches, + drop_incomplete_batch, + datamodule, + validation_period, + device, + cache_samples, + seed, + parallel, + monitoring_interval, + balance_classes, + **_, +) -> None: # numpydoc ignore=PR01 + """Train an CNN to perform image classification. + + Training is performed for a configurable number of epochs, and + generates checkpoints. Checkpoints are model files with a .ckpt + extension that are used in subsequent tasks or from which training + can be resumed. + """ + + import torch + from lightning.pytorch import seed_everything + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.engine.trainer import run + from mednet.libs.common.utils.checkpointer import ( + get_checkpoint_to_resume_training, + ) + + from .utils import ( + device_properties, + execution_metadata, + model_summary, + save_json_with_backup, + ) + + checkpoint_file = None + if output_folder.is_dir(): + try: + checkpoint_file = get_checkpoint_to_resume_training(output_folder) + except FileNotFoundError: + logger.info( + f"Folder {output_folder} already exists, but I did not" + f" find any usable checkpoint file to resume training" + f" from. Starting from scratch...", + ) + + seed_everything(seed) + + # reset datamodule with user configurable options + datamodule.drop_incomplete_batch = drop_incomplete_batch + datamodule.cache_samples = cache_samples + datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms + + datamodule.prepare_data() + datamodule.setup(stage="fit") + + # If asked, rebalances the loss criterion based on the relative proportion + # of class examples available in the training set. Also affects the + # validation loss if a validation set is available on the DataModule. + if balance_classes: + logger.info("Applying train/valid loss balancing...") + model.balance_losses(datamodule) + else: + logger.info( + "Skipping sample class/dataset ownership balancing on user request", + ) + + logger.info(f"Training for at most {epochs} epochs.") + + arguments = {} + arguments["max_epoch"] = epochs + arguments["epoch"] = 0 + + if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): + # Sets the model normalizer with the unaugmented-train-subset if we are + # starting from scratch and/or the model does not contain its own + # checkpoint loading strategy (e.g. a pytorch stock checkpoint). This + # call may be a NOOP, if the model comes from outside this framework, + # and expects different weights for the normalisation layer. + if hasattr(model, "set_normalizer"): + model.set_normalizer(datamodule.unshuffled_train_dataloader()) + else: + logger.warning( + f"Model {model.name} has no `set_normalizer` method. " + "Skipping normalization setup (unsupported external model).", + ) + else: + # Normalizer will be loaded during model.on_load_checkpoint + checkpoint = torch.load(checkpoint_file) + start_epoch = checkpoint["epoch"] + logger.info( + f"Resuming from epoch {start_epoch} " + f"(checkpoint file: `{str(checkpoint_file)}`)...", + ) + + device_manager = DeviceManager(device) + + # stores all information we can think of, to reproduce this later + json_data: dict[str, typing.Any] = execution_metadata() + json_data.update(device_properties(device_manager.device_type)) + json_data.update( + dict( + database_name=datamodule.database_name, + split_name=datamodule.split_name, + epochs=epochs, + batch_size=batch_size, + accumulate_grad_batches=accumulate_grad_batches, + drop_incomplete_batch=drop_incomplete_batch, + validation_period=validation_period, + cache_samples=cache_samples, + seed=seed, + parallel=parallel, + monitoring_interval=monitoring_interval, + balance_classes=balance_classes, + model_name=model.name, + ), + ) + json_data.update(model_summary(model)) + json_data = {k.replace("_", "-"): v for k, v in json_data.items()} + save_json_with_backup(output_folder / "meta.json", json_data) + + run( + model=model, + datamodule=datamodule, + validation_period=validation_period, + device_manager=device_manager, + max_epochs=epochs, + output_folder=output_folder, + monitoring_interval=monitoring_interval, + accumulate_grad_batches=accumulate_grad_batches, + checkpoint=checkpoint_file, + ) diff --git a/src/mednet/libs/classification/scripts/train_analysis.py b/src/mednet/libs/common/scripts/train_analysis.py similarity index 100% rename from src/mednet/libs/classification/scripts/train_analysis.py rename to src/mednet/libs/common/scripts/train_analysis.py diff --git a/src/mednet/libs/classification/scripts/utils.py b/src/mednet/libs/common/scripts/utils.py similarity index 100% rename from src/mednet/libs/classification/scripts/utils.py rename to src/mednet/libs/common/scripts/utils.py diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py index 0796c3d77df3d28fc491de60545a056ebde9cf9e..933231c6d6aa3a6a96088d61f42407b041ac1da8 100644 --- a/src/mednet/libs/segmentation/scripts/train.py +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -1,222 +1,11 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import functools -import pathlib -import typing -from pathlib import Path - import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup +from mednet.libs.common.scripts.click import ConfigCommand +from mednet.libs.common.scripts.train import reusable_options +from mednet.libs.common.scripts.train import train as train_script -from .click import ConfigCommand - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -def reusable_options(f): - """The options that can be re-used by top-level scripts (i.e. - ``experiment``). - - This decorator equips the target function ``f`` with all (reusable) - ``train`` script options. - - Parameters - ---------- - f - The target function to equip with options. This function must have - parameters that accept such options. - - Returns - ------- - The decorated version of function ``f`` - """ # noqa D401 - - @click.option( - "--output-folder", - "-o", - help="Directory in which to store results (created if does not exist)", - required=True, - type=click.Path( - file_okay=False, - dir_okay=True, - writable=True, - path_type=pathlib.Path, - ), - default="results", - cls=ResourceOption, - ) - @click.option( - "--model", - "-m", - help="A lightning module instance implementing the network to be trained", - required=True, - cls=ResourceOption, - ) - @click.option( - "--datamodule", - "-d", - help="A lightning DataModule containing the training and validation sets.", - required=True, - cls=ResourceOption, - ) - @click.option( - "--batch-size", - "-b", - help="Number of samples in every batch (this parameter affects " - "memory requirements for the network). If the number of samples in " - "the batch is larger than the total number of samples available for " - "training, this value is truncated. If this number is smaller, then " - "batches of the specified size are created and fed to the network " - "until there are no more new samples to feed (epoch is finished). " - "If the total number of training samples is not a multiple of the " - "batch-size, the last batch will be smaller than the first, unless " - "--drop-incomplete-batch is set, in which case this batch is not used.", - required=True, - show_default=True, - default=1, - type=click.IntRange(min=1), - cls=ResourceOption, - ) - @click.option( - "--batch-chunk-count", - "-c", - help="Number of chunks in every batch (this parameter affects " - "memory requirements for the network). The number of samples " - "loaded for every iteration will be batch-size/batch-chunk-count. " - "batch-size needs to be divisible by batch-chunk-count, otherwise an " - "error will be raised. This parameter is used to reduce the number of " - "samples loaded in each iteration, in order to reduce the memory usage " - "in exchange for processing time (more iterations). This is especially " - "interesting when one is training on GPUs with limited RAM. The " - "default of 1 forces the whole batch to be processed at once. Otherwise " - "the batch is broken into batch-chunk-count pieces, and gradients are " - "accumulated to complete each batch.", - required=True, - show_default=True, - default=1, - type=click.IntRange(min=1), - cls=ResourceOption, - ) - @click.option( - "--drop-incomplete-batch/--no-drop-incomplete-batch", - "-D", - help="If set, the last batch in an epoch will be dropped if " - "incomplete. If you set this option, you should also consider " - "increasing the total number of epochs of training, as the total number " - "of training steps may be reduced.", - required=True, - show_default=True, - default=False, - cls=ResourceOption, - ) - @click.option( - "--epochs", - "-e", - help="""Number of epochs (complete training set passes) to train for. - If continuing from a saved checkpoint, ensure to provide a greater - number of epochs than was saved in the checkpoint to be loaded.""", - show_default=True, - required=True, - default=1000, - type=click.IntRange(min=1), - cls=ResourceOption, - ) - @click.option( - "--validation-period", - "-p", - help="""Number of epochs after which validation happens. By default, - we run validation after every training epoch (period=1). You can - change this to make validation more sparse, by increasing the - validation period. Notice that this affects checkpoint saving. While - checkpoints are created after every training step (the last training - step always triggers the overriding of latest checkpoint), and - this process is independent of validation runs, evaluation of the - 'best' model obtained so far based on those will be influenced by this - setting.""", - show_default=True, - required=True, - default=1, - type=click.IntRange(min=1), - cls=ResourceOption, - ) - @click.option( - "--device", - "-x", - help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', - show_default=True, - required=True, - default="cpu", - cls=ResourceOption, - ) - @click.option( - "--cache-samples/--no-cache-samples", - help="If set to True, loads the sample into memory, " - "otherwise loads them at runtime.", - required=True, - show_default=True, - default=False, - cls=ResourceOption, - ) - @click.option( - "--seed", - "-s", - help="Seed to use for the random number generator", - show_default=True, - required=False, - default=42, - type=click.IntRange(min=0), - cls=ResourceOption, - ) - @click.option( - "--parallel", - "-P", - help="""Use multiprocessing for data loading: if set to -1 (default), - disables multiprocessing data loading. Set to 0 to enable as many data - loading instances as processing cores available in the system. Set to - >= 1 to enable that many multiprocessing instances for data - loading.""", - type=click.IntRange(min=-1), - show_default=True, - required=True, - default=-1, - cls=ResourceOption, - ) - @click.option( - "--monitoring-interval", - "-I", - help="""Time between checks for the use of resources during each training - epoch, in seconds. An interval of 5 seconds, for example, will lead to - CPU and GPU resources being probed every 5 seconds during each training - epoch. Values registered in the training logs correspond to averages - (or maxima) observed through possibly many probes in each epoch. - Notice that setting a very small value may cause the probing process to - become extremely busy, potentially biasing the overall perception of - resource usage.""", - type=click.FloatRange(min=0.1), - show_default=True, - required=True, - default=5.0, - cls=ResourceOption, - ) - @click.option( - "--balance-classes/--no-balance-classes", - "-B/-N", - help="""If set, balances weights of the random sampler during - training so that samples from all sample classes are picked - equitably.""", - required=True, - show_default=True, - default=True, - cls=ResourceOption, - ) - @functools.wraps(f) - def wrapper_reusable_options(*args, **kwargs): - return f(*args, **kwargs) - - return wrapper_reusable_options +logger = setup("mednet", format="%(levelname)s: %(message)s") @click.command( @@ -228,7 +17,7 @@ def reusable_options(f): .. code:: sh - deepdraw train -vv pasa montgomery --batch-size=4 --device="cuda:0" + mednet train -vv pasa montgomery --batch-size=4 --device="cuda:0" """, ) @reusable_options @@ -250,127 +39,27 @@ def train( balance_classes, **_, ) -> None: # numpydoc ignore=PR01 - """Train an CNN to perform image classification. + """Train a model to perform image segmentation. Training is performed for a configurable number of epochs, and generates checkpoints. Checkpoints are model files with a .ckpt extension that are used in subsequent tasks or from which training can be resumed. """ - - import torch - from lightning.pytorch import seed_everything - from mednet.libs.common.engine.device import DeviceManager - from mednet.libs.common.engine.trainer import run - from mednet.libs.common.utils.checkpointer import ( - get_checkpoint_to_resume_training, - ) - - from .utils import ( - device_properties, - execution_metadata, - model_summary, - save_json_with_backup, - ) - - checkpoint_file = None - if Path.is_dir(output_folder): - try: - checkpoint_file = get_checkpoint_to_resume_training(output_folder) - except FileNotFoundError: - logger.info( - f"Folder {output_folder} already exists, but I did not" - f" find any usable checkpoint file to resume training" - f" from. Starting from scratch..." - ) - - seed_everything(seed) - - # reset datamodule with user configurable options - datamodule.set_chunk_size(batch_size, batch_chunk_count) - datamodule.drop_incomplete_batch = drop_incomplete_batch - datamodule.cache_samples = cache_samples - datamodule.parallel = parallel - datamodule.model_transforms = model.model_transforms - - datamodule.prepare_data() - datamodule.setup(stage="fit") - - # If asked, rebalances the loss criterion based on the relative proportion - # of class examples available in the training set. Also affects the - # validation loss if a validation set is available on the DataModule. - if balance_classes: - logger.info("Applying DataModule train sampler balancing...") - datamodule.balance_sampler_by_class = True - # logger.info("Applying train/valid loss balancing...") - # model.balance_losses_by_class(datamodule) - else: - logger.info( - "Skipping sample class/dataset ownership balancing on user request" - ) - - logger.info(f"Training for at most {epochs} epochs.") - - arguments = {} - arguments["max_epoch"] = epochs - arguments["epoch"] = 0 - - if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): - # Sets the model normalizer with the unaugmented-train-subset if we are - # starting from scratch and/or the model does not contain its own - # checkpoint loading strategy (e.g. a pytorch stock checkpoint). This - # call may be a NOOP, if the model comes from outside this framework, - # and expects different weights for the normalisation layer. - if hasattr(model, "set_normalizer"): - model.set_normalizer(datamodule.unshuffled_train_dataloader()) - else: - logger.warning( - f"Model {model.name} has no `set_normalizer` method. " - "Skipping normalization setup (unsupported external model)." - ) - else: - # Normalizer will be loaded during model.on_load_checkpoint - checkpoint = torch.load(checkpoint_file) - start_epoch = checkpoint["epoch"] - logger.info( - f"Resuming from epoch {start_epoch} " - f"(checkpoint file: `{str(checkpoint_file)}`)..." - ) - - device_manager = DeviceManager(device) - - # stores all information we can think of, to reproduce this later - json_data: dict[str, typing.Any] = execution_metadata() - json_data.update(device_properties(device_manager.device_type)) - json_data.update( - dict( - database_name=datamodule.database_name, - split_name=datamodule.split_name, - epochs=epochs, - batch_size=batch_size, - batch_chunk_count=batch_chunk_count, - drop_incomplete_batch=drop_incomplete_batch, - validation_period=validation_period, - cache_samples=cache_samples, - seed=seed, - parallel=parallel, - monitoring_interval=monitoring_interval, - balance_classes=balance_classes, - model_name=model.name, - ), - ) - json_data.update(model_summary(model)) - json_data = {k.replace("_", "-"): v for k, v in json_data.items()} - save_json_with_backup(output_folder / "meta.json", json_data) - - run( - model=model, - datamodule=datamodule, - validation_period=validation_period, - device_manager=device_manager, - max_epochs=epochs, - output_folder=output_folder, - monitoring_interval=monitoring_interval, - batch_chunk_count=batch_chunk_count, - checkpoint=checkpoint_file, + train_script( + model, + output_folder, + epochs, + batch_size, + batch_chunk_count, + drop_incomplete_batch, + datamodule, + validation_period, + device, + cache_samples, + seed, + parallel, + monitoring_interval, + balance_classes, + **_, ) diff --git a/tests/test_resource_monitor.py b/tests/test_resource_monitor.py index 028932cee910d1a71188d4851b2ee84e484a70c8..ed4095902ac057ee1773ca1c5431ac75f90619d1 100644 --- a/tests/test_resource_monitor.py +++ b/tests/test_resource_monitor.py @@ -13,7 +13,7 @@ import pytest def test_cpu_constants(): - from mednet.utils.resources import cpu_constants + from mednet.libs.common.utils.resources import cpu_constants v = cpu_constants() assert "memory-total-GB/cpu" in v @@ -21,8 +21,8 @@ def test_cpu_constants(): def test_combined_monitor_cpu_only(): - from mednet.engine.device import DeviceManager - from mednet.utils.resources import _CombinedMonitor + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.utils.resources import _CombinedMonitor monitor = _CombinedMonitor( device_type=DeviceManager("cpu").device_type, @@ -48,7 +48,7 @@ def test_combined_monitor_cpu_only(): reason="Requires macOS on Apple silicon to run", ) def test_mps_constants(): - from mednet.utils.resources import mps_constants + from mednet.libs.common.utils.resources import mps_constants v = mps_constants() assert "apple-processor-model" in v @@ -60,8 +60,8 @@ def test_mps_constants(): reason="Requires macOS on Apple silicon to run", ) def test_combined_monitor_macos_gpu(): - from mednet.engine.device import DeviceManager - from mednet.utils.resources import _CombinedMonitor + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.utils.resources import _CombinedMonitor monitor = _CombinedMonitor( device_type=DeviceManager("mps").device_type, @@ -90,7 +90,7 @@ def test_combined_monitor_macos_gpu(): shutil.which("nvidia-smi") is None, reason="Requires nvidia-smi to run" ) def test_cuda_constants(): - from mednet.utils.resources import cuda_constants + from mednet.libs.common.utils.resources import cuda_constants v = cuda_constants() assert "driver-version/gpu" in v @@ -101,8 +101,8 @@ def test_cuda_constants(): shutil.which("nvidia-smi") is None, reason="Requires nvidia-smi to run" ) def test_combined_monitor_nvidia_gpu(): - from mednet.engine.device import DeviceManager - from mednet.utils.resources import _CombinedMonitor + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.utils.resources import _CombinedMonitor monitor = _CombinedMonitor( device_type=DeviceManager("cuda").device_type, @@ -130,8 +130,8 @@ def test_combined_monitor_nvidia_gpu(): def test_aggregation(): - from mednet.engine.device import DeviceManager - from mednet.utils.resources import _CombinedMonitor, aggregate + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.utils.resources import _CombinedMonitor, aggregate monitor = _CombinedMonitor( device_type=DeviceManager("cpu").device_type, @@ -168,8 +168,8 @@ def test_mp_cpu_monitoring(): # Checks a "normal" workflow, where the monitoring interval is smaller than # the total work time - from mednet.engine.device import DeviceManager - from mednet.utils.resources import ResourceMonitor, aggregate + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.utils.resources import ResourceMonitor, aggregate rm = ResourceMonitor( interval=0.2, @@ -189,8 +189,8 @@ def test_mp_cpu_monitoring_short_processing(): # Checks we can get at least 1 monitoring sample even if the processing is # super short - from mednet.engine.device import DeviceManager - from mednet.utils.resources import ResourceMonitor, aggregate + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.utils.resources import ResourceMonitor, aggregate rm = ResourceMonitor( interval=0.5, @@ -214,8 +214,8 @@ def test_mp_macos_gpu_monitoring(): # Checks a "normal" workflow, where the monitoring interval is smaller than # the total work time - from mednet.engine.device import DeviceManager - from mednet.utils.resources import ResourceMonitor, aggregate + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.utils.resources import ResourceMonitor, aggregate rm = ResourceMonitor( interval=1.5, @@ -240,8 +240,8 @@ def test_mp_macos_gpu_monitoring_short_processing(): # shorter. In this check we execute an external utility which may # delay obtaining samples. - from mednet.engine.device import DeviceManager - from mednet.utils.resources import ResourceMonitor, aggregate + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.utils.resources import ResourceMonitor, aggregate rm = ResourceMonitor( interval=1.5, # on my mac, this measurements take ~0.9s