Skip to content
Snippets Groups Projects

updated TB-POC dataset and corresponding tests

Merged Maxime DELITROZ requested to merge update-tbpoc into add-datamodule
13 files
+ 285
550
Compare changes
  • Side-by-side
  • Inline
Files
13
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""TB-POC dataset for computer-aided diagnosis.
* Reference: [TB-POC-2018]_
* Original resolution (height x width or width x height): 2048 x 2500
* Split reference: none
* Stratified kfold protocol:
* Training samples: 72% of TB and healthy CXR (including labels)
* Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
"""
import importlib.resources
import os
from ...utils.rc import load_rc
from .. import make_dataset
from ..dataset import JSONDataset
from ..loader import load_pil_grayscale, make_delayed
_protocols = [
importlib.resources.files(__name__).joinpath("fold_0.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_1.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_2.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_3.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_4.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_5.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_6.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_7.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
]
_datadir = load_rc().get("datadir.tbpoc", os.path.realpath(os.curdir))
def _raw_data_loader(sample):
return dict(
data=load_pil_grayscale(os.path.join(_datadir, sample["data"])),
label=sample["label"],
)
def _loader(context, sample):
# "context" is ignored in this case - database is homogeneous
# we returned delayed samples to avoid loading all images at once
return make_delayed(sample, _raw_data_loader)
json_dataset = JSONDataset(
protocols=_protocols,
fieldnames=("data", "label"),
loader=_loader,
)
"""TB-POC dataset object."""
def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms
from ..augmentations import ElasticDeformation
from ..image_utils import RemoveBlackBorders
post_transforms = []
if RGB:
post_transforms = [
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]
return make_dataset(
[json_dataset.subsets(protocol)],
[
RemoveBlackBorders(),
transforms.Resize(resize_size),
transforms.CenterCrop(cc_size),
],
[ElasticDeformation(p=0.8)],
post_transforms,
)
Loading