Skip to content
Snippets Groups Projects
Commit d354514c authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[predict] Split predict script into functions, fix segmentation

parent 4d303d1e
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -7,7 +7,6 @@ import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.predict import predict as predict_script
from mednet.libs.common.scripts.predict import reusable_options
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -45,13 +44,34 @@ def predict(
**_,
) -> None: # numpydoc ignore=PR01
"""Run inference (generates scores) on all input images, using a pre-trained model."""
predict_script(
output,
model,
datamodule,
batch_size,
device,
weight,
parallel,
**_,
import json
import shutil
from mednet.libs.classification.engine.predictor import run
from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.scripts.predict import (
load_checkpoint,
save_json_data,
setup_datamodule,
)
setup_datamodule(datamodule, model, batch_size, parallel)
model = load_checkpoint(model, weight)
device_manager = DeviceManager(device)
save_json_data(datamodule, model, output, device_manager)
predictions = run(model, datamodule, device_manager)
output.parent.mkdir(parents=True, exist_ok=True)
if output.exists():
backup = output.parent / (output.name + "~")
logger.warning(
f"Output predictions file `{str(output)}` exists - "
f"backing it up to `{str(backup)}`...",
)
shutil.copy(output, backup)
with output.open("w") as f:
json.dump(predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(output)}`")
......@@ -4,10 +4,12 @@
import functools
import pathlib
import typing
import click
from clapper.click import ResourceOption
from clapper.logging import setup
from mednet.libs.common.models.model import Model
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -120,34 +122,13 @@ def reusable_options(f):
return wrapper_reusable_options
def predict(
output,
model,
def setup_datamodule(
datamodule,
model,
batch_size,
device,
weight,
parallel,
**_,
) -> None: # numpydoc ignore=PR01
"""Run inference (generates scores) on all input images, using a pre-trained model."""
import json
import shutil
import typing
from mednet.libs.classification.engine.predictor import run
from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.utils.checkpointer import (
get_checkpoint_to_run_inference,
)
from .utils import (
device_properties,
execution_metadata,
model_summary,
save_json_with_backup,
)
"""Configure and set up the datamodule."""
datamodule.set_chunk_size(batch_size, 1)
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
......@@ -155,15 +136,48 @@ def predict(
datamodule.prepare_data()
datamodule.setup(stage="predict")
def load_checkpoint(model: Model, weight: pathlib.Path) -> Model:
"""Load a model checkpoint for prediction.
Parameters
----------
model
Instance of a model.
weight
The base directory containing either the "best", "last" or "periodic"
checkpoint to start the training session from.
Returns
-------
An instance of the model loaded from the checkpoint.
"""
from mednet.libs.common.utils.checkpointer import (
get_checkpoint_to_run_inference,
)
if weight.is_dir():
weight = get_checkpoint_to_run_inference(weight)
logger.info(f"Loading checkpoint from `{weight}`...")
model = type(model).load_from_checkpoint(weight, strict=False)
return type(model).load_from_checkpoint(weight, strict=False)
device_manager = DeviceManager(device)
# register metadata
def save_json_data(
datamodule,
model,
output,
device_manager,
) -> None: # numpydoc ignore=PR01
"""Save prediction hyperparameters into a .json file."""
from .utils import (
device_properties,
execution_metadata,
model_summary,
save_json_with_backup,
)
json_data: dict[str, typing.Any] = execution_metadata()
json_data.update(device_properties(device_manager.device_type))
json_data.update(
......@@ -176,18 +190,3 @@ def predict(
json_data.update(model_summary(model))
json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
save_json_with_backup(output.with_suffix(".meta.json"), json_data)
predictions = run(model, datamodule, device_manager)
output.parent.mkdir(parents=True, exist_ok=True)
if output.exists():
backup = output.parent / (output.name + "~")
logger.warning(
f"Output predictions file `{str(output)}` exists - "
f"backing it up to `{str(backup)}`...",
)
shutil.copy(output, backup)
with output.open("w") as f:
json.dump(predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(output)}`")
......@@ -319,8 +319,8 @@ class LittleWNet(Model):
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
output = self(batch[0])[1]
probabilities = torch.sigmoid(output)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
......
......@@ -27,7 +27,7 @@ def _as_predictions(
list[BinaryPrediction | MultiClassPrediction]
A list of typed predictions that can be saved to disk.
"""
return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples]
return [(v[1]["name"], v[1]["target"], v[0].item()) for v in samples]
def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
......
......@@ -12,10 +12,10 @@ from . import (
# compare,
config,
database,
predict,
# evaluate,
# experiment,
# mkmask,
# predict,
# significance,
train,
)
......@@ -37,12 +37,15 @@ segmentation.add_command(database.database)
# segmentation.add_command(evaluate.evaluate)
# segmentation.add_command(experiment.experiment)
# segmentation.add_command(mkmask.mkmask)
# segmentation.add_command(predict.predict)
# segmentation.add_command(significance.significance)
segmentation.add_command(train.train)
segmentation.add_command(predict.predict)
segmentation.add_command(
importlib.import_module(
"mednet.libs.common.scripts.train_analysis",
package=__name__,
).train_analysis,
)
segmentation.add_command(
importlib.import_module("..experiment", package=__name__).experiment,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from datetime import datetime
import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from .train import reusable_options as training_options
# avoids X11/graphical desktop requirement when creating plots
__import__("matplotlib").use("agg")
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.command(
entry_point_group="mednet.libs.segmentation.config",
cls=ConfigCommand,
epilog="""Examples:
\b
1. Train a pasa model with montgomery dataset, on the CPU, for only two
epochs, then runs inference and evaluation on stock datasets, report
performance as a table and figures:
.. code:: sh
$ mednet experiment -vv pasa montgomery --epochs=2
""",
)
@training_options
@verbosity_option(logger=logger, cls=ResourceOption)
@click.pass_context
def experiment(
ctx,
model,
output_folder,
epochs,
batch_size,
batch_chunk_count,
drop_incomplete_batch,
datamodule,
validation_period,
device,
cache_samples,
seed,
parallel,
monitoring_interval,
**_,
): # numpydoc ignore=PR01
r"""Run a complete experiment, from training, to prediction and evaluation.
This script is just a wrapper around the individual scripts for training,
running prediction, and evaluating. It organises the output in a preset way::
\b
└─ <output-folder>/
├── command.sh
├── model/ # the generated model will be here
├── predictions.json # the prediction outputs for the sets
└── evaluation/ # the outputs of the evaluations for the sets
"""
experiment_start_timestamp = datetime.now()
train_start_timestamp = datetime.now()
logger.info(f"Started training at {train_start_timestamp}")
from .train import train
train_output_folder = output_folder / "model"
ctx.invoke(
train,
model=model,
output_folder=train_output_folder,
epochs=epochs,
batch_size=batch_size,
batch_chunk_count=batch_chunk_count,
drop_incomplete_batch=drop_incomplete_batch,
datamodule=datamodule,
validation_period=validation_period,
device=device,
cache_samples=cache_samples,
seed=seed,
parallel=parallel,
monitoring_interval=monitoring_interval,
)
train_stop_timestamp = datetime.now()
logger.info(f"Ended training in {train_stop_timestamp}")
logger.info(
f"Training runtime: {train_stop_timestamp-train_start_timestamp}"
)
logger.info("Started train analysis")
from mednet.libs.common.scripts.train_analysis import train_analysis
logdir = train_output_folder / "logs"
ctx.invoke(
train_analysis,
logdir=logdir,
output_folder=train_output_folder,
)
logger.info("Ended train analysis")
predict_start_timestamp = datetime.now()
logger.info(f"Started prediction at {predict_start_timestamp}")
from .predict import predict
predictions_output = output_folder / "predictions.json"
ctx.invoke(
predict,
output=predictions_output,
model=model,
datamodule=datamodule,
device=device,
weight=train_output_folder,
batch_size=batch_size,
parallel=parallel,
)
predict_stop_timestamp = datetime.now()
logger.info(f"Ended prediction in {predict_stop_timestamp}")
logger.info(
f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}"
)
"""evaluation_start_timestamp = datetime.now()
logger.info(f"Started evaluation at {evaluation_start_timestamp}")
from .evaluate import evaluate
ctx.invoke(
evaluate,
predictions=predictions_output,
output_folder=output_folder,
threshold="validation",
)
evaluation_stop_timestamp = datetime.now()
logger.info(f"Ended prediction in {evaluation_stop_timestamp}")
logger.info(
f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}"
)"""
experiment_stop_timestamp = datetime.now()
logger.info(
f"Total experiment runtime: {experiment_stop_timestamp-experiment_start_timestamp}"
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.predict import reusable_options
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.command(
entry_point_group="mednet.libs.segmentation.config",
cls=ConfigCommand,
epilog="""Examples:
1. Run prediction on an existing DataModule configuration:
.. code:: sh
mednet predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json
2. Enable multi-processing data loading with 6 processes:
.. code:: sh
mednet predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json
""",
)
@reusable_options
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def predict(
output,
model,
datamodule,
batch_size,
device,
weight,
parallel,
**_,
) -> None: # numpydoc ignore=PR01
"""Run inference (generates scores) on all input images, using a pre-trained model."""
import json
import shutil
from mednet.libs.classification.engine.predictor import run
from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.scripts.predict import (
load_checkpoint,
save_json_data,
setup_datamodule,
)
setup_datamodule(datamodule, model, batch_size, parallel)
model = load_checkpoint(model, weight)
device_manager = DeviceManager(device)
save_json_data(datamodule, model, output, device_manager)
predictions = run(model, datamodule, device_manager)
output.parent.mkdir(parents=True, exist_ok=True)
if output.exists():
backup = output.parent / (output.name + "~")
logger.warning(
f"Output predictions file `{str(output)}` exists - "
f"backing it up to `{str(backup)}`...",
)
shutil.copy(output, backup)
# Remove targets from predictions, as they are images and not serializable
# A better solution should be found
serializable_predictions = {}
for split_name, sample in predictions.items():
split_predictions = []
for s in sample:
split_predictions.append((s[0], s[2]))
serializable_predictions[split_name] = split_predictions
with output.open("w") as f:
json.dump(serializable_predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(output)}`")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment