Skip to content
Snippets Groups Projects
Commit 7be6a4ee authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.nih_cxr14_re] Update datamodule; Prepare framework for multi-class classification

parent fb3ccf06
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -106,7 +106,7 @@ class _DelayedLoadingDataset(Dataset): ...@@ -106,7 +106,7 @@ class _DelayedLoadingDataset(Dataset):
sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
def labels(self) -> list[int]: def labels(self) -> list[int | list[int]]:
"""Returns the integer labels for all samples in the dataset.""" """Returns the integer labels for all samples in the dataset."""
return [self.loader.label(k) for k in self.raw_dataset] return [self.loader.label(k) for k in self.raw_dataset]
...@@ -223,7 +223,7 @@ class _CachedDataset(Dataset): ...@@ -223,7 +223,7 @@ class _CachedDataset(Dataset):
f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb" f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb"
) )
def labels(self) -> list[int]: def labels(self) -> list[int | list[int]]:
"""Returns the integer labels for all samples in the dataset.""" """Returns the integer labels for all samples in the dataset."""
return [k[1]["label"] for k in self.data] return [k[1]["label"] for k in self.data]
...@@ -256,7 +256,7 @@ class _ConcatDataset(Dataset): ...@@ -256,7 +256,7 @@ class _ConcatDataset(Dataset):
for j in range(len(datasets[i])) for j in range(len(datasets[i]))
] ]
def labels(self) -> list[int]: def labels(self) -> list[int | list[int]]:
"""Returns the integer labels for all samples in the dataset.""" """Returns the integer labels for all samples in the dataset."""
return list(itertools.chain(*[k.labels() for k in self._datasets])) return list(itertools.chain(*[k.labels() for k in self._datasets]))
...@@ -379,11 +379,11 @@ def _make_balanced_random_sampler( ...@@ -379,11 +379,11 @@ def _make_balanced_random_sampler(
for ds in dataset.datasets for ds in dataset.datasets
for k in typing.cast(Dataset, ds).labels() for k in typing.cast(Dataset, ds).labels()
] ]
weights = _calculate_weights(targets) weights = _calculate_weights(targets) # type: ignore
else: else:
logger.warning( logger.warning(
f"Balancing samples **and** concatenated-datasets " f"Balancing samples **and** concatenated-datasets "
f"WITHOUT metadata targets (`{target}` not available)" f"by using dataset totals as `{target}: int` is not true"
) )
weights = [ weights = [
k k
...@@ -403,10 +403,11 @@ def _make_balanced_random_sampler( ...@@ -403,10 +403,11 @@ def _make_balanced_random_sampler(
f"Balancing samples from dataset using metadata " f"Balancing samples from dataset using metadata "
f"targets `{target}`" f"targets `{target}`"
) )
weights = _calculate_weights(dataset.labels()) weights = _calculate_weights(dataset.labels()) # type: ignore
else: else:
raise RuntimeError( raise RuntimeError(
f"Cannot balance samples without metadata targets `{target}`" f"Cannot balance samples with multiple class labels "
f"({target}: list[int]) or without metadata targets `{target}`"
) )
return torch.utils.data.WeightedRandomSampler( return torch.utils.data.WeightedRandomSampler(
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""NIH CXR14 (relabeled) dataset for computer-aided diagnosis. """NIH CXR14 (relabeled) dataset for computer-aided diagnosis.
This dataset was extracted from the clinical PACS database at the National This dataset was extracted from the clinical PACS database at the National
......
{
"train": [
["images/00000001_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000001_001.png", [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000001_002.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["images/00000007_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000010_000.png", [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["images/00000011_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000011_001.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["images/00000011_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["images/00000011_003.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["images/00000013_011.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
["images/00000013_014.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
["images/00000013_018.png", [1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]],
["images/00000013_022.png", [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
["images/00000013_024.png", [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000013_025.png", [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
["images/00000013_026.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_027.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
["images/00000013_028.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
["images/00000013_029.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
["images/00000013_030.png", [1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]],
["images/00000013_031.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_032.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_034.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_037.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_038.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_040.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
["images/00000013_041.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
["images/00000013_043.png", [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]],
["images/00000013_044.png", [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0]],
["images/00000013_045.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],
["images/00000013_046.png", [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000031_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000033_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]],
["images/00000044_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000045_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000046_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]],
["images/00000054_003.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000059_000.png", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000066_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],
["images/00000069_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
],
"validation": [
["images/00000001_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000001_001.png", [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000001_002.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["images/00000007_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000010_000.png", [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["images/00000011_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000011_001.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["images/00000011_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["images/00000011_003.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["images/00000013_011.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
["images/00000013_014.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
["images/00000013_018.png", [1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]],
["images/00000013_022.png", [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
["images/00000013_024.png", [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000013_025.png", [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
["images/00000013_026.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_027.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
["images/00000013_028.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
["images/00000013_029.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
["images/00000013_030.png", [1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]],
["images/00000013_031.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_032.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_034.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_037.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_038.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
["images/00000013_040.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
["images/00000013_041.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
["images/00000013_043.png", [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]],
["images/00000013_044.png", [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0]],
["images/00000013_045.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],
["images/00000013_046.png", [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000031_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000033_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]],
["images/00000044_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000045_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000046_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]],
["images/00000054_003.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000059_000.png", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
["images/00000066_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],
["images/00000069_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
]
}
File deleted
...@@ -2,47 +2,6 @@ ...@@ -2,47 +2,6 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""NIH CXR14 dataset for computer-aided diagnosis. from .datamodule import DataModule
First 40 images with cardiomegaly. datamodule = DataModule("cardiomegaly.json")
* See :py:mod:`ptbench.data.nih_cxr14_re` for split details
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets
from ..base_datamodule import BaseDataModule
from . import _maker
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class Fold0Module(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
self.dataset = _maker("cardiomegaly")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = Fold0Module
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import importlib.resources
import os
import PIL.Image
from torchvision.transforms.functional import to_tensor
from ...utils.rc import load_rc
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from ..typing import DatabaseSplit
from ..typing import RawDataLoader as _BaseRawDataLoader
from ..typing import Sample
class RawDataLoader(_BaseRawDataLoader):
"""A specialized raw-data-loader for the Montgomery dataset.
Attributes
----------
datadir
This variable contains the base directory where the database raw data
is stored.
idiap_file_organisation
This variable will be ``True``, if the user has set the configuration
parameter ``nih_cxr14_re.idiap_file_organisation`` in the global
configuration file. It will cause internal loader to search for files
in a slightly different folder structure, that was adapted to Idiap's
requirements (number of files per folder to be less than 10k).
"""
datadir: str
idiap_file_organisation: bool
def __init__(self):
rc = load_rc()
self.datadir = rc.get(
"datadir.nih_cxr14_re", os.path.realpath(os.curdir)
)
self.idiap_file_organisation = rc.get(
"nih_cxr14_re.idiap_folder_structure", False
)
def sample(self, sample: tuple[str, list[int]]) -> Sample:
"""Loads a single image sample from the disk.
Parameters
----------
sample:
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
sample
The sample representation
"""
file_path = sample[0] # default
if self.idiap_file_organisation:
# for folder lookup efficiency, data is split into subfolders
# each original file is on the subfolder `f[:5]/f`, where f
# is the original file basename
basename = os.path.basename(sample[0])
file_path = os.path.join(
os.path.dirname(sample[0]),
basename[:5],
basename,
)
# N.B.: NIH CXR-14 images are encoded as color PNGs
image = PIL.Image.open(os.path.join(self.datadir, file_path))
tensor = to_tensor(image)
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, list[int]]) -> list[int]:
"""Loads a single image sample label from the disk.
Parameters
----------
sample:
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
labels
The integer labels associated with the sample
"""
return sample[1]
def make_split(basename: str) -> DatabaseSplit:
"""Returns a database split for the Montgomery database."""
return JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
)
class DataModule(CachingDataModule):
"""NIH CXR14 (relabeled) datamodule for computer-aided diagnosis.
This dataset was extracted from the clinical PACS database at the National
Institutes of Health Clinical Center (USA) and represents 60% of all their
radiographs. It contains labels for 14 common radiological signs in this
order: cardiomegaly, emphysema, effusion, hernia, infiltration, mass,
nodule, atelectasis, pneumothorax, pleural thickening, pneumonia, fibrosis,
edema and consolidation. This is the relabeled version created in the
CheXNeXt study.
* Reference: [NIH-CXR14-2017]_
* Original resolution (height x width): 1024 x 1024
* Labels: [CHEXNEXT-2018]_
* Split reference: [CHEXNEXT-2018]_
* Protocol ``default``:
* Training samples: 98637
* Validation samples: 6350
* Test samples: 4355
* Output image:
* Transforms:
* Load raw PNG with :py:mod:`PIL`
* Final specifications
* RGB, encoded as a 3-plane image, 8 bits
* Square (1024x1024 px)
"""
def __init__(self, split_filename: str):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
)
...@@ -2,46 +2,6 @@ ...@@ -2,46 +2,6 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""NIH CXR14 (relabeled) dataset for computer-aided diagnosis (default from .datamodule import DataModule
protocol)
* See :py:mod:`ptbench.data.nih_cxr14_re` for split details datamodule = DataModule("default.json.bz2")
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets
from ..base_datamodule import BaseDataModule
from . import _maker
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
self.dataset = _maker("default")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
...@@ -28,7 +28,7 @@ class RawDataLoader: ...@@ -28,7 +28,7 @@ class RawDataLoader:
"""Loads whole samples from media.""" """Loads whole samples from media."""
raise NotImplementedError("You must implement the `sample()` method") raise NotImplementedError("You must implement the `sample()` method")
def label(self, k: typing.Any) -> int: def label(self, k: typing.Any) -> int | list[int]:
"""Loads only sample label from media. """Loads only sample label from media.
If you do not override this implementation, then, by default, If you do not override this implementation, then, by default,
...@@ -79,7 +79,7 @@ class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized): ...@@ -79,7 +79,7 @@ class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
provide a dunder len method. provide a dunder len method.
""" """
def labels(self) -> list[int]: def labels(self) -> list[int | list[int]]:
"""Returns the integer labels for all samples in the dataset.""" """Returns the integer labels for all samples in the dataset."""
raise NotImplementedError("You must implement the `labels()` method") raise NotImplementedError("You must implement the `labels()` method")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment