Skip to content
Snippets Groups Projects
Commit 61324bd9 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[libs.segmentation.scripts.dump_annotations] Implement routine to dump...

[libs.segmentation.scripts.dump_annotations] Implement routine to dump annotations from an arbitrary datamodule
parent 692a1be6
No related branches found
No related tags found
1 merge request!46Create common library
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import pathlib
import typing
import h5py
import lightning.pytorch
import torch.utils.data
import tqdm
logger = logging.getLogger(__name__)
def run(
datamodule: lightning.pytorch.LightningDataModule,
output_folder: pathlib.Path,
) -> (
dict[str, list[tuple[str, str]]]
| list[list[tuple[str, str]]]
| list[tuple[str, str]]
| None
):
"""Dump annotations from input datamodule.
Parameters
----------
datamodule
The lightning DataModule to extract annotations from.
output_folder
Folder where to store HDF5 representations of annotations.
Returns
-------
A JSON-able representation of sample data stored at ``output_folder``.
For every split (dataloader), a list of samples in the form
``[sample-name, hdf5-path]`` is returned. In the cases where the
``predict_dataloader()`` returns a single loader, we then return a
list. A dictionary is returned in case ``predict_dataloader()`` also
returns a dictionary.
Raises
------
TypeError
If the DataModule's ``predict_dataloader()`` method does not return any
of the types described above.
"""
def _write_sample(
sample: typing.Any, output_folder: pathlib.Path
) -> tuple[str, str]:
"""Write a single sample target to an HDF5 file.
Parameters
----------
sample
A segmentation sample as output by a dataloader.
output_folder
Path leading to a folder where to store dumped annotations.
Returns
-------
A tuple which contains the sample path and the destination
directory where the HDF5 file was saved.
"""
name = sample[1]["name"][0]
target = sample[1]["target"][0]
stem = pathlib.Path(name).with_suffix(".hdf5")
dest = output_folder / stem
tqdm.tqdm.write(f"`{name}` -> `{str(dest)}`")
dest.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(dest, "w") as f:
f.create_dataset(
"target",
data=(target.cpu().numpy() > 0.5),
compression="gzip",
compression_opts=9,
)
return (name, str(stem))
dataloaders = datamodule.predict_dataloader()
if isinstance(dataloaders, torch.utils.data.DataLoader):
logger.info("Dump annotations from a single dataloader...")
return [_write_sample(k, output_folder) for k in tqdm.tqdm(dataloaders)]
if isinstance(dataloaders, list):
retval_list = []
for k, dataloader in enumerate(dataloaders):
logger.info(f"Dumping annotations from split `{k}`...")
retval_list.append(
[_write_sample(k, output_folder) for k in tqdm.tqdm(dataloader)]
)
return retval_list
if isinstance(dataloaders, dict):
retval_dict = {}
for name, dataloader in dataloaders.items():
logger.info(f"Dumping annotations from split `{name}`...")
retval_dict[name] = [
_write_sample(k, output_folder) for k in tqdm.tqdm(dataloader)
]
return retval_dict
if dataloaders is None:
logger.warning("Datamodule did not return any prediction dataloaders!")
return None
# if you get to this point, then the user is returning something that is
# not supported - complain!
raise TypeError(
f"Datamodule returned strangely typed prediction "
f"dataloaders: `{type(dataloaders)}` - Please write code "
f"to support this use-case.",
)
...@@ -11,6 +11,7 @@ from . import ( ...@@ -11,6 +11,7 @@ from . import (
# analyze, # analyze,
config, config,
database, database,
dump_annotations,
evaluate, evaluate,
predict, predict,
train, train,
...@@ -40,6 +41,7 @@ segmentation.add_command( ...@@ -40,6 +41,7 @@ segmentation.add_command(
).train_analysis, ).train_analysis,
) )
segmentation.add_command(view.view) segmentation.add_command(view.view)
segmentation.add_command(dump_annotations.dump_annotations)
segmentation.add_command( segmentation.add_command(
importlib.import_module("..experiment", package=__name__).experiment, importlib.import_module("..experiment", package=__name__).experiment,
) )
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import clapper.click
import clapper.logging
import click
import mednet.libs.common.scripts.click
logger = clapper.logging.setup(
__name__.split(".")[0], format="%(levelname)s: %(message)s"
)
@click.command(
entry_point_group="mednet.libs.segmentation.config",
cls=mednet.libs.common.scripts.click.ConfigCommand,
epilog="""Examples:
1. Dump annotations for a dataset after pre-processing on a particular directory:
.. code:: sh
mednet segmentation dump-annotations -vv lwnet drive-2nd --output-folder=path/to/annotations
""",
)
@click.option(
"--output-folder",
"-o",
help="Directory in which to save predictions (created if does not exist)",
required=True,
type=click.Path(
file_okay=False,
dir_okay=True,
writable=True,
path_type=pathlib.Path,
),
default="predictions",
cls=clapper.click.ResourceOption,
)
@click.option(
"--model",
"-m",
help="""A lightning module instance that will be used to retrieve
pre-processing transforms.""",
required=True,
cls=clapper.click.ResourceOption,
)
@click.option(
"--datamodule",
"-d",
help="""A lightning DataModule that will be asked for prediction data
loaders. Typically, this includes all configured splits in a DataModule,
however this is not a requirement. A DataModule that returns a single
dataloader for prediction (wrapped in a dictionary) is acceptable.""",
required=True,
cls=clapper.click.ResourceOption,
)
@clapper.click.verbosity_option(
logger=logger, cls=clapper.click.ResourceOption, expose_value=False
)
def dump_annotations(
output_folder, model, datamodule, **_
) -> None: # numpydoc ignore=PR01
"""Dump annotations in a given folder, after pre-processing."""
from mednet.libs.common.scripts.predict import (
setup_datamodule,
)
from mednet.libs.common.scripts.utils import save_json_with_backup
from mednet.libs.segmentation.engine.dumper import run
setup_datamodule(datamodule, model, batch_size=1, parallel=-1)
json_data = run(datamodule, output_folder)
base_file = output_folder / "annotations.json"
save_json_with_backup(base_file, json_data)
logger.info(f"Annotations saved to `{str(base_file)}`")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment