Skip to content
Snippets Groups Projects
Commit 442c8777 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmetnation] Add hrf database

parent 5438fb26
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -446,6 +446,9 @@ drishtigs1-disc-any = "mednet.libs.segmentation.config.data.drishtigs1.optic_dis ...@@ -446,6 +446,9 @@ drishtigs1-disc-any = "mednet.libs.segmentation.config.data.drishtigs1.optic_dis
drishtigs1-cup-all = "mednet.libs.segmentation.config.data.drishtigs1.optic_cup_all" drishtigs1-cup-all = "mednet.libs.segmentation.config.data.drishtigs1.optic_cup_all"
drishtigs1-cup-any = "mednet.libs.segmentation.config.data.drishtigs1.optic_cup_any" drishtigs1-cup-any = "mednet.libs.segmentation.config.data.drishtigs1.optic_cup_any"
# hrf - retinography
hrf = "mednet.libs.segmentation.config.data.hrf.default"
# iostar - retinography # iostar - retinography
iostar-vessel = "mednet.libs.segmentation.config.data.iostar.vessel" iostar-vessel = "mednet.libs.segmentation.config.data.iostar.vessel"
iostar-disc = "mednet.libs.segmentation.config.data.iostar.optic_disc" iostar-disc = "mednet.libs.segmentation.config.data.iostar.optic_disc"
......
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""HRF dataset for Vessel Segmentation."""
import os
import pathlib
import PIL.Image
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from mednet.libs.common.models.transforms import crop_image_to_mask
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
)
from torchvision import tv_tensors
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
"""Key to search for in the configuration file for the root directory of this
database."""
class SegmentationRawDataLoader(_SegmentationRawDataLoader):
"""A specialized raw-data-loader for the drishtigs1hrf dataset."""
datadir: pathlib.Path
"""This variable contains the base directory where the database raw data is
stored."""
def __init__(self):
self.datadir = pathlib.Path(
load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir))
)
def sample(self, sample: tuple[str, str, str]) -> Sample:
"""Load a single image sample from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
The sample representation.
"""
image = to_tensor(PIL.Image.open(self.datadir / sample[0]).convert(mode="RGB"))
target = to_tensor(
PIL.Image.open(self.datadir / sample[1]).convert(mode="1", dither=None)
)
mask = to_tensor(
PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
)
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Image(crop_image_to_mask(target, mask))
mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
class DataModule(CachingDataModule):
"""HRF dataset for Vessel Segmentation.
The database includes 15 images of each healthy, diabetic retinopathy (DR), and
glaucomatous eyes. It contains a total of 45 eye fundus images with a
resolution of 3304 x 2336. One set of ground-truth vessel annotations is
available.
* Reference: [HRF-2013]_
* Original resolution (height x width): 2336 x 3504
* Configuration resolution: 1168 x 1648 (after specific cropping and rescaling)
* Split reference: [ORLANDO-2017]_
* Protocol ``default``:
* Training samples: 15 (including labels)
* Test samples: 30 (including labels)
Parameters
----------
split_filename
Name of the .json file containing the split to load.
"""
def __init__(self, split_filename: str):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=SegmentationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
{
"train": [
[
"images/01_dr.JPG",
"manual1/01_dr.tif",
"mask/01_dr_mask.tif"
],
[
"images/02_dr.JPG",
"manual1/02_dr.tif",
"mask/02_dr_mask.tif"
],
[
"images/03_dr.JPG",
"manual1/03_dr.tif",
"mask/03_dr_mask.tif"
],
[
"images/04_dr.JPG",
"manual1/04_dr.tif",
"mask/04_dr_mask.tif"
],
[
"images/05_dr.JPG",
"manual1/05_dr.tif",
"mask/05_dr_mask.tif"
],
[
"images/01_g.jpg",
"manual1/01_g.tif",
"mask/01_g_mask.tif"
],
[
"images/02_g.jpg",
"manual1/02_g.tif",
"mask/02_g_mask.tif"
],
[
"images/03_g.jpg",
"manual1/03_g.tif",
"mask/03_g_mask.tif"
],
[
"images/04_g.jpg",
"manual1/04_g.tif",
"mask/04_g_mask.tif"
],
[
"images/05_g.jpg",
"manual1/05_g.tif",
"mask/05_g_mask.tif"
],
[
"images/01_h.jpg",
"manual1/01_h.tif",
"mask/01_h_mask.tif"
],
[
"images/02_h.jpg",
"manual1/02_h.tif",
"mask/02_h_mask.tif"
],
[
"images/03_h.jpg",
"manual1/03_h.tif",
"mask/03_h_mask.tif"
],
[
"images/04_h.jpg",
"manual1/04_h.tif",
"mask/04_h_mask.tif"
],
[
"images/05_h.jpg",
"manual1/05_h.tif",
"mask/05_h_mask.tif"
]
],
"test": [
[
"images/06_dr.JPG",
"manual1/06_dr.tif",
"mask/06_dr_mask.tif"
],
[
"images/07_dr.JPG",
"manual1/07_dr.tif",
"mask/07_dr_mask.tif"
],
[
"images/08_dr.JPG",
"manual1/08_dr.tif",
"mask/08_dr_mask.tif"
],
[
"images/09_dr.JPG",
"manual1/09_dr.tif",
"mask/09_dr_mask.tif"
],
[
"images/10_dr.JPG",
"manual1/10_dr.tif",
"mask/10_dr_mask.tif"
],
[
"images/11_dr.JPG",
"manual1/11_dr.tif",
"mask/11_dr_mask.tif"
],
[
"images/12_dr.JPG",
"manual1/12_dr.tif",
"mask/12_dr_mask.tif"
],
[
"images/13_dr.JPG",
"manual1/13_dr.tif",
"mask/13_dr_mask.tif"
],
[
"images/14_dr.JPG",
"manual1/14_dr.tif",
"mask/14_dr_mask.tif"
],
[
"images/15_dr.JPG",
"manual1/15_dr.tif",
"mask/15_dr_mask.tif"
],
[
"images/06_g.jpg",
"manual1/06_g.tif",
"mask/06_g_mask.tif"
],
[
"images/07_g.jpg",
"manual1/07_g.tif",
"mask/07_g_mask.tif"
],
[
"images/08_g.jpg",
"manual1/08_g.tif",
"mask/08_g_mask.tif"
],
[
"images/09_g.jpg",
"manual1/09_g.tif",
"mask/09_g_mask.tif"
],
[
"images/10_g.jpg",
"manual1/10_g.tif",
"mask/10_g_mask.tif"
],
[
"images/11_g.jpg",
"manual1/11_g.tif",
"mask/11_g_mask.tif"
],
[
"images/12_g.jpg",
"manual1/12_g.tif",
"mask/12_g_mask.tif"
],
[
"images/13_g.jpg",
"manual1/13_g.tif",
"mask/13_g_mask.tif"
],
[
"images/14_g.jpg",
"manual1/14_g.tif",
"mask/14_g_mask.tif"
],
[
"images/15_g.jpg",
"manual1/15_g.tif",
"mask/15_g_mask.tif"
],
[
"images/06_h.jpg",
"manual1/06_h.tif",
"mask/06_h_mask.tif"
],
[
"images/07_h.jpg",
"manual1/07_h.tif",
"mask/07_h_mask.tif"
],
[
"images/08_h.jpg",
"manual1/08_h.tif",
"mask/08_h_mask.tif"
],
[
"images/09_h.jpg",
"manual1/09_h.tif",
"mask/09_h_mask.tif"
],
[
"images/10_h.jpg",
"manual1/10_h.tif",
"mask/10_h_mask.tif"
],
[
"images/11_h.jpg",
"manual1/11_h.tif",
"mask/11_h_mask.tif"
],
[
"images/12_h.jpg",
"manual1/12_h.tif",
"mask/12_h_mask.tif"
],
[
"images/13_h.jpg",
"manual1/13_h.tif",
"mask/13_h_mask.tif"
],
[
"images/14_h.jpg",
"manual1/14_h.tif",
"mask/14_h_mask.tif"
],
[
"images/15_h.jpg",
"manual1/15_h.tif",
"mask/15_h_mask.tif"
]
]
}
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""HRF dataset for Vessel Segmentation (default protocol).
* Split reference: [ORLANDO-2017]_
* Configuration resolution: 1168 x 1648 (about half full HRF resolution)
* See :py:mod:`deepdraw.data.hrf` for dataset details
"""
from mednet.libs.segmentation.config.data.hrf.datamodule import (
DataModule,
)
datamodule = DataModule("default.json")
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for hrf dataset."""
import importlib
import pytest
from click.testing import CliRunner
def id_function(val):
if isinstance(val, dict):
return str(val)
return repr(val)
@pytest.mark.parametrize(
"split,lengths",
[
("default", dict(train=15, test=30)),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers,
split: str,
lengths: dict[str, int],
):
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.hrf", f"{split}.json"),
lengths=lengths,
)
@pytest.mark.skip_if_rc_var_not_set("datadir.hrf")
def test_database_check():
from mednet.libs.segmentation.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["hrf"])
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
@pytest.mark.skip_if_rc_var_not_set("datadir.hrf")
@pytest.mark.parametrize(
"dataset",
[
"train",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"default",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module(
f".{name}",
"mednet.libs.segmentation.config.data.hrf",
).datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.hrf")
def test_raw_transforms_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_hrf_default.json",
)
datamodule = importlib.import_module(
".default",
"mednet.libs.segmentation.config.data.hrf",
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
@pytest.mark.skip_if_rc_var_not_set("datadir.hrf")
@pytest.mark.parametrize(
"model_name",
["lwnet"],
)
def test_model_transforms_image_quality(database_checkers, datadir, model_name):
reference_histogram_file = str(
datadir / f"histograms/models/histograms_{model_name}_hrf_default.json",
)
datamodule = importlib.import_module(
".default",
"mednet.libs.segmentation.config.data.hrf",
).datamodule
model = importlib.import_module(
f".{model_name}",
"mednet.libs.segmentation.config.models",
).model
datamodule.model_transforms = model.model_transforms
datamodule.setup("predict")
database_checkers.check_image_quality(
datamodule,
reference_histogram_file,
compare_type="statistical",
pearson_coeff_threshold=0.005,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment