From 50fce2cda942a97dfbbe4eaf9ff401af78736ef5 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Thu, 20 Jul 2023 23:26:54 +0200 Subject: [PATCH] [data.montgomery/shenzhen.loader] Simplify loaders, adjust some variable names, comments and add some (commented out) test code to visualize generated images --- src/ptbench/data/montgomery/loader.py | 44 +++++++++------------------ src/ptbench/data/shenzhen/loader.py | 39 +++++++++++------------- 2 files changed, 32 insertions(+), 51 deletions(-) diff --git a/src/ptbench/data/montgomery/loader.py b/src/ptbench/data/montgomery/loader.py index 45733dad..ad856d5f 100644 --- a/src/ptbench/data/montgomery/loader.py +++ b/src/ptbench/data/montgomery/loader.py @@ -2,37 +2,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Montgomery datamodule for TB detection (default protocol) - -* See :py:mod:`ptbench.data.montgomery` for more database details. - -This configuration: - -* Raw data input (on disk): - - * PNG images 8 bit grayscale - * resolution: 4020 x 4892 px or 4892 x 4020 px - -* Output image: - - * Transforms: - - * Load raw PNG with :py:mod:`PIL` - * Remove black borders - * Torch center cropping to get square image - - * Final specifications - - * Grayscale (single channel), 8 bits - * Varying resolutions -""" +"""Specialized raw-data loaders for the Montgomery dataset.""" import os +import PIL.Image + from torchvision.transforms.functional import center_crop, to_tensor from ...utils.rc import load_rc -from ..image_utils import load_pil, remove_black_borders +from ..image_utils import remove_black_borders from ..typing import RawDataLoader as _BaseRawDataLoader from ..typing import Sample @@ -73,10 +52,17 @@ class RawDataLoader(_BaseRawDataLoader): sample The sample representation """ - tensor = load_pil(os.path.join(self.datadir, sample[0])) - tensor = remove_black_borders(tensor) - tensor = center_crop(tensor, min(tensor.size[0], tensor.size[1])) - tensor = to_tensor(tensor) + # N.B.: Montgomery images are encoded as grayscale PNGs, so no need to + # convert them again with Image.convert("L"). + image = PIL.Image.open(os.path.join(self.datadir, sample[0])) + image = remove_black_borders(image) + tensor = to_tensor(image) + tensor = center_crop(tensor, min(*tensor.shape[1:])) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py index df4d4afb..3409fed2 100644 --- a/src/ptbench/data/shenzhen/loader.py +++ b/src/ptbench/data/shenzhen/loader.py @@ -2,30 +2,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen dataset for computer-aided diagnosis. - -The standard digital image database for Tuberculosis is created by the National -Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s -Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from -out-patient clinics, and were captured as part of the daily routine using -Philips DR Digital Diagnose systems. - -* Reference: [MONTGOMERY-SHENZHEN-2014]_ -* Original resolution (height x width or width x height): 3000 x 3000 or less -* Split reference: none -* Protocol ``default``: - - * Training samples: 64% of TB and healthy CXR (including labels) - * Validation samples: 16% of TB and healthy CXR (including labels) - * Test samples: 20% of TB and healthy CXR (including labels) -""" +"""Specialized raw-data loaders for the Shenzen dataset.""" import os +import PIL.Image + from torchvision.transforms.functional import center_crop, to_tensor from ...utils.rc import load_rc -from ..image_utils import load_pil_rgb, remove_black_borders +from ..image_utils import remove_black_borders from ..typing import RawDataLoader as _BaseRawDataLoader from ..typing import Sample @@ -69,10 +55,19 @@ class RawDataLoader(_BaseRawDataLoader): sample The sample representation """ - tensor = load_pil_rgb(os.path.join(self.datadir, sample[0])) - tensor = remove_black_borders(tensor) - tensor = center_crop(tensor, min(tensor.size[0], tensor.size[1])) - tensor = to_tensor(tensor) + # N.B.: Image.convert("L") is required to normalize grayscale back to + # normal (instead of inverted). + image = PIL.Image.open(os.path.join(self.datadir, sample[0])).convert( + "L" + ) + image = remove_black_borders(image) + tensor = to_tensor(image) + tensor = center_crop(tensor, min(*tensor.shape[1:])) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] -- GitLab