diff --git a/src/ptbench/data/montgomery/loader.py b/src/ptbench/data/montgomery/loader.py index 45733dadde59eac18c43be953f54ed15bc89db47..ad856d5fcc45603015cf75c1c87885751f25bcd8 100644 --- a/src/ptbench/data/montgomery/loader.py +++ b/src/ptbench/data/montgomery/loader.py @@ -2,37 +2,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (default protocol) - -* See :py:mod:`ptbench.data.montgomery` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images 8 bit grayscale - * resolution: 4020 x 4892 px or 4892 x 4020 px - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch center cropping to get square image - - * Final specifications - - * Grayscale (single channel), 8 bits - * Varying resolutions -""" +"""Specialized raw-data loaders for the Montgomery dataset.""" import os +import PIL.Image + from torchvision.transforms.functional import center_crop, to_tensor from ...utils.rc import load_rc -from ..image_utils import load_pil, remove_black_borders +from ..image_utils import remove_black_borders from ..typing import RawDataLoader as _BaseRawDataLoader from ..typing import Sample @@ -73,10 +52,17 @@ class RawDataLoader(_BaseRawDataLoader): sample The sample representation """ - tensor = load_pil(os.path.join(self.datadir, sample[0])) - tensor = remove_black_borders(tensor) - tensor = center_crop(tensor, min(tensor.size[0], tensor.size[1])) - tensor = to_tensor(tensor) + # N.B.: Montgomery images are encoded as grayscale PNGs, so no need to + # convert them again with Image.convert("L"). + image = PIL.Image.open(os.path.join(self.datadir, sample[0])) + image = remove_black_borders(image) + tensor = to_tensor(image) + tensor = center_crop(tensor, min(*tensor.shape[1:])) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py index df4d4afbb0e4afa22522b7ebc403dc91372fefcd..3409fed2e1a552c44135888df6d6bc4a874b427c 100644 --- a/src/ptbench/data/shenzhen/loader.py +++ b/src/ptbench/data/shenzhen/loader.py @@ -2,30 +2,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen dataset for computer-aided diagnosis. - -The standard digital image database for Tuberculosis is created by the National -Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s -Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from -out-patient clinics, and were captured as part of the daily routine using -Philips DR Digital Diagnose systems. - -* Reference: [MONTGOMERY-SHENZHEN-2014]_ -* Original resolution (height x width or width x height): 3000 x 3000 or less -* Split reference: none -* Protocol ``default``: - - * Training samples: 64% of TB and healthy CXR (including labels) - * Validation samples: 16% of TB and healthy CXR (including labels) - * Test samples: 20% of TB and healthy CXR (including labels) -""" +"""Specialized raw-data loaders for the Shenzen dataset.""" import os +import PIL.Image + from torchvision.transforms.functional import center_crop, to_tensor from ...utils.rc import load_rc -from ..image_utils import load_pil_rgb, remove_black_borders +from ..image_utils import remove_black_borders from ..typing import RawDataLoader as _BaseRawDataLoader from ..typing import Sample @@ -69,10 +55,19 @@ class RawDataLoader(_BaseRawDataLoader): sample The sample representation """ - tensor = load_pil_rgb(os.path.join(self.datadir, sample[0])) - tensor = remove_black_borders(tensor) - tensor = center_crop(tensor, min(tensor.size[0], tensor.size[1])) - tensor = to_tensor(tensor) + # N.B.: Image.convert("L") is required to normalize grayscale back to + # normal (instead of inverted). + image = PIL.Image.open(os.path.join(self.datadir, sample[0])).convert( + "L" + ) + image = remove_black_borders(image) + tensor = to_tensor(image) + tensor = center_crop(tensor, min(*tensor.shape[1:])) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]