Skip to content
Snippets Groups Projects
Commit c2e0c01c authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[predict] Split predict into functions and implement in segmentation

parent bfe67849
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -111,11 +111,11 @@ def experiment( ...@@ -111,11 +111,11 @@ def experiment(
from .predict import predict from .predict import predict
predictions_output = output_folder / "predictions.json" predictions_output = output_folder / "predictions"
ctx.invoke( ctx.invoke(
predict, predict,
output=predictions_output, output_folder=predictions_output,
model=model, model=model,
datamodule=datamodule, datamodule=datamodule,
device=device, device=device,
...@@ -135,7 +135,7 @@ def experiment( ...@@ -135,7 +135,7 @@ def experiment(
ctx.invoke( ctx.invoke(
evaluate, evaluate,
predictions=predictions_output, predictions=predictions_output / "predictions.json",
output_folder=output_folder, output_folder=output_folder,
threshold="validation", threshold="validation",
) )
......
...@@ -34,7 +34,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -34,7 +34,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@reusable_options @reusable_options
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def predict( def predict(
output, output_folder,
model, model,
datamodule, datamodule,
batch_size, batch_size,
...@@ -56,22 +56,24 @@ def predict( ...@@ -56,22 +56,24 @@ def predict(
setup_datamodule, setup_datamodule,
) )
predictions_file = output_folder / "predictions.json"
predictions_file.parent.mkdir(parents=True, exist_ok=True)
setup_datamodule(datamodule, model, batch_size, parallel) setup_datamodule(datamodule, model, batch_size, parallel)
model = load_checkpoint(model, weight) model = load_checkpoint(model, weight)
device_manager = DeviceManager(device) device_manager = DeviceManager(device)
save_json_data(datamodule, model, output, device_manager) save_json_data(datamodule, model, predictions_file, device_manager)
predictions = run(model, datamodule, device_manager) predictions = run(model, datamodule, device_manager)
output.parent.mkdir(parents=True, exist_ok=True) if predictions_file.exists():
if output.exists(): backup = predictions_file.parent / (predictions_file.name + "~")
backup = output.parent / (output.name + "~")
logger.warning( logger.warning(
f"Output predictions file `{str(output)}` exists - " f"Output predictions file `{str(predictions_file)}` exists - "
f"backing it up to `{str(backup)}`...", f"backing it up to `{str(backup)}`...",
) )
shutil.copy(output, backup) shutil.copy(predictions_file, backup)
with output.open("w") as f: with predictions_file.open("w") as f:
json.dump(predictions, f, indent=2) json.dump(predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(output)}`") logger.info(f"Predictions saved to `{str(predictions_file)}`")
...@@ -329,7 +329,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): ...@@ -329,7 +329,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir):
runner = CliRunner() runner = CliRunner()
with stdout_logging() as buf: with stdout_logging() as buf:
output = temporary_basedir / "predictions.json" output = temporary_basedir / "predictions"
last = _get_checkpoint_from_alias( last = _get_checkpoint_from_alias(
temporary_basedir / "results", temporary_basedir / "results",
"periodic", "periodic",
...@@ -343,7 +343,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): ...@@ -343,7 +343,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir):
"-vv", "-vv",
"--batch-size=1", "--batch-size=1",
f"--weight={str(last)}", f"--weight={str(last)}",
f"--output={str(output)}", f"--output-folder={str(output)}",
], ],
) )
_assert_exit_0(result) _assert_exit_0(result)
...@@ -379,7 +379,8 @@ def test_evaluate_pasa_montgomery(temporary_basedir): ...@@ -379,7 +379,8 @@ def test_evaluate_pasa_montgomery(temporary_basedir):
runner = CliRunner() runner = CliRunner()
with stdout_logging() as buf: with stdout_logging() as buf:
prediction_path = temporary_basedir / "predictions.json" prediction_path = temporary_basedir / "predictions"
predictions_file = prediction_path / "predictions.json"
evaluation_filename = "evaluation.json" evaluation_filename = "evaluation.json"
evaluation_file = temporary_basedir / evaluation_filename evaluation_file = temporary_basedir / evaluation_filename
result = runner.invoke( result = runner.invoke(
...@@ -387,7 +388,7 @@ def test_evaluate_pasa_montgomery(temporary_basedir): ...@@ -387,7 +388,7 @@ def test_evaluate_pasa_montgomery(temporary_basedir):
[ [
"-vv", "-vv",
"montgomery", "montgomery",
f"--predictions={str(prediction_path)}", f"--predictions={predictions_file}",
f"--output-folder={str(temporary_basedir)}", f"--output-folder={str(temporary_basedir)}",
"--threshold=test", "--threshold=test",
], ],
...@@ -440,9 +441,11 @@ def test_experiment(temporary_basedir): ...@@ -440,9 +441,11 @@ def test_experiment(temporary_basedir):
_assert_exit_0(result) _assert_exit_0(result)
assert (output_folder / "model" / "meta.json").exists() assert (output_folder / "model" / "meta.json").exists()
assert (output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists() assert (
assert (output_folder / "predictions.json").exists() output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt"
assert (output_folder / "predictions.meta.json").exists() ).exists()
assert (output_folder / "predictions" / "predictions.json").exists()
assert (output_folder / "predictions" / "predictions.meta.json").exists()
# Need to glob because we cannot be sure of the checkpoint with lowest validation loss # Need to glob because we cannot be sure of the checkpoint with lowest validation loss
assert ( assert (
......
...@@ -32,20 +32,18 @@ def reusable_options(f): ...@@ -32,20 +32,18 @@ def reusable_options(f):
""" """
@click.option( @click.option(
"--output", "--output-folder",
"-o", "-o",
help="""Path to a JSON file in which to save predictions for all samples in the help="Directory in which to save predictions (created if does not exist)",
input DataModule (leading directories are created if they do not
exist).""",
required=True, required=True,
default="predictions.json",
cls=ResourceOption,
type=click.Path( type=click.Path(
file_okay=True, file_okay=False,
dir_okay=False, dir_okay=True,
writable=True, writable=True,
path_type=pathlib.Path, path_type=pathlib.Path,
), ),
default="predictions",
cls=ResourceOption,
) )
@click.option( @click.option(
"--model", "--model",
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import typing import typing
import torch
from mednet.libs.common.data.typing import Sample from mednet.libs.common.data.typing import Sample
from .typing import BinaryPrediction, MultiClassPrediction from .typing import BinaryPrediction, MultiClassPrediction
...@@ -27,7 +26,7 @@ def _as_predictions( ...@@ -27,7 +26,7 @@ def _as_predictions(
list[BinaryPrediction | MultiClassPrediction] list[BinaryPrediction | MultiClassPrediction]
A list of typed predictions that can be saved to disk. A list of typed predictions that can be saved to disk.
""" """
return [(v[1]["name"], v[1]["target"], v[0].item()) for v in samples] return [(v[1]["name"], v[1]["target"], v[0]) for v in samples]
def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
...@@ -58,4 +57,5 @@ def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: ...@@ -58,4 +57,5 @@ def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
{key: value[i] for key, value in batch[1].items()} {key: value[i] for key, value in batch[1].items()}
for i in range(len(batch[0])) for i in range(len(batch[0]))
] ]
return _as_predictions(zip(torch.flatten(batch[0]), metadata))
return _as_predictions(zip(batch[0], metadata))
...@@ -112,11 +112,11 @@ def experiment( ...@@ -112,11 +112,11 @@ def experiment(
from .predict import predict from .predict import predict
predictions_output = output_folder / "predictions.json" predictions_output = output_folder / "predictions"
ctx.invoke( ctx.invoke(
predict, predict,
output=predictions_output, output_folder=predictions_output,
model=model, model=model,
datamodule=datamodule, datamodule=datamodule,
device=device, device=device,
...@@ -131,7 +131,7 @@ def experiment( ...@@ -131,7 +131,7 @@ def experiment(
f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}" f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}"
) )
"""evaluation_start_timestamp = datetime.now() evaluation_start_timestamp = datetime.now()
logger.info(f"Started evaluation at {evaluation_start_timestamp}") logger.info(f"Started evaluation at {evaluation_start_timestamp}")
from .evaluate import evaluate from .evaluate import evaluate
...@@ -140,14 +140,15 @@ def experiment( ...@@ -140,14 +140,15 @@ def experiment(
evaluate, evaluate,
predictions=predictions_output, predictions=predictions_output,
output_folder=output_folder, output_folder=output_folder,
threshold="validation", # threshold="validation",
threshold=0.5,
) )
evaluation_stop_timestamp = datetime.now() evaluation_stop_timestamp = datetime.now()
logger.info(f"Ended prediction in {evaluation_stop_timestamp}") logger.info(f"Ended prediction in {evaluation_stop_timestamp}")
logger.info( logger.info(
f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}" f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}"
)""" )
experiment_stop_timestamp = datetime.now() experiment_stop_timestamp = datetime.now()
logger.info( logger.info(
......
...@@ -3,15 +3,47 @@ ...@@ -3,15 +3,47 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import click import click
import h5py
import PIL
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.predict import reusable_options from mednet.libs.common.scripts.predict import reusable_options
from tqdm import tqdm
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _save_hdf5(
stem: pathlib.Path, prob: PIL.Image.Image, output_folder: pathlib.Path
):
"""Save prediction maps as image in the same format as the test image.
Parameters
----------
stem
Name of the file without extension on the original dataset.
prob
Monochrome Image with prediction maps.
output_folder
Directory in which to store predictions.
"""
fullpath = output_folder / f"{stem}.hdf5"
tqdm.write(f"Saving {fullpath}...")
fullpath.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(fullpath, "w") as f:
data = prob.squeeze(0).numpy()
f.create_dataset(
"array", data=data, compression="gzip", compression_opts=9
)
@click.command( @click.command(
entry_point_group="mednet.libs.segmentation.config", entry_point_group="mednet.libs.segmentation.config",
cls=ConfigCommand, cls=ConfigCommand,
...@@ -34,7 +66,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -34,7 +66,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@reusable_options @reusable_options
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def predict( def predict(
output, output_folder,
model, model,
datamodule, datamodule,
batch_size, batch_size,
...@@ -45,9 +77,6 @@ def predict( ...@@ -45,9 +77,6 @@ def predict(
) -> None: # numpydoc ignore=PR01 ) -> None: # numpydoc ignore=PR01
"""Run inference (generates scores) on all input images, using a pre-trained model.""" """Run inference (generates scores) on all input images, using a pre-trained model."""
import json
import shutil
from mednet.libs.classification.engine.predictor import run from mednet.libs.classification.engine.predictor import run
from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.scripts.predict import ( from mednet.libs.common.scripts.predict import (
...@@ -56,31 +85,15 @@ def predict( ...@@ -56,31 +85,15 @@ def predict(
setup_datamodule, setup_datamodule,
) )
predictions_file = output_folder / "predictions.json"
setup_datamodule(datamodule, model, batch_size, parallel) setup_datamodule(datamodule, model, batch_size, parallel)
model = load_checkpoint(model, weight) model = load_checkpoint(model, weight)
device_manager = DeviceManager(device) device_manager = DeviceManager(device)
save_json_data(datamodule, model, output, device_manager) save_json_data(datamodule, model, predictions_file, device_manager)
predictions = run(model, datamodule, device_manager) predictions = run(model, datamodule, device_manager)
output.parent.mkdir(parents=True, exist_ok=True) for split_name, split in predictions.items():
if output.exists(): for sample in split:
backup = output.parent / (output.name + "~") _save_hdf5(sample[0], sample[2], output_folder)
logger.warning(
f"Output predictions file `{str(output)}` exists - "
f"backing it up to `{str(backup)}`...",
)
shutil.copy(output, backup)
# Remove targets from predictions, as they are images and not serializable
# A better solution should be found
serializable_predictions = {}
for split_name, sample in predictions.items():
split_predictions = []
for s in sample:
split_predictions.append((s[0], s[2]))
serializable_predictions[split_name] = split_predictions
with output.open("w") as f:
json.dump(serializable_predictions, f, indent=2)
logger.info(f"Predictions saved to `{str(output)}`")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment