Skip to content
Snippets Groups Projects

Making use of LightningDataModule and simplification of data loading

Merged Daniel CARRON requested to merge add-datamodule into main
Compare and Show latest version
60 files
+ 854
2065
Compare changes
  • Side-by-side
  • Inline
Files
60
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Montgomery dataset for computer-aided diagnosis.
The Montgomery database has been established to foster research
in computer-aided diagnosis of pulmonary diseases with a special
focus on pulmonary tuberculosis (TB).
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 4020 x 4892
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
import importlib.resources
import os
from ...utils.rc import load_rc
from .. import make_dataset
from ..dataset import JSONDataset
from ..loader import load_pil_baw, make_delayed
_protocols = [
importlib.resources.files(__name__).joinpath("default.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_0.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_1.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_2.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_3.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_4.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_5.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_6.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_7.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
]
_datadir = load_rc().get("datadir.montgomery", os.path.realpath(os.curdir))
def _raw_data_loader(sample):
return dict(
data=load_pil_baw(os.path.join(_datadir, sample["data"])), # type: ignore
label=sample["label"],
)
def _loader(context, sample):
# "context" is ignored in this case - database is homogeneous
# we return delayed samples to avoid loading all images at once
return make_delayed(sample, _raw_data_loader)
json_dataset = JSONDataset(
protocols=_protocols,
fieldnames=("data", "label"),
loader=_loader,
)
"""Montgomery dataset object."""
def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms
from ..transforms import ElasticDeformation, RemoveBlackBorders
post_transforms = []
if RGB:
post_transforms = [
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]
return make_dataset(
[json_dataset.subsets(protocol)],
[
RemoveBlackBorders(),
transforms.Resize(resize_size),
transforms.CenterCrop(cc_size),
],
[ElasticDeformation(p=0.8)],
post_transforms,
)
Loading