diff --git a/src/mednet/libs/segmentation/engine/dumper.py b/src/mednet/libs/segmentation/engine/dumper.py new file mode 100644 index 0000000000000000000000000000000000000000..a4822e1c2a751f0c28bb482e26630ac95b9ebac0 --- /dev/null +++ b/src/mednet/libs/segmentation/engine/dumper.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import pathlib +import typing + +import h5py +import lightning.pytorch +import torch.utils.data +import tqdm + +logger = logging.getLogger(__name__) + + +def run( + datamodule: lightning.pytorch.LightningDataModule, + output_folder: pathlib.Path, +) -> ( + dict[str, list[tuple[str, str]]] + | list[list[tuple[str, str]]] + | list[tuple[str, str]] + | None +): + """Dump annotations from input datamodule. + + Parameters + ---------- + datamodule + The lightning DataModule to extract annotations from. + output_folder + Folder where to store HDF5 representations of annotations. + + Returns + ------- + A JSON-able representation of sample data stored at ``output_folder``. + For every split (dataloader), a list of samples in the form + ``[sample-name, hdf5-path]`` is returned. In the cases where the + ``predict_dataloader()`` returns a single loader, we then return a + list. A dictionary is returned in case ``predict_dataloader()`` also + returns a dictionary. + + Raises + ------ + TypeError + If the DataModule's ``predict_dataloader()`` method does not return any + of the types described above. + """ + + def _write_sample( + sample: typing.Any, output_folder: pathlib.Path + ) -> tuple[str, str]: + """Write a single sample target to an HDF5 file. + + Parameters + ---------- + sample + A segmentation sample as output by a dataloader. + output_folder + Path leading to a folder where to store dumped annotations. + + Returns + ------- + A tuple which contains the sample path and the destination + directory where the HDF5 file was saved. + """ + name = sample[1]["name"][0] + target = sample[1]["target"][0] + stem = pathlib.Path(name).with_suffix(".hdf5") + dest = output_folder / stem + tqdm.tqdm.write(f"`{name}` -> `{str(dest)}`") + dest.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(dest, "w") as f: + f.create_dataset( + "target", + data=(target.cpu().numpy() > 0.5), + compression="gzip", + compression_opts=9, + ) + return (name, str(stem)) + + dataloaders = datamodule.predict_dataloader() + + if isinstance(dataloaders, torch.utils.data.DataLoader): + logger.info("Dump annotations from a single dataloader...") + return [_write_sample(k, output_folder) for k in tqdm.tqdm(dataloaders)] + + if isinstance(dataloaders, list): + retval_list = [] + for k, dataloader in enumerate(dataloaders): + logger.info(f"Dumping annotations from split `{k}`...") + retval_list.append( + [_write_sample(k, output_folder) for k in tqdm.tqdm(dataloader)] + ) + return retval_list + + if isinstance(dataloaders, dict): + retval_dict = {} + for name, dataloader in dataloaders.items(): + logger.info(f"Dumping annotations from split `{name}`...") + retval_dict[name] = [ + _write_sample(k, output_folder) for k in tqdm.tqdm(dataloader) + ] + return retval_dict + + if dataloaders is None: + logger.warning("Datamodule did not return any prediction dataloaders!") + return None + + # if you get to this point, then the user is returning something that is + # not supported - complain! + raise TypeError( + f"Datamodule returned strangely typed prediction " + f"dataloaders: `{type(dataloaders)}` - Please write code " + f"to support this use-case.", + ) diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py index f2841daf15263d8c7d0c376c750d5e41303d5083..b83721701810280ee155318ae8fb443428d72a1b 100644 --- a/src/mednet/libs/segmentation/scripts/cli.py +++ b/src/mednet/libs/segmentation/scripts/cli.py @@ -11,6 +11,7 @@ from . import ( # analyze, config, database, + dump_annotations, evaluate, predict, train, @@ -40,6 +41,7 @@ segmentation.add_command( ).train_analysis, ) segmentation.add_command(view.view) +segmentation.add_command(dump_annotations.dump_annotations) segmentation.add_command( importlib.import_module("..experiment", package=__name__).experiment, ) diff --git a/src/mednet/libs/segmentation/scripts/dump_annotations.py b/src/mednet/libs/segmentation/scripts/dump_annotations.py new file mode 100644 index 0000000000000000000000000000000000000000..d199ccdc7540d0d3112a0b9cd41a372d687dfdf8 --- /dev/null +++ b/src/mednet/libs/segmentation/scripts/dump_annotations.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pathlib + +import clapper.click +import clapper.logging +import click +import mednet.libs.common.scripts.click + +logger = clapper.logging.setup( + __name__.split(".")[0], format="%(levelname)s: %(message)s" +) + + +@click.command( + entry_point_group="mednet.libs.segmentation.config", + cls=mednet.libs.common.scripts.click.ConfigCommand, + epilog="""Examples: + +1. Dump annotations for a dataset after pre-processing on a particular directory: + + .. code:: sh + + mednet segmentation dump-annotations -vv lwnet drive-2nd --output-folder=path/to/annotations + +""", +) +@click.option( + "--output-folder", + "-o", + help="Directory in which to save predictions (created if does not exist)", + required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), + default="predictions", + cls=clapper.click.ResourceOption, +) +@click.option( + "--model", + "-m", + help="""A lightning module instance that will be used to retrieve + pre-processing transforms.""", + required=True, + cls=clapper.click.ResourceOption, +) +@click.option( + "--datamodule", + "-d", + help="""A lightning DataModule that will be asked for prediction data + loaders. Typically, this includes all configured splits in a DataModule, + however this is not a requirement. A DataModule that returns a single + dataloader for prediction (wrapped in a dictionary) is acceptable.""", + required=True, + cls=clapper.click.ResourceOption, +) +@clapper.click.verbosity_option( + logger=logger, cls=clapper.click.ResourceOption, expose_value=False +) +def dump_annotations( + output_folder, model, datamodule, **_ +) -> None: # numpydoc ignore=PR01 + """Dump annotations in a given folder, after pre-processing.""" + + from mednet.libs.common.scripts.predict import ( + setup_datamodule, + ) + from mednet.libs.common.scripts.utils import save_json_with_backup + from mednet.libs.segmentation.engine.dumper import run + + setup_datamodule(datamodule, model, batch_size=1, parallel=-1) + + json_data = run(datamodule, output_folder) + + base_file = output_folder / "annotations.json" + save_json_with_backup(base_file, json_data) + logger.info(f"Annotations saved to `{str(base_file)}`")