diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 2cbc4b848d080de2d4500bd84111efb8905efcfa..9b5eab61b988e65d5b7d199d1cfcfc854cd784be 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -106,7 +106,7 @@ class _DelayedLoadingDataset(Dataset): sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0) logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") - def labels(self) -> list[int]: + def labels(self) -> list[int | list[int]]: """Returns the integer labels for all samples in the dataset.""" return [self.loader.label(k) for k in self.raw_dataset] @@ -223,7 +223,7 @@ class _CachedDataset(Dataset): f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb" ) - def labels(self) -> list[int]: + def labels(self) -> list[int | list[int]]: """Returns the integer labels for all samples in the dataset.""" return [k[1]["label"] for k in self.data] @@ -256,7 +256,7 @@ class _ConcatDataset(Dataset): for j in range(len(datasets[i])) ] - def labels(self) -> list[int]: + def labels(self) -> list[int | list[int]]: """Returns the integer labels for all samples in the dataset.""" return list(itertools.chain(*[k.labels() for k in self._datasets])) @@ -379,11 +379,11 @@ def _make_balanced_random_sampler( for ds in dataset.datasets for k in typing.cast(Dataset, ds).labels() ] - weights = _calculate_weights(targets) + weights = _calculate_weights(targets) # type: ignore else: logger.warning( f"Balancing samples **and** concatenated-datasets " - f"WITHOUT metadata targets (`{target}` not available)" + f"by using dataset totals as `{target}: int` is not true" ) weights = [ k @@ -403,10 +403,11 @@ def _make_balanced_random_sampler( f"Balancing samples from dataset using metadata " f"targets `{target}`" ) - weights = _calculate_weights(dataset.labels()) + weights = _calculate_weights(dataset.labels()) # type: ignore else: raise RuntimeError( - f"Cannot balance samples without metadata targets `{target}`" + f"Cannot balance samples with multiple class labels " + f"({target}: list[int]) or without metadata targets `{target}`" ) return torch.utils.data.WeightedRandomSampler( diff --git a/src/ptbench/data/nih_cxr14_re/__init__.py b/src/ptbench/data/nih_cxr14_re/__init__.py index 27d1903c5a25a1ccc99520867ad407b3186a3694..b9954cf126eae1670c87296ad86f0ca6f4f9e758 100644 --- a/src/ptbench/data/nih_cxr14_re/__init__.py +++ b/src/ptbench/data/nih_cxr14_re/__init__.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later - """NIH CXR14 (relabeled) dataset for computer-aided diagnosis. This dataset was extracted from the clinical PACS database at the National diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.json b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json new file mode 100644 index 0000000000000000000000000000000000000000..b9af6ad7b85245631f3ed5825f74a8e1ce2654d5 --- /dev/null +++ b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json @@ -0,0 +1,86 @@ +{ + "train": [ + ["images/00000001_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000001_001.png", [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000001_002.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000007_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000010_000.png", [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000011_001.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_003.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000013_011.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_014.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_018.png", [1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_022.png", [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_024.png", [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_025.png", [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_026.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_027.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_028.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_029.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_030.png", [1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]], + ["images/00000013_031.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_032.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_034.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_037.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_038.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_040.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_041.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_043.png", [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000013_044.png", [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000013_045.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]], + ["images/00000013_046.png", [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000031_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000033_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]], + ["images/00000044_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000045_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000046_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000054_003.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000059_000.png", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000066_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]], + ["images/00000069_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + ], + "validation": [ + ["images/00000001_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000001_001.png", [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000001_002.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000007_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000010_000.png", [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000011_001.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000011_003.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]], + ["images/00000013_011.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_014.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_018.png", [1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_022.png", [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_024.png", [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_025.png", [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000013_026.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_027.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_028.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_029.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]], + ["images/00000013_030.png", [1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]], + ["images/00000013_031.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_032.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_034.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_037.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_038.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]], + ["images/00000013_040.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_041.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]], + ["images/00000013_043.png", [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000013_044.png", [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000013_045.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]], + ["images/00000013_046.png", [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000031_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000033_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]], + ["images/00000044_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000045_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000046_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]], + ["images/00000054_003.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000059_000.png", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ["images/00000066_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]], + ["images/00000069_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + ] +} diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2 b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2 deleted file mode 100644 index 13b6d810cd7bb430b332476e363970a5a728fa3e..0000000000000000000000000000000000000000 Binary files a/src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2 and /dev/null differ diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.py b/src/ptbench/data/nih_cxr14_re/cardiomegaly.py index 1904ebfa60dade4ff59f770da7f1310a099c798b..0715650d7ec5b9394249b885172f2f5af646dd2d 100644 --- a/src/ptbench/data/nih_cxr14_re/cardiomegaly.py +++ b/src/ptbench/data/nih_cxr14_re/cardiomegaly.py @@ -2,47 +2,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""NIH CXR14 dataset for computer-aided diagnosis. +from .datamodule import DataModule -First 40 images with cardiomegaly. - -* See :py:mod:`ptbench.data.nih_cxr14_re` for split details -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("cardiomegaly") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = Fold0Module +datamodule = DataModule("cardiomegaly.json") diff --git a/src/ptbench/data/nih_cxr14_re/datamodule.py b/src/ptbench/data/nih_cxr14_re/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..66a4379cf0880817c8035ff62cd3fea1d28af193 --- /dev/null +++ b/src/ptbench/data/nih_cxr14_re/datamodule.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import importlib.resources +import os + +import PIL.Image + +from torchvision.transforms.functional import to_tensor + +from ...utils.rc import load_rc +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from ..typing import DatabaseSplit +from ..typing import RawDataLoader as _BaseRawDataLoader +from ..typing import Sample + + +class RawDataLoader(_BaseRawDataLoader): + """A specialized raw-data-loader for the Montgomery dataset. + + Attributes + ---------- + + datadir + This variable contains the base directory where the database raw data + is stored. + + idiap_file_organisation + This variable will be ``True``, if the user has set the configuration + parameter ``nih_cxr14_re.idiap_file_organisation`` in the global + configuration file. It will cause internal loader to search for files + in a slightly different folder structure, that was adapted to Idiap's + requirements (number of files per folder to be less than 10k). + """ + + datadir: str + idiap_file_organisation: bool + + def __init__(self): + rc = load_rc() + self.datadir = rc.get( + "datadir.nih_cxr14_re", os.path.realpath(os.curdir) + ) + self.idiap_file_organisation = rc.get( + "nih_cxr14_re.idiap_folder_structure", False + ) + + def sample(self, sample: tuple[str, list[int]]) -> Sample: + """Loads a single image sample from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + sample + The sample representation + """ + file_path = sample[0] # default + if self.idiap_file_organisation: + # for folder lookup efficiency, data is split into subfolders + # each original file is on the subfolder `f[:5]/f`, where f + # is the original file basename + basename = os.path.basename(sample[0]) + file_path = os.path.join( + os.path.dirname(sample[0]), + basename[:5], + basename, + ) + + # N.B.: NIH CXR-14 images are encoded as color PNGs + image = PIL.Image.open(os.path.join(self.datadir, file_path)) + tensor = to_tensor(image) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() + + return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + + def label(self, sample: tuple[str, list[int]]) -> list[int]: + """Loads a single image sample label from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + labels + The integer labels associated with the sample + """ + return sample[1] + + +def make_split(basename: str) -> DatabaseSplit: + """Returns a database split for the Montgomery database.""" + + return JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + ) + + +class DataModule(CachingDataModule): + """NIH CXR14 (relabeled) datamodule for computer-aided diagnosis. + + This dataset was extracted from the clinical PACS database at the National + Institutes of Health Clinical Center (USA) and represents 60% of all their + radiographs. It contains labels for 14 common radiological signs in this + order: cardiomegaly, emphysema, effusion, hernia, infiltration, mass, + nodule, atelectasis, pneumothorax, pleural thickening, pneumonia, fibrosis, + edema and consolidation. This is the relabeled version created in the + CheXNeXt study. + + * Reference: [NIH-CXR14-2017]_ + * Original resolution (height x width): 1024 x 1024 + * Labels: [CHEXNEXT-2018]_ + * Split reference: [CHEXNEXT-2018]_ + * Protocol ``default``: + + * Training samples: 98637 + * Validation samples: 6350 + * Test samples: 4355 + + * Output image: + + * Transforms: + + * Load raw PNG with :py:mod:`PIL` + + * Final specifications + + * RGB, encoded as a 3-plane image, 8 bits + * Square (1024x1024 px) + """ + + def __init__(self, split_filename: str): + super().__init__( + database_split=make_split(split_filename), + raw_data_loader=RawDataLoader(), + ) diff --git a/src/ptbench/data/nih_cxr14_re/default.py b/src/ptbench/data/nih_cxr14_re/default.py index 0ea6ef5acc55560ae8db115f3585b40da3cf58b8..7fe993a981c86c0161327d1ddb4498e08a90313c 100644 --- a/src/ptbench/data/nih_cxr14_re/default.py +++ b/src/ptbench/data/nih_cxr14_re/default.py @@ -2,46 +2,6 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""NIH CXR14 (relabeled) dataset for computer-aided diagnosis (default -protocol) +from .datamodule import DataModule -* See :py:mod:`ptbench.data.nih_cxr14_re` for split details -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("default") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +datamodule = DataModule("default.json.bz2") diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py index bf821068eee6d6a724150a5c56e9b9ad374374da..6f41b39eb33d2a91c51008623388bc2900032665 100644 --- a/src/ptbench/data/typing.py +++ b/src/ptbench/data/typing.py @@ -28,7 +28,7 @@ class RawDataLoader: """Loads whole samples from media.""" raise NotImplementedError("You must implement the `sample()` method") - def label(self, k: typing.Any) -> int: + def label(self, k: typing.Any) -> int | list[int]: """Loads only sample label from media. If you do not override this implementation, then, by default, @@ -79,7 +79,7 @@ class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized): provide a dunder len method. """ - def labels(self) -> list[int]: + def labels(self) -> list[int | list[int]]: """Returns the integer labels for all samples in the dataset.""" raise NotImplementedError("You must implement the `labels()` method")