Skip to content
Snippets Groups Projects
Commit 07080b2e authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation] Add shenzhen database

parent 3666d6e7
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -465,6 +465,9 @@ rimoner3-disc = "mednet.libs.segmentation.config.data.rimoner3.disc_exp1" ...@@ -465,6 +465,9 @@ rimoner3-disc = "mednet.libs.segmentation.config.data.rimoner3.disc_exp1"
rimoner3-cup-2nd = "mednet.libs.segmentation.config.data.rimoner3.cup_exp2" rimoner3-cup-2nd = "mednet.libs.segmentation.config.data.rimoner3.cup_exp2"
rimoner3-disc-2nd = "mednet.libs.segmentation.config.data.rimoner3.disc_exp2" rimoner3-disc-2nd = "mednet.libs.segmentation.config.data.rimoner3.disc_exp2"
# shenzhen - cxr
shenzhen = "mednet.libs.segmentation.config.data.shenzhen.default"
# stare dataset - retinography # stare dataset - retinography
stare = "mednet.libs.segmentation.config.data.stare.ah" stare = "mednet.libs.segmentation.config.data.stare.ah"
stare-2nd = "mednet.libs.segmentation.config.data.stare.vk" stare-2nd = "mednet.libs.segmentation.config.data.stare.vk"
......
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen No.3 People’s Hospital dataset for Lung Segmentation."""
import os
import pathlib
import PIL.Image
import torch
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
)
from torchvision import tv_tensors
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
"""Key to search for in the configuration file for the root directory of this
database."""
class SegmentationRawDataLoader(_SegmentationRawDataLoader):
"""A specialized raw-data-loader for the shenzhen dataset."""
datadir: pathlib.Path
"""This variable contains the base directory where the database raw data is
stored."""
def __init__(self):
self.datadir = pathlib.Path(
load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir))
)
def sample(self, sample: tuple[str, str, str]) -> Sample:
"""Load a single image sample from the disk.
Parameters
----------
sample
A tuple containing path suffixes to the sample image, target, and mask
to be loaded, within the dataset root folder.
Returns
-------
The sample representation.
"""
image = to_tensor(PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB"))
target = to_tensor(
PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None)
)
tensor = tv_tensors.Image(image)
target = tv_tensors.Image(target)
mask = tv_tensors.Mask(torch.ones_like(target))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
class DataModule(CachingDataModule):
"""Shenzhen No.3 People’s Hospital dataset for Lung Segmentation.
The database includes 336 cases with manifestation of tuberculosis, and 326
normal cases. It contains a total of 662 images. Image size varies for each
X-ray. It is approximately 3K x 3K. One set of ground-truth lung annotations is
available for 566 of the 662 images.
* Reference: [SHENZHEN-2014]_
* Original resolution (height x width): Approximately 3K x 3K (varies)
* Configuration resolution: 512 x 512 (after rescaling)
* Split reference: [GAAL-2020]_
* Protocol ``default``:
* Training samples: 396 (including labels)
* Validation samples: 56 (including labels)
* Test samples: 114 (including labels)
Parameters
----------
split_filename
Name of the .json file containing the split to load.
"""
def __init__(self, split_filename: str):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=SegmentationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for Lung Segmentation (default protocol).
* Split reference: [GAAL-2020]_
* Configuration resolution: 256 x 256
* See :py:mod:`deepdraw.data.shenzhen` for dataset details
"""
from mednet.libs.segmentation.config.data.shenzhen.datamodule import DataModule
datamodule = DataModule("default.json")
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for shenzhen dataset."""
import importlib
import pytest
from click.testing import CliRunner
def id_function(val):
if isinstance(val, dict):
return str(val)
return repr(val)
@pytest.mark.parametrize(
"split,lengths",
[
("default", dict(train=396, validation=56, test=114)),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers,
split: str,
lengths: dict[str, int],
):
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.shenzhen", f"{split}.json"),
lengths=lengths,
)
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_database_check():
from mednet.libs.segmentation.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["--limit=20", "shenzhen"])
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
@pytest.mark.parametrize(
"dataset",
[
"train",
"validation",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"default",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module(
f".{name}",
"mednet.libs.segmentation.config.data.shenzhen",
).datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_raw_transforms_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_shenzhen_default.json",
)
datamodule = importlib.import_module(
".default",
"mednet.libs.segmentation.config.data.shenzhen",
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
@pytest.mark.parametrize(
"model_name",
["lwnet"],
)
def test_model_transforms_image_quality(database_checkers, datadir, model_name):
reference_histogram_file = str(
datadir / f"histograms/models/histograms_{model_name}_shenzhen_default.json",
)
datamodule = importlib.import_module(
".default",
"mednet.libs.segmentation.config.data.shenzhen",
).datamodule
model = importlib.import_module(
f".{model_name}",
"mednet.libs.segmentation.config.models",
).model
datamodule.model_transforms = model.model_transforms
datamodule.setup("predict")
database_checkers.check_image_quality(
datamodule,
reference_histogram_file,
compare_type="statistical",
pearson_coeff_threshold=0.005,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment