Newer
Older

André Anjos
committed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import importlib.resources
import os
import PIL.Image
from torchvision.transforms.functional import to_tensor
from ...utils.rc import load_rc
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from ..typing import DatabaseSplit
from ..typing import RawDataLoader as _BaseRawDataLoader
from ..typing import Sample
class RawDataLoader(_BaseRawDataLoader):
"""A specialized raw-data-loader for the NIH CXR-14 dataset.

André Anjos
committed
Attributes
----------
datadir
This variable contains the base directory where the database raw data
is stored.
idiap_file_organisation
This variable will be ``True``, if the user has set the configuration
parameter ``nih_cxr14.idiap_file_organisation`` in the global

André Anjos
committed
configuration file. It will cause internal loader to search for files
in a slightly different folder structure, that was adapted to Idiap's
requirements (number of files per folder to be less than 10k).
"""
datadir: str
idiap_file_organisation: bool
def __init__(self):
rc = load_rc()
self.datadir = rc.get("datadir.nih_cxr14", os.path.realpath(os.curdir))

André Anjos
committed
self.idiap_file_organisation = rc.get(
"nih_cxr14.idiap_folder_structure", False

André Anjos
committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
)
def sample(self, sample: tuple[str, list[int]]) -> Sample:
"""Loads a single image sample from the disk.
Parameters
----------
sample:
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
sample
The sample representation
"""
file_path = sample[0] # default
if self.idiap_file_organisation:
# for folder lookup efficiency, data is split into subfolders
# each original file is on the subfolder `f[:5]/f`, where f
# is the original file basename
basename = os.path.basename(sample[0])
file_path = os.path.join(
os.path.dirname(sample[0]),
basename[:5],
basename,
)
# N.B.: NIH CXR-14 images are encoded as color PNGs
image = PIL.Image.open(os.path.join(self.datadir, file_path))
tensor = to_tensor(image)
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, list[int]]) -> list[int]:
"""Loads a single image sample label from the disk.
Parameters
----------
sample:
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
labels
The integer labels associated with the sample
"""
return sample[1]
def make_split(basename: str) -> DatabaseSplit:
"""Returns a database split for the NIH CXR-14 database."""

André Anjos
committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
return JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
)
class DataModule(CachingDataModule):
"""NIH CXR14 (relabeled) datamodule for computer-aided diagnosis.
This dataset was extracted from the clinical PACS database at the National
Institutes of Health Clinical Center (USA) and represents 60% of all their
radiographs. It contains labels for 14 common radiological signs in this
order: cardiomegaly, emphysema, effusion, hernia, infiltration, mass,
nodule, atelectasis, pneumothorax, pleural thickening, pneumonia, fibrosis,
edema and consolidation. This is the relabeled version created in the
CheXNeXt study.
* Reference: [NIH-CXR14-2017]_
* Original resolution (height x width): 1024 x 1024
* Labels: [CHEXNEXT-2018]_
* Split reference: [CHEXNEXT-2018]_
* Protocol ``default``:
* Training samples: 98637
* Validation samples: 6350
* Test samples: 4355
* Output image:
* Transforms:
* Load raw PNG with :py:mod:`PIL`
* Final specifications
* RGB, encoded as a 3-plane image, 8 bits
* Square (1024x1024 px)
"""
def __init__(self, split_filename: str):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
)