From c2e0c01c3bca3bde40b78c0ca18e840ad8865077 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Thu, 16 May 2024 15:27:56 +0200 Subject: [PATCH] [predict] Split predict into functions and implement in segmentation --- .../libs/classification/scripts/experiment.py | 6 +- .../libs/classification/scripts/predict.py | 20 +++--- .../tests/test_cli_classification.py | 17 +++-- src/mednet/libs/common/scripts/predict.py | 14 ++-- .../libs/segmentation/engine/__init__.py | 0 .../libs/segmentation/models/separate.py | 6 +- .../libs/segmentation/scripts/experiment.py | 11 ++-- .../libs/segmentation/scripts/predict.py | 65 +++++++++++-------- 8 files changed, 78 insertions(+), 61 deletions(-) create mode 100644 src/mednet/libs/segmentation/engine/__init__.py diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py index 814f3b76..6f11690a 100644 --- a/src/mednet/libs/classification/scripts/experiment.py +++ b/src/mednet/libs/classification/scripts/experiment.py @@ -111,11 +111,11 @@ def experiment( from .predict import predict - predictions_output = output_folder / "predictions.json" + predictions_output = output_folder / "predictions" ctx.invoke( predict, - output=predictions_output, + output_folder=predictions_output, model=model, datamodule=datamodule, device=device, @@ -135,7 +135,7 @@ def experiment( ctx.invoke( evaluate, - predictions=predictions_output, + predictions=predictions_output / "predictions.json", output_folder=output_folder, threshold="validation", ) diff --git a/src/mednet/libs/classification/scripts/predict.py b/src/mednet/libs/classification/scripts/predict.py index 04e9908d..2691727d 100644 --- a/src/mednet/libs/classification/scripts/predict.py +++ b/src/mednet/libs/classification/scripts/predict.py @@ -34,7 +34,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @reusable_options @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def predict( - output, + output_folder, model, datamodule, batch_size, @@ -56,22 +56,24 @@ def predict( setup_datamodule, ) + predictions_file = output_folder / "predictions.json" + predictions_file.parent.mkdir(parents=True, exist_ok=True) + setup_datamodule(datamodule, model, batch_size, parallel) model = load_checkpoint(model, weight) device_manager = DeviceManager(device) - save_json_data(datamodule, model, output, device_manager) + save_json_data(datamodule, model, predictions_file, device_manager) predictions = run(model, datamodule, device_manager) - output.parent.mkdir(parents=True, exist_ok=True) - if output.exists(): - backup = output.parent / (output.name + "~") + if predictions_file.exists(): + backup = predictions_file.parent / (predictions_file.name + "~") logger.warning( - f"Output predictions file `{str(output)}` exists - " + f"Output predictions file `{str(predictions_file)}` exists - " f"backing it up to `{str(backup)}`...", ) - shutil.copy(output, backup) + shutil.copy(predictions_file, backup) - with output.open("w") as f: + with predictions_file.open("w") as f: json.dump(predictions, f, indent=2) - logger.info(f"Predictions saved to `{str(output)}`") + logger.info(f"Predictions saved to `{str(predictions_file)}`") diff --git a/src/mednet/libs/classification/tests/test_cli_classification.py b/src/mednet/libs/classification/tests/test_cli_classification.py index f735ae17..99e70d8f 100644 --- a/src/mednet/libs/classification/tests/test_cli_classification.py +++ b/src/mednet/libs/classification/tests/test_cli_classification.py @@ -329,7 +329,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): runner = CliRunner() with stdout_logging() as buf: - output = temporary_basedir / "predictions.json" + output = temporary_basedir / "predictions" last = _get_checkpoint_from_alias( temporary_basedir / "results", "periodic", @@ -343,7 +343,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): "-vv", "--batch-size=1", f"--weight={str(last)}", - f"--output={str(output)}", + f"--output-folder={str(output)}", ], ) _assert_exit_0(result) @@ -379,7 +379,8 @@ def test_evaluate_pasa_montgomery(temporary_basedir): runner = CliRunner() with stdout_logging() as buf: - prediction_path = temporary_basedir / "predictions.json" + prediction_path = temporary_basedir / "predictions" + predictions_file = prediction_path / "predictions.json" evaluation_filename = "evaluation.json" evaluation_file = temporary_basedir / evaluation_filename result = runner.invoke( @@ -387,7 +388,7 @@ def test_evaluate_pasa_montgomery(temporary_basedir): [ "-vv", "montgomery", - f"--predictions={str(prediction_path)}", + f"--predictions={predictions_file}", f"--output-folder={str(temporary_basedir)}", "--threshold=test", ], @@ -440,9 +441,11 @@ def test_experiment(temporary_basedir): _assert_exit_0(result) assert (output_folder / "model" / "meta.json").exists() - assert (output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt").exists() - assert (output_folder / "predictions.json").exists() - assert (output_folder / "predictions.meta.json").exists() + assert ( + output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt" + ).exists() + assert (output_folder / "predictions" / "predictions.json").exists() + assert (output_folder / "predictions" / "predictions.meta.json").exists() # Need to glob because we cannot be sure of the checkpoint with lowest validation loss assert ( diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py index 2cde0d8d..48fbf4d1 100644 --- a/src/mednet/libs/common/scripts/predict.py +++ b/src/mednet/libs/common/scripts/predict.py @@ -32,20 +32,18 @@ def reusable_options(f): """ @click.option( - "--output", + "--output-folder", "-o", - help="""Path to a JSON file in which to save predictions for all samples in the - input DataModule (leading directories are created if they do not - exist).""", + help="Directory in which to save predictions (created if does not exist)", required=True, - default="predictions.json", - cls=ResourceOption, type=click.Path( - file_okay=True, - dir_okay=False, + file_okay=False, + dir_okay=True, writable=True, path_type=pathlib.Path, ), + default="predictions", + cls=ResourceOption, ) @click.option( "--model", diff --git a/src/mednet/libs/segmentation/engine/__init__.py b/src/mednet/libs/segmentation/engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py index 7f520a4a..6b9bb5de 100644 --- a/src/mednet/libs/segmentation/models/separate.py +++ b/src/mednet/libs/segmentation/models/separate.py @@ -5,7 +5,6 @@ import typing -import torch from mednet.libs.common.data.typing import Sample from .typing import BinaryPrediction, MultiClassPrediction @@ -27,7 +26,7 @@ def _as_predictions( list[BinaryPrediction | MultiClassPrediction] A list of typed predictions that can be saved to disk. """ - return [(v[1]["name"], v[1]["target"], v[0].item()) for v in samples] + return [(v[1]["name"], v[1]["target"], v[0]) for v in samples] def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: @@ -58,4 +57,5 @@ def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: {key: value[i] for key, value in batch[1].items()} for i in range(len(batch[0])) ] - return _as_predictions(zip(torch.flatten(batch[0]), metadata)) + + return _as_predictions(zip(batch[0], metadata)) diff --git a/src/mednet/libs/segmentation/scripts/experiment.py b/src/mednet/libs/segmentation/scripts/experiment.py index 0760adde..ccae0eae 100644 --- a/src/mednet/libs/segmentation/scripts/experiment.py +++ b/src/mednet/libs/segmentation/scripts/experiment.py @@ -112,11 +112,11 @@ def experiment( from .predict import predict - predictions_output = output_folder / "predictions.json" + predictions_output = output_folder / "predictions" ctx.invoke( predict, - output=predictions_output, + output_folder=predictions_output, model=model, datamodule=datamodule, device=device, @@ -131,7 +131,7 @@ def experiment( f"Prediction runtime: {predict_stop_timestamp-predict_start_timestamp}" ) - """evaluation_start_timestamp = datetime.now() + evaluation_start_timestamp = datetime.now() logger.info(f"Started evaluation at {evaluation_start_timestamp}") from .evaluate import evaluate @@ -140,14 +140,15 @@ def experiment( evaluate, predictions=predictions_output, output_folder=output_folder, - threshold="validation", + # threshold="validation", + threshold=0.5, ) evaluation_stop_timestamp = datetime.now() logger.info(f"Ended prediction in {evaluation_stop_timestamp}") logger.info( f"Prediction runtime: {evaluation_stop_timestamp-evaluation_start_timestamp}" - )""" + ) experiment_stop_timestamp = datetime.now() logger.info( diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py index 9e0b44b0..b7492f53 100644 --- a/src/mednet/libs/segmentation/scripts/predict.py +++ b/src/mednet/libs/segmentation/scripts/predict.py @@ -3,15 +3,47 @@ # SPDX-License-Identifier: GPL-3.0-or-later +import pathlib + import click +import h5py +import PIL from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.predict import reusable_options +from tqdm import tqdm logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +def _save_hdf5( + stem: pathlib.Path, prob: PIL.Image.Image, output_folder: pathlib.Path +): + """Save prediction maps as image in the same format as the test image. + + Parameters + ---------- + stem + Name of the file without extension on the original dataset. + + prob + Monochrome Image with prediction maps. + + output_folder + Directory in which to store predictions. + """ + + fullpath = output_folder / f"{stem}.hdf5" + tqdm.write(f"Saving {fullpath}...") + fullpath.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(fullpath, "w") as f: + data = prob.squeeze(0).numpy() + f.create_dataset( + "array", data=data, compression="gzip", compression_opts=9 + ) + + @click.command( entry_point_group="mednet.libs.segmentation.config", cls=ConfigCommand, @@ -34,7 +66,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @reusable_options @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def predict( - output, + output_folder, model, datamodule, batch_size, @@ -45,9 +77,6 @@ def predict( ) -> None: # numpydoc ignore=PR01 """Run inference (generates scores) on all input images, using a pre-trained model.""" - import json - import shutil - from mednet.libs.classification.engine.predictor import run from mednet.libs.common.engine.device import DeviceManager from mednet.libs.common.scripts.predict import ( @@ -56,31 +85,15 @@ def predict( setup_datamodule, ) + predictions_file = output_folder / "predictions.json" + setup_datamodule(datamodule, model, batch_size, parallel) model = load_checkpoint(model, weight) device_manager = DeviceManager(device) - save_json_data(datamodule, model, output, device_manager) + save_json_data(datamodule, model, predictions_file, device_manager) predictions = run(model, datamodule, device_manager) - output.parent.mkdir(parents=True, exist_ok=True) - if output.exists(): - backup = output.parent / (output.name + "~") - logger.warning( - f"Output predictions file `{str(output)}` exists - " - f"backing it up to `{str(backup)}`...", - ) - shutil.copy(output, backup) - - # Remove targets from predictions, as they are images and not serializable - # A better solution should be found - serializable_predictions = {} - for split_name, sample in predictions.items(): - split_predictions = [] - for s in sample: - split_predictions.append((s[0], s[2])) - serializable_predictions[split_name] = split_predictions - - with output.open("w") as f: - json.dump(serializable_predictions, f, indent=2) - logger.info(f"Predictions saved to `{str(output)}`") + for split_name, split in predictions.items(): + for sample in split: + _save_hdf5(sample[0], sample[2], output_folder) -- GitLab