From d354514cfb59f811534ce92a56b8392a737fc135 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 14 May 2024 11:14:27 +0200 Subject: [PATCH] [predict] Split predict script into functions, fix segmentation --- .../libs/classification/scripts/predict.py | 40 +++-- src/mednet/libs/common/scripts/predict.py | 83 +++++----- src/mednet/libs/segmentation/models/lwnet.py | 4 +- .../libs/segmentation/models/separate.py | 2 +- src/mednet/libs/segmentation/scripts/cli.py | 7 +- .../libs/segmentation/scripts/experiment.py | 155 ++++++++++++++++++ .../libs/segmentation/scripts/predict.py | 86 ++++++++++ 7 files changed, 320 insertions(+), 57 deletions(-) create mode 100644 src/mednet/libs/segmentation/scripts/experiment.py create mode 100644 src/mednet/libs/segmentation/scripts/predict.py diff --git a/src/mednet/libs/classification/scripts/predict.py b/src/mednet/libs/classification/scripts/predict.py index 50d7d03d..04e9908d 100644 --- a/src/mednet/libs/classification/scripts/predict.py +++ b/src/mednet/libs/classification/scripts/predict.py @@ -7,7 +7,6 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand -from mednet.libs.common.scripts.predict import predict as predict_script from mednet.libs.common.scripts.predict import reusable_options logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -45,13 +44,34 @@ def predict( **_, ) -> None: # numpydoc ignore=PR01 """Run inference (generates scores) on all input images, using a pre-trained model.""" - predict_script( - output, - model, - datamodule, - batch_size, - device, - weight, - parallel, - **_, + + import json + import shutil + + from mednet.libs.classification.engine.predictor import run + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.scripts.predict import ( + load_checkpoint, + save_json_data, + setup_datamodule, ) + + setup_datamodule(datamodule, model, batch_size, parallel) + model = load_checkpoint(model, weight) + device_manager = DeviceManager(device) + save_json_data(datamodule, model, output, device_manager) + + predictions = run(model, datamodule, device_manager) + + output.parent.mkdir(parents=True, exist_ok=True) + if output.exists(): + backup = output.parent / (output.name + "~") + logger.warning( + f"Output predictions file `{str(output)}` exists - " + f"backing it up to `{str(backup)}`...", + ) + shutil.copy(output, backup) + + with output.open("w") as f: + json.dump(predictions, f, indent=2) + logger.info(f"Predictions saved to `{str(output)}`") diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py index 11cd6336..2cde0d8d 100644 --- a/src/mednet/libs/common/scripts/predict.py +++ b/src/mednet/libs/common/scripts/predict.py @@ -4,10 +4,12 @@ import functools import pathlib +import typing import click from clapper.click import ResourceOption from clapper.logging import setup +from mednet.libs.common.models.model import Model logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -120,34 +122,13 @@ def reusable_options(f): return wrapper_reusable_options -def predict( - output, - model, +def setup_datamodule( datamodule, + model, batch_size, - device, - weight, parallel, - **_, ) -> None: # numpydoc ignore=PR01 - """Run inference (generates scores) on all input images, using a pre-trained model.""" - import json - import shutil - import typing - - from mednet.libs.classification.engine.predictor import run - from mednet.libs.common.engine.device import DeviceManager - from mednet.libs.common.utils.checkpointer import ( - get_checkpoint_to_run_inference, - ) - - from .utils import ( - device_properties, - execution_metadata, - model_summary, - save_json_with_backup, - ) - + """Configure and set up the datamodule.""" datamodule.set_chunk_size(batch_size, 1) datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms @@ -155,15 +136,48 @@ def predict( datamodule.prepare_data() datamodule.setup(stage="predict") + +def load_checkpoint(model: Model, weight: pathlib.Path) -> Model: + """Load a model checkpoint for prediction. + + Parameters + ---------- + model + Instance of a model. + weight + The base directory containing either the "best", "last" or "periodic" + checkpoint to start the training session from. + + Returns + ------- + An instance of the model loaded from the checkpoint. + """ + from mednet.libs.common.utils.checkpointer import ( + get_checkpoint_to_run_inference, + ) + if weight.is_dir(): weight = get_checkpoint_to_run_inference(weight) logger.info(f"Loading checkpoint from `{weight}`...") - model = type(model).load_from_checkpoint(weight, strict=False) + return type(model).load_from_checkpoint(weight, strict=False) - device_manager = DeviceManager(device) - # register metadata +def save_json_data( + datamodule, + model, + output, + device_manager, +) -> None: # numpydoc ignore=PR01 + """Save prediction hyperparameters into a .json file.""" + + from .utils import ( + device_properties, + execution_metadata, + model_summary, + save_json_with_backup, + ) + json_data: dict[str, typing.Any] = execution_metadata() json_data.update(device_properties(device_manager.device_type)) json_data.update( @@ -176,18 +190,3 @@ def predict( json_data.update(model_summary(model)) json_data = {k.replace("_", "-"): v for k, v in json_data.items()} save_json_with_backup(output.with_suffix(".meta.json"), json_data) - - predictions = run(model, datamodule, device_manager) - - output.parent.mkdir(parents=True, exist_ok=True) - if output.exists(): - backup = output.parent / (output.name + "~") - logger.warning( - f"Output predictions file `{str(output)}` exists - " - f"backing it up to `{str(backup)}`...", - ) - shutil.copy(output, backup) - - with output.open("w") as f: - json.dump(predictions, f, indent=2) - logger.info(f"Predictions saved to `{str(output)}`") diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index c3190f0b..593c2f96 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -319,8 +319,8 @@ class LittleWNet(Model): return self._validation_loss(outputs, ground_truths, masks) def predict_step(self, batch, batch_idx, dataloader_idx=0): - outputs = self(batch[0]) - probabilities = torch.sigmoid(outputs) + output = self(batch[0])[1] + probabilities = torch.sigmoid(output) return separate((probabilities, batch[1])) def configure_optimizers(self): diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py index 8abd2329..7f520a4a 100644 --- a/src/mednet/libs/segmentation/models/separate.py +++ b/src/mednet/libs/segmentation/models/separate.py @@ -27,7 +27,7 @@ def _as_predictions( list[BinaryPrediction | MultiClassPrediction] A list of typed predictions that can be saved to disk. """ - return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples] + return [(v[1]["name"], v[1]["target"], v[0].item()) for v in samples] def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py index 52ce017f..49ba9821 100644 --- a/src/mednet/libs/segmentation/scripts/cli.py +++ b/src/mednet/libs/segmentation/scripts/cli.py @@ -12,10 +12,10 @@ from . import ( # compare, config, database, + predict, # evaluate, # experiment, # mkmask, - # predict, # significance, train, ) @@ -37,12 +37,15 @@ segmentation.add_command(database.database) # segmentation.add_command(evaluate.evaluate) # segmentation.add_command(experiment.experiment) # segmentation.add_command(mkmask.mkmask) -# segmentation.add_command(predict.predict) # segmentation.add_command(significance.significance) segmentation.add_command(train.train) +segmentation.add_command(predict.predict) segmentation.add_command( importlib.import_module( "mednet.libs.common.scripts.train_analysis", package=__name__, ).train_analysis, ) +segmentation.add_command( + importlib.import_module("..experiment", package=__name__).experiment, +) diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py new file mode 100644 index 00000000..0760adde --- /dev/null +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime + +import click +from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.logging import setup + +from .train import reusable_options as training_options + +# avoids X11/graphical desktop requirement when creating plots +__import__("matplotlib").use("agg") + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + entry_point_group="mednet.libs.segmentation.config", + cls=ConfigCommand, + epilog="""Examples: + +\b + 1. Train a pasa model with montgomery dataset, on the CPU, for only two + epochs, then runs inference and evaluation on stock datasets, report + performance as a table and figures: + + .. code:: sh + + $ mednet experiment -vv pasa montgomery --epochs=2 +""", +) +@training_options +@verbosity_option(logger=logger, cls=ResourceOption) +@click.pass_context +def experiment( + ctx, + model, + output_folder, + epochs, + batch_size, + batch_chunk_count, + drop_incomplete_batch, + datamodule, + validation_period, + device, + cache_samples, + seed, + parallel, + monitoring_interval, + **_, +): # numpydoc ignore=PR01 + r"""Run a complete experiment, from training, to prediction and evaluation. + + This script is just a wrapper around the individual scripts for training, + running prediction, and evaluating. It organises the output in a preset way:: + + \b + └─ <output-folder>/ + ├── command.sh + ├── model/ # the generated model will be here + ├── predictions.json # the prediction outputs for the sets + └── evaluation/ # the outputs of the evaluations for the sets + """ + + experiment_start_timestamp = datetime.now() + + train_start_timestamp = datetime.now() + logger.info(f"Started training at {train_start_timestamp}") + + from .train import train + + train_output_folder = output_folder / "model" + ctx.invoke( + train, + model=model, + output_folder=train_output_folder, + epochs=epochs, + batch_size=batch_size, + batch_chunk_count=batch_chunk_count, + drop_incomplete_batch=drop_incomplete_batch, + datamodule=datamodule, + validation_period=validation_period, + device=device, + cache_samples=cache_samples, + seed=seed, + parallel=parallel, + monitoring_interval=monitoring_interval, + ) + train_stop_timestamp = datetime.now() + + logger.info(f"Ended training in {train_stop_timestamp}") + logger.info( + f"Training runtime: {train_stop_timestamp-train_start_timestamp}" + ) + + logger.info("Started train analysis") + from mednet.libs.common.scripts.train_analysis import train_analysis + + logdir = train_output_folder / "logs" + ctx.invoke( + train_analysis, + logdir=logdir, + output_folder=train_output_folder, + ) + + logger.info("Ended train analysis") + + predict_start_timestamp = datetime.now() + logger.info(f"Started prediction at {predict_start_timestamp}") + + from .predict import predict + + predictions_output = output_folder / "predictions.json" + + ctx.invoke( + predict, + output=predictions_output, + model=model, + datamodule=datamodule, + device=device, + weight=train_output_folder, + batch_size=batch_size, + parallel=parallel, + ) + + predict_stop_timestamp = datetime.now() + logger.info(f"Ended prediction in {predict_stop_timestamp}") + logger.info( + f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}" + ) + + """evaluation_start_timestamp = datetime.now() + logger.info(f"Started evaluation at {evaluation_start_timestamp}") + + from .evaluate import evaluate + + ctx.invoke( + evaluate, + predictions=predictions_output, + output_folder=output_folder, + threshold="validation", + ) + + evaluation_stop_timestamp = datetime.now() + logger.info(f"Ended prediction in {evaluation_stop_timestamp}") + logger.info( + f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}" + )""" + + experiment_stop_timestamp = datetime.now() + logger.info( + f"Total experiment runtime: {experiment_stop_timestamp-experiment_start_timestamp}" + ) diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py new file mode 100644 index 00000000..9e0b44b0 --- /dev/null +++ b/src/mednet/libs/segmentation/scripts/predict.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import click +from clapper.click import ResourceOption, verbosity_option +from clapper.logging import setup +from mednet.libs.common.scripts.click import ConfigCommand +from mednet.libs.common.scripts.predict import reusable_options + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + entry_point_group="mednet.libs.segmentation.config", + cls=ConfigCommand, + epilog="""Examples: + +1. Run prediction on an existing DataModule configuration: + + .. code:: sh + + mednet predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json + +2. Enable multi-processing data loading with 6 processes: + + .. code:: sh + + mednet predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json + +""", +) +@reusable_options +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def predict( + output, + model, + datamodule, + batch_size, + device, + weight, + parallel, + **_, +) -> None: # numpydoc ignore=PR01 + """Run inference (generates scores) on all input images, using a pre-trained model.""" + + import json + import shutil + + from mednet.libs.classification.engine.predictor import run + from mednet.libs.common.engine.device import DeviceManager + from mednet.libs.common.scripts.predict import ( + load_checkpoint, + save_json_data, + setup_datamodule, + ) + + setup_datamodule(datamodule, model, batch_size, parallel) + model = load_checkpoint(model, weight) + device_manager = DeviceManager(device) + save_json_data(datamodule, model, output, device_manager) + + predictions = run(model, datamodule, device_manager) + + output.parent.mkdir(parents=True, exist_ok=True) + if output.exists(): + backup = output.parent / (output.name + "~") + logger.warning( + f"Output predictions file `{str(output)}` exists - " + f"backing it up to `{str(backup)}`...", + ) + shutil.copy(output, backup) + + # Remove targets from predictions, as they are images and not serializable + # A better solution should be found + serializable_predictions = {} + for split_name, sample in predictions.items(): + split_predictions = [] + for s in sample: + split_predictions.append((s[0], s[2])) + serializable_predictions[split_name] = split_predictions + + with output.open("w") as f: + json.dump(serializable_predictions, f, indent=2) + logger.info(f"Predictions saved to `{str(output)}`") -- GitLab