Skip to content
Snippets Groups Projects
Commit a0465b26 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation] Add cxr8 database

parent 9eacc051
No related branches found
No related tags found
1 merge request!46Create common library
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""ChestX-ray8: Hospital-scale Chest X-ray Database."""
import os
import pathlib
import numpy as np
import PIL.Image
import torch
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
)
from torchvision import tv_tensors
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
"""Key to search for in the configuration file for the root directory of this
database."""
CONFIGURATION_KEY_IDIAP_FILESTRUCTURE = (
(__name__.rsplit(".", 2)[-2]) + ".idiap_folder_structure"
)
"""Key to search for in the configuration file indicating if the loader should
use standard or idiap-based file organisation structure.
It causes the internal loader to search for files in a slightly
different folder structure, that was adapted to Idiap's requirements
(number of files per folder to be less than 10k).
"""
class SegmentationRawDataLoader(_SegmentationRawDataLoader):
"""A specialized raw-data-loader for the cxr8 dataset."""
datadir: pathlib.Path
"""This variable contains the base directory where the database raw data is
stored."""
def __init__(self):
self.datadir = pathlib.Path(
load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir))
)
self.idiap_file_organisation = load_rc().get(
CONFIGURATION_KEY_IDIAP_FILESTRUCTURE,
False,
)
def sample(self, sample: tuple[str, str, str]) -> Sample:
"""Load a single image sample from the disk.
Parameters
----------
sample
A tuple containing path suffixes to the sample image, target, and mask
to be loaded, within the dataset root folder.
Returns
-------
The sample representation.
"""
file_path = pathlib.Path(sample[0])
label_path = pathlib.Path(sample[1])
if self.idiap_file_organisation:
sample_parts = sample[0].split("/", 1)
file_path = pathlib.Path(
sample_parts[0] + "/" + sample_parts[1][:5] + "/" + sample_parts[1]
)
label_parts = sample[1].split("/", 1)
label_path = pathlib.Path(
pathlib.Path(
label_parts[0] + "/" + label_parts[1][:5] + "/" + label_parts[1]
)
)
tensor = to_tensor(PIL.Image.open(self.datadir / file_path).convert(mode="RGB"))
target = np.array(
PIL.Image.open(
pathlib.Path(str(self.datadir) + "-segmentations") / label_path
)
)
target = np.where(target == 255, 0, target)
target = to_tensor(PIL.Image.fromarray(np.array(target > 0)))
tensor = tv_tensors.Image(tensor)
target = tv_tensors.Image(target)
mask = tv_tensors.Mask(torch.ones_like(target))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
class DataModule(CachingDataModule):
"""The database contains a total of 112120 images. Image size for each X-ray is
1024 x 1024. One set of mask annotations is available for all images.
* Reference: [CXR8-2017]_
* Original resolution (height x width): 1024 x 1024
* Configuration resolution: 256 x 256 (after rescaling)
* Split reference: [GAAL-2020]_
* Protocol ``default``:
* Training samples: 78484 (including labels)
* Validation samples: 11212 (including labels)
* Test samples: 22424 (including labels)
Parameters
----------
split_filename
Name of the .json file containing the split to load.
"""
def __init__(self, split_filename: str):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=SegmentationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
Source diff could not be displayed: it is stored in LFS. Options to address this: view the blob.
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""CXR8 dataset for Vessel Segmentation (default protocol).
* Split reference: [CXR8-2004]_
* This configuration resolution: 544 x 544 (center-crop)
* See :py:mod:`deepdraw.data.cxr8` for dataset details
* This dataset offers a second-annotator comparison for the test set only
"""
from mednet.libs.segmentation.config.data.cxr8.datamodule import DataModule
datamodule = DataModule("default.json")
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for cxr8 dataset."""
import importlib
import pytest
from click.testing import CliRunner
def id_function(val):
if isinstance(val, dict):
return str(val)
return repr(val)
@pytest.mark.parametrize(
"split,lengths",
[
("default", dict(train=78484, validation=11212, test=22424)),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers,
split: str,
lengths: dict[str, int],
):
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.cxr8", f"{split}.json"),
lengths=lengths,
)
@pytest.mark.skip_if_rc_var_not_set("datadir.cxr8")
def test_database_check():
from mednet.libs.segmentation.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["--limit=10", "cxr8"])
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
@pytest.mark.skip_if_rc_var_not_set("datadir.cxr8")
@pytest.mark.parametrize(
"dataset",
[
"train",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"default",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module(
f".{name}",
"mednet.libs.segmentation.config.data.cxr8",
).datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.cxr8")
def test_raw_transforms_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_cxr8_default.json",
)
datamodule = importlib.import_module(
".default",
"mednet.libs.segmentation.config.data.cxr8",
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
@pytest.mark.skip_if_rc_var_not_set("datadir.cxr8")
@pytest.mark.parametrize(
"model_name",
["lwnet"],
)
def test_model_transforms_image_quality(database_checkers, datadir, model_name):
reference_histogram_file = str(
datadir / f"histograms/models/histograms_{model_name}_cxr8_default.json",
)
datamodule = importlib.import_module(
".default",
"mednet.libs.segmentation.config.data.cxr8",
).datamodule
model = importlib.import_module(
f".{model_name}",
"mednet.libs.segmentation.config.models",
).model
datamodule.model_transforms = model.model_transforms
datamodule.setup("predict")
database_checkers.check_image_quality(
datamodule,
reference_histogram_file,
compare_type="statistical",
pearson_coeff_threshold=0.005,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment