Skip to content
Snippets Groups Projects
Commit 50fce2cd authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.montgomery/shenzhen.loader] Simplify loaders, adjust some variable...

[data.montgomery/shenzhen.loader] Simplify loaders, adjust some variable names, comments and add some (commented out) test code to visualize generated images
parent adf82385
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -2,37 +2,16 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Montgomery datamodule for TB detection (default protocol)
* See :py:mod:`ptbench.data.montgomery` for more database details.
This configuration:
* Raw data input (on disk):
* PNG images 8 bit grayscale
* resolution: 4020 x 4892 px or 4892 x 4020 px
* Output image:
* Transforms:
* Load raw PNG with :py:mod:`PIL`
* Remove black borders
* Torch center cropping to get square image
* Final specifications
* Grayscale (single channel), 8 bits
* Varying resolutions
"""
"""Specialized raw-data loaders for the Montgomery dataset."""
import os
import PIL.Image
from torchvision.transforms.functional import center_crop, to_tensor
from ...utils.rc import load_rc
from ..image_utils import load_pil, remove_black_borders
from ..image_utils import remove_black_borders
from ..typing import RawDataLoader as _BaseRawDataLoader
from ..typing import Sample
......@@ -73,10 +52,17 @@ class RawDataLoader(_BaseRawDataLoader):
sample
The sample representation
"""
tensor = load_pil(os.path.join(self.datadir, sample[0]))
tensor = remove_black_borders(tensor)
tensor = center_crop(tensor, min(tensor.size[0], tensor.size[1]))
tensor = to_tensor(tensor)
# N.B.: Montgomery images are encoded as grayscale PNGs, so no need to
# convert them again with Image.convert("L").
image = PIL.Image.open(os.path.join(self.datadir, sample[0]))
image = remove_black_borders(image)
tensor = to_tensor(image)
tensor = center_crop(tensor, min(*tensor.shape[1:]))
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
......
......@@ -2,30 +2,16 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the National
Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
out-patient clinics, and were captured as part of the daily routine using
Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
"""Specialized raw-data loaders for the Shenzen dataset."""
import os
import PIL.Image
from torchvision.transforms.functional import center_crop, to_tensor
from ...utils.rc import load_rc
from ..image_utils import load_pil_rgb, remove_black_borders
from ..image_utils import remove_black_borders
from ..typing import RawDataLoader as _BaseRawDataLoader
from ..typing import Sample
......@@ -69,10 +55,19 @@ class RawDataLoader(_BaseRawDataLoader):
sample
The sample representation
"""
tensor = load_pil_rgb(os.path.join(self.datadir, sample[0]))
tensor = remove_black_borders(tensor)
tensor = center_crop(tensor, min(tensor.size[0], tensor.size[1]))
tensor = to_tensor(tensor)
# N.B.: Image.convert("L") is required to normalize grayscale back to
# normal (instead of inverted).
image = PIL.Image.open(os.path.join(self.datadir, sample[0])).convert(
"L"
)
image = remove_black_borders(image)
tensor = to_tensor(image)
tensor = center_crop(tensor, min(*tensor.shape[1:]))
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment