Skip to content
Snippets Groups Projects
Commit 3026dc35 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation] Add jsrt database

parent 19394b11
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -453,6 +453,9 @@ hrf = "mednet.libs.segmentation.config.data.hrf.default"
iostar-vessel = "mednet.libs.segmentation.config.data.iostar.vessel"
iostar-disc = "mednet.libs.segmentation.config.data.iostar.optic_disc"
# jsrt - cxr
jsrt = "mednet.libs.segmentation.config.data.jsrt.default"
# montgomery county - cxr
montgomery = "mednet.libs.segmentation.config.data.montgomery.default"
......
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Japanese Society of Radiological Technology dataset for Lung Segmentation."""
import os
import pathlib
import numpy as np
import PIL.Image
import skimage.exposure
import torch
from mednet.libs.common.data.datamodule import CachingDataModule
from mednet.libs.common.data.split import make_split
from mednet.libs.common.data.typing import Sample
from mednet.libs.segmentation.data.typing import (
SegmentationRawDataLoader as _SegmentationRawDataLoader,
)
from torchvision import tv_tensors
from torchvision.transforms.functional import to_tensor
from ....utils.rc import load_rc
CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
"""Key to search for in the configuration file for the root directory of this
database."""
class SegmentationRawDataLoader(_SegmentationRawDataLoader):
"""A specialized raw-data-loader for the drishtigs1jsrt dataset."""
datadir: pathlib.Path
"""This variable contains the base directory where the database raw data is
stored."""
def __init__(self):
self.datadir = pathlib.Path(
load_rc().get(CONFIGURATION_KEY_DATADIR, os.path.realpath(os.curdir))
)
def load_pil_raw_12bit_jsrt(self, path: pathlib.Path) -> PIL.Image.Image:
"""Load a raw 16-bit sample data.
This method was designed to handle the raw images from the JSRT_ dataset.
It reads the data file and applies a simple histogram equalization to the
8-bit representation of the image to obtain something along the lines of
the PNG (unofficial) version distributed at `JSRT-Kaggle`_.
Parameters
----------
path
The full path leading to the image to be loaded.
Returns
-------
A PIL image in RGB mode, with `width`x`width` pixels.
"""
raw_image = np.fromfile(path, np.dtype(">u2")).reshape(2048, 2048)
raw_image[raw_image > 4095] = 4095
raw_image = 4095 - raw_image # invert colors
raw_image = (raw_image >> 4).astype(np.uint8) # 8-bit uint
raw_image = skimage.exposure.equalize_hist(raw_image)
return PIL.Image.fromarray((raw_image * 255).astype(np.uint8)).convert("RGB")
def sample(self, sample: tuple[str, str, str]) -> Sample:
"""Load a single image sample from the disk.
Parameters
----------
sample
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
The sample representation.
"""
image = to_tensor(self.load_pil_raw_12bit_jsrt(self.datadir / sample[0]))
# Combine left and right lung masks into a single tensor
target = tv_tensors.Image(
to_tensor(
np.ma.mask_or(
np.asarray(
PIL.Image.open(self.datadir / sample[1]).convert(
mode="1", dither=None
)
),
np.asarray(
PIL.Image.open(self.datadir / sample[2]).convert(
mode="1", dither=None
)
),
)
).float()
)
tensor = tv_tensors.Image(image)
target = tv_tensors.Image(target)
mask = tv_tensors.Mask(torch.ones_like(target))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
class DataModule(CachingDataModule):
"""Japanese Society of Radiological Technology dataset for Lung Segmentation.
The database includes 154 nodule and 93 non-nodule images. It contains a total
of 247 resolution of 2048 x 2048. One set of ground-truth lung annotations is
available.
* Reference: [JSRT-2000]_
* Original resolution (height x width): 2048 x 2048
* Configuration resolution: 1024 x 1024 (after rescaling)
* Split reference: [GAAL-2020]_
* Protocol ``default``:
* Training samples: 172 (including labels)
* Validation samples: 25 (including labels)
* Test samples: 50 (including labels)
Parameters
----------
split_filename
Name of the .json file containing the split to load.
"""
def __init__(self, split_filename: str):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=SegmentationRawDataLoader(),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Japanese Society of Radiological Technology dataset for Lung Segmentation
(default protocol).
* Split reference: [GAAL-2020]_
* Configuration resolution: 256 x 256
* See :py:mod:`deepdraw.data.jsrt` for dataset details
"""
from mednet.libs.segmentation.config.data.jsrt.datamodule import (
DataModule,
)
datamodule = DataModule("default.json")
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for jsrt dataset."""
import importlib
import pytest
from click.testing import CliRunner
def id_function(val):
if isinstance(val, dict):
return str(val)
return repr(val)
@pytest.mark.parametrize(
"split,lengths",
[
("default", dict(train=172, validation=25, test=50)),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers,
split: str,
lengths: dict[str, int],
):
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.jsrt", f"{split}.json"),
lengths=lengths,
)
@pytest.mark.skip_if_rc_var_not_set("datadir.jsrt")
def test_database_check():
from mednet.libs.segmentation.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["jsrt"])
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
@pytest.mark.skip_if_rc_var_not_set("datadir.jsrt")
@pytest.mark.parametrize(
"dataset",
[
"train",
"validation",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"default",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module(
f".{name}",
"mednet.libs.segmentation.config.data.jsrt",
).datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.jsrt")
def test_raw_transforms_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_jsrt_default.json",
)
datamodule = importlib.import_module(
".default",
"mednet.libs.segmentation.config.data.jsrt",
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
@pytest.mark.skip_if_rc_var_not_set("datadir.jsrt")
@pytest.mark.parametrize(
"model_name",
["lwnet"],
)
def test_model_transforms_image_quality(database_checkers, datadir, model_name):
reference_histogram_file = str(
datadir / f"histograms/models/histograms_{model_name}_jsrt_default.json",
)
datamodule = importlib.import_module(
".default",
"mednet.libs.segmentation.config.data.jsrt",
).datamodule
model = importlib.import_module(
f".{model_name}",
"mednet.libs.segmentation.config.models",
).model
datamodule.model_transforms = model.model_transforms
datamodule.setup("predict")
database_checkers.check_image_quality(
datamodule,
reference_histogram_file,
compare_type="statistical",
pearson_coeff_threshold=0.005,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment