Skip to content
Snippets Groups Projects
datamodule.py 4.74 KiB
Newer Older
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import importlib.resources
import os

import PIL.Image

from torchvision.transforms.functional import to_tensor

from ...utils.rc import load_rc
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from ..typing import DatabaseSplit
from ..typing import RawDataLoader as _BaseRawDataLoader
from ..typing import Sample


class RawDataLoader(_BaseRawDataLoader):
    """A specialized raw-data-loader for the NIH CXR-14 dataset.

    Attributes
    ----------

    datadir
        This variable contains the base directory where the database raw data
        is stored.

    idiap_file_organisation
        This variable will be ``True``, if the user has set the configuration
        parameter ``nih_cxr14.idiap_file_organisation`` in the global
        configuration file.  It will cause internal loader to search for files
        in a slightly different folder structure, that was adapted to Idiap's
        requirements (number of files per folder to be less than 10k).
    """

    datadir: str
    idiap_file_organisation: bool

    def __init__(self):
        rc = load_rc()
        self.datadir = rc.get("datadir.nih_cxr14", os.path.realpath(os.curdir))
            "nih_cxr14.idiap_folder_structure", False
        )

    def sample(self, sample: tuple[str, list[int]]) -> Sample:
        """Loads a single image sample from the disk.

        Parameters
        ----------

        sample:
            A tuple containing the path suffix, within the dataset root folder,
            where to find the image to be loaded, and an integer, representing the
            sample label.


        Returns
        -------

        sample
            The sample representation
        """
        file_path = sample[0]  # default
        if self.idiap_file_organisation:
            # for folder lookup efficiency, data is split into subfolders
            # each original file is on the subfolder `f[:5]/f`, where f
            # is the original file basename
            basename = os.path.basename(sample[0])
            file_path = os.path.join(
                os.path.dirname(sample[0]),
                basename[:5],
                basename,
            )

        # N.B.: NIH CXR-14 images are encoded as color PNGs
        image = PIL.Image.open(os.path.join(self.datadir, file_path))
        tensor = to_tensor(image)

        # use the code below to view generated images
        # from torchvision.transforms.functional import to_pil_image
        # to_pil_image(tensor).show()
        # __import__("pdb").set_trace()

        return tensor, dict(label=sample[1], name=sample[0])  # type: ignore[arg-type]

    def label(self, sample: tuple[str, list[int]]) -> list[int]:
        """Loads a single image sample label from the disk.

        Parameters
        ----------

        sample:
            A tuple containing the path suffix, within the dataset root folder,
            where to find the image to be loaded, and an integer, representing the
            sample label.


        Returns
        -------

        labels
            The integer labels associated with the sample
        """
        return sample[1]


def make_split(basename: str) -> DatabaseSplit:
    """Returns a database split for the NIH CXR-14 database."""

    return JSONDatabaseSplit(
        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
    )


class DataModule(CachingDataModule):
    """NIH CXR14 (relabeled) datamodule for computer-aided diagnosis.

    This dataset was extracted from the clinical PACS database at the National
    Institutes of Health Clinical Center (USA) and represents 60% of all their
    radiographs. It contains labels for 14 common radiological signs in this
    order: cardiomegaly, emphysema, effusion, hernia, infiltration, mass,
    nodule, atelectasis, pneumothorax, pleural thickening, pneumonia, fibrosis,
    edema and consolidation. This is the relabeled version created in the
    CheXNeXt study.

    * Reference: [NIH-CXR14-2017]_
    * Original resolution (height x width): 1024 x 1024
    * Labels: [CHEXNEXT-2018]_
    * Split reference: [CHEXNEXT-2018]_
    * Protocol ``default``:

      * Training samples: 98637
      * Validation samples: 6350
      * Test samples: 4355

    * Output image:

        * Transforms:

            * Load raw PNG with :py:mod:`PIL`

        * Final specifications

            * RGB, encoded as a 3-plane image, 8 bits
            * Square (1024x1024 px)
    """

    def __init__(self, split_filename: str):
        super().__init__(
            database_split=make_split(split_filename),
            raw_data_loader=RawDataLoader(),
        )