diff --git a/src/ptbench/data/tbpoc/__init__.py b/src/ptbench/data/tbpoc/__init__.py index 00f5f42c7f81b93c5ab4c037d289275e3477b00b..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/src/ptbench/data/tbpoc/__init__.py +++ b/src/ptbench/data/tbpoc/__init__.py @@ -1,83 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""TB-POC dataset for computer-aided diagnosis. - -* Reference: [TB-POC-2018]_ -* Original resolution (height x width or width x height): 2048 x 2500 -* Split reference: none -* Stratified kfold protocol: - - * Training samples: 72% of TB and healthy CXR (including labels) - * Validation samples: 18% of TB and healthy CXR (including labels) - * Test samples: 10% of TB and healthy CXR (including labels) -""" - -import importlib.resources -import os - -from ...utils.rc import load_rc -from .. import make_dataset -from ..dataset import JSONDataset -from ..loader import load_pil_grayscale, make_delayed - -_protocols = [ - importlib.resources.files(__name__).joinpath("fold_0.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_1.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_2.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_3.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_4.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_5.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_6.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_7.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), -] - -_datadir = load_rc().get("datadir.tbpoc", os.path.realpath(os.curdir)) - - -def _raw_data_loader(sample): - return dict( - data=load_pil_grayscale(os.path.join(_datadir, sample["data"])), - label=sample["label"], - ) - - -def _loader(context, sample): - # "context" is ignored in this case - database is homogeneous - # we returned delayed samples to avoid loading all images at once - return make_delayed(sample, _raw_data_loader) - - -json_dataset = JSONDataset( - protocols=_protocols, - fieldnames=("data", "label"), - loader=_loader, -) -"""TB-POC dataset object.""" - - -def _maker(protocol, resize_size=512, cc_size=512, RGB=False): - from torchvision import transforms - - from ..augmentations import ElasticDeformation - from ..image_utils import RemoveBlackBorders - - post_transforms = [] - if RGB: - post_transforms = [ - transforms.Lambda(lambda x: x.convert("RGB")), - transforms.ToTensor(), - ] - - return make_dataset( - [json_dataset.subsets(protocol)], - [ - RemoveBlackBorders(), - transforms.Resize(resize_size), - transforms.CenterCrop(cc_size), - ], - [ElasticDeformation(p=0.8)], - post_transforms, - ) diff --git a/src/ptbench/data/tbpoc/datamodule.py b/src/ptbench/data/tbpoc/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..35465bac49ae78db05b769a8d63aee40156f1482 --- /dev/null +++ b/src/ptbench/data/tbpoc/datamodule.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import importlib.resources +import os + +import PIL.Image + +from torchvision.transforms.functional import center_crop, to_tensor + +from ...utils.rc import load_rc +from ..datamodule import CachingDataModule +from ..image_utils import load_pil_grayscale, remove_black_borders +from ..split import JSONDatabaseSplit +from ..typing import DatabaseSplit +from ..typing import RawDataLoader as _BaseRawDataLoader +from ..typing import Sample + + +class RawDataLoader(_BaseRawDataLoader): + """A specialized raw-data-loader for the Shenzen dataset. + + Attributes + ---------- + + datadir + This variable contains the base directory where the database raw data + is stored. + + transform + Transforms that are always applied to the loaded raw images. + """ + + datadir: str + + def __init__(self, config_variable: str = "datadir.tbpoc"): + self.datadir = load_rc().get( + config_variable, os.path.realpath(os.curdir) + ) + + def sample(self, sample: tuple[str, int]) -> Sample: + """Loads a single image sample from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + sample + The sample representation + """ + image = load_pil_grayscale(os.path.join(self.datadir, sample[0])) + image = remove_black_borders(image) + tensor = to_tensor(image) + tensor = center_crop(tensor, min(*tensor.shape[1:])) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() + + return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + + def label(self, sample: tuple[str, int]) -> int: + """Loads a single image sample label from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + label + The integer label associated with the sample + """ + return sample[1] + + +def make_split(basename: str) -> DatabaseSplit: + """Returns a database split for the TB-POC database.""" + + return JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + ) + + +class DataModule(CachingDataModule): + """TB-POC dataset for computer-aided diagnosis. + + * Database reference: [TB-POC-2018]_ + * Original resolution (height x width or width x height): 2048 x 2500 pixels + or 2500 x 2048 pixels + + Data specifications: + + * Raw data input (on disk): + + * jpeg 8-bit grayscale images + * resolution: fixed to one of the cases above + + * Output image: + + * Transforms: + + * Load raw jpeg with :py:mod:`PIL` + * Remove black borders + * Convert to torch tensor + * Torch center cropping to get square image + + * Final specifications: + + * Grayscale, encoded as a single plane tensor, 32-bit floats, + square with varying resolutions, depending on black borders' sizes + on the input image + * Labels: 0 (healthy), 1 (active tuberculosis) + """ + + def __init__(self, split_filename: str): + super().__init__( + database_split=make_split(split_filename), + raw_data_loader=RawDataLoader(), + ) + + diff --git a/src/ptbench/data/tbpoc/fold_0.py b/src/ptbench/data/tbpoc/fold_0.py index 7a423deba65dfd1bc9c1fcfc5750e0cec1bd0563..972e7188f13a0b7e67b3581eb87c0d20acd38794 100644 --- a/src/ptbench/data/tbpoc/fold_0.py +++ b/src/ptbench/data/tbpoc/fold_0.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """TB-POC dataset for TB detection (cross validation fold 0) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.tbpoc` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) +* This configuration resolution: varying depending of black borders on original + image +* See :py:mod:`ptbench.data.tbpoc` for dataset details +""" -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-0") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - +from .datamodule import DataModule -datamodule = Fold0Module +datamodule = DataModule("fold-0.json") diff --git a/src/ptbench/data/tbpoc/fold_1.py b/src/ptbench/data/tbpoc/fold_1.py index cb4c59baba491d1e8edf666c71e373b7dbf6a5ec..79b9bfcaec144157770c3be12705f73fcb0f5c79 100644 --- a/src/ptbench/data/tbpoc/fold_1.py +++ b/src/ptbench/data/tbpoc/fold_1.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """TB-POC dataset for TB detection (cross validation fold 1) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.tbpoc` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) +* This configuration resolution: varying depending of black borders on original + image +* See :py:mod:`ptbench.data.tbpoc` for dataset details +""" -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-1") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - +from .datamodule import DataModule -datamodule = Fold0Module +datamodule = DataModule("fold-1.json") diff --git a/src/ptbench/data/tbpoc/fold_2.py b/src/ptbench/data/tbpoc/fold_2.py index 1bffecea5721496cc24731147d5b0d62db6c93d2..9d41fb595637dce8f31944b4aa2eeee2bd60e58d 100644 --- a/src/ptbench/data/tbpoc/fold_2.py +++ b/src/ptbench/data/tbpoc/fold_2.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """TB-POC dataset for TB detection (cross validation fold 2) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.tbpoc` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) +* This configuration resolution: varying depending of black borders on original + image +* See :py:mod:`ptbench.data.tbpoc` for dataset details +""" -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-2") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - +from .datamodule import DataModule -datamodule = Fold0Module +datamodule = DataModule("fold-2.json") diff --git a/src/ptbench/data/tbpoc/fold_3.py b/src/ptbench/data/tbpoc/fold_3.py index 1263d39b123c002d065c32a5de17d753a3d14121..08672b3f325a8e60e19b8254ea89bc59fc3ad78b 100644 --- a/src/ptbench/data/tbpoc/fold_3.py +++ b/src/ptbench/data/tbpoc/fold_3.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """TB-POC dataset for TB detection (cross validation fold 3) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.tbpoc` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) +* This configuration resolution: varying depending of black borders on original + image +* See :py:mod:`ptbench.data.tbpoc` for dataset details +""" -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-3") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - +from .datamodule import DataModule -datamodule = Fold0Module +datamodule = DataModule("fold-3.json") diff --git a/src/ptbench/data/tbpoc/fold_4.py b/src/ptbench/data/tbpoc/fold_4.py index 119adfa98e94ed165232b73f49e7998359376cf0..8354a4c2d7038c35620a7220afa7e9a8731d44fd 100644 --- a/src/ptbench/data/tbpoc/fold_4.py +++ b/src/ptbench/data/tbpoc/fold_4.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """TB-POC dataset for TB detection (cross validation fold 4) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.tbpoc` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) +* This configuration resolution: varying depending of black borders on original + image +* See :py:mod:`ptbench.data.tbpoc` for dataset details +""" -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-4") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - +from .datamodule import DataModule -datamodule = Fold0Module +datamodule = DataModule("fold-4.json") diff --git a/src/ptbench/data/tbpoc/fold_5.py b/src/ptbench/data/tbpoc/fold_5.py index 2a90cbdb4a706435f8f8589d863132a07191ab5c..cb7f95612e23dca6af8d3d06dfaf6ae76319ed6f 100644 --- a/src/ptbench/data/tbpoc/fold_5.py +++ b/src/ptbench/data/tbpoc/fold_5.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """TB-POC dataset for TB detection (cross validation fold 5) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.tbpoc` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) +* This configuration resolution: varying depending of black borders on original + image +* See :py:mod:`ptbench.data.tbpoc` for dataset details +""" -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-5") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - +from .datamodule import DataModule -datamodule = Fold0Module +datamodule = DataModule("fold-5.json") diff --git a/src/ptbench/data/tbpoc/fold_6.py b/src/ptbench/data/tbpoc/fold_6.py index 42ed763dde698b3ce98f350068142539d07532b0..379211aad631cf9beac280d598f52beb6746eac0 100644 --- a/src/ptbench/data/tbpoc/fold_6.py +++ b/src/ptbench/data/tbpoc/fold_6.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """TB-POC dataset for TB detection (cross validation fold 6) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.tbpoc` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) +* This configuration resolution: varying depending of black borders on original + image +* See :py:mod:`ptbench.data.tbpoc` for dataset details +""" -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-6") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - +from .datamodule import DataModule -datamodule = Fold0Module +datamodule = DataModule("fold-6.json") diff --git a/src/ptbench/data/tbpoc/fold_7.py b/src/ptbench/data/tbpoc/fold_7.py index ad7dbe14e752f8f3015c18af775064f051515a34..b846b88af5cf7375f578ee2ffbc24055a4a3ff85 100644 --- a/src/ptbench/data/tbpoc/fold_7.py +++ b/src/ptbench/data/tbpoc/fold_7.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """TB-POC dataset for TB detection (cross validation fold 7) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.tbpoc` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) +* This configuration resolution: varying depending of black borders on original + image +* See :py:mod:`ptbench.data.tbpoc` for dataset details +""" -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-7") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - +from .datamodule import DataModule -datamodule = Fold0Module +datamodule = DataModule("fold-7.json") diff --git a/src/ptbench/data/tbpoc/fold_8.py b/src/ptbench/data/tbpoc/fold_8.py index 4bcea788633542d7bc9f52ef35966542891e5f91..acfd42964fe21cf15c1d47a5bc5df794fbcba961 100644 --- a/src/ptbench/data/tbpoc/fold_8.py +++ b/src/ptbench/data/tbpoc/fold_8.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """TB-POC dataset for TB detection (cross validation fold 8) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.tbpoc` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) +* This configuration resolution: varying depending of black borders on original + image +* See :py:mod:`ptbench.data.tbpoc` for dataset details +""" -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-8") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - +from .datamodule import DataModule -datamodule = Fold0Module +datamodule = DataModule("fold-8.json") diff --git a/src/ptbench/data/tbpoc/fold_9.py b/src/ptbench/data/tbpoc/fold_9.py index c33eb6bd3d2dbe8635e2b3cc2a2f7e81e859e180..4634068e5942bf9d7062876ee2007702083de1ed 100644 --- a/src/ptbench/data/tbpoc/fold_9.py +++ b/src/ptbench/data/tbpoc/fold_9.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """TB-POC dataset for TB detection (cross validation fold 9) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.tbpoc` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) +* This configuration resolution: varying depending of black borders on original + image +* See :py:mod:`ptbench.data.tbpoc` for dataset details +""" -class Fold0Module(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-9") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - +from .datamodule import DataModule -datamodule = Fold0Module +datamodule = DataModule("fold-9.json") diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py index 9609ea66be7f59a655f744cba8b7f94e8cafb382..ee34d8d09d14379dfc08dcc4adebc33ffedd18b9 100644 --- a/tests/test_tbpoc.py +++ b/tests/test_tbpoc.py @@ -4,106 +4,126 @@ """Tests for TB-POC dataset.""" import pytest +import torch -dataset = None +from ptbench.data.tbpoc.datamodule import make_split -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - # Cross-validation fold 0-6 - for f in range(7): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 +def _check_split( + split_filename: str, + lengths: dict[str, int], + prefix: str = "TBPOC_CXR/", + extension: str = ".jpeg", + possible_labels: list[int] = [0, 1], +): + """Runs a simple consistence check on the data split. - assert "train" in subset - assert len(subset["train"]) == 292 - for s in subset["train"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") + Parameters + ---------- - assert "validation" in subset - assert len(subset["validation"]) == 74 - for s in subset["validation"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") + split_filename + This is the split we will check - assert "test" in subset - assert len(subset["test"]) == 41 - for s in subset["test"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") + lenghts + A dictionary that contains keys matching those of the split (this will + be checked). The values of the dictionary should correspond to the + sizes of each of the datasets in the split. - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] + prefix + Each file named in a split should start with this prefix. - for s in subset["validation"]: - assert s.label in [0.0, 1.0] + extension + Each file named in a split should end with this extension. - for s in subset["test"]: - assert s.label in [0.0, 1.0] + possible_labels + These are the list of possible labels contained in any split. + """ - # Cross-validation fold 7-9 - for f in range(7, 10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 + split = make_split(split_filename) - assert "train" in subset - assert len(subset["train"]) == 293 - for s in subset["train"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") + assert len(split) == len(lengths) - assert "validation" in subset - assert len(subset["validation"]) == 74 - for s in subset["validation"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") + for k in lengths.keys(): + # dataset must have been declared + assert k in split - assert "test" in subset - assert len(subset["test"]) == 40 - for s in subset["test"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") + assert len(split[k]) == lengths[k] + for s in split[k]: + # assert s[0].startswith(prefix) + assert s[0].endswith(extension) + assert s[1] in possible_labels - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - for s in subset["validation"]: - assert s.label in [0.0, 1.0] +def _check_loaded_batch( + batch, + size: int = 1, + prefix: str = "TBPOC_CXR/", + extension: str = ".jpeg", + possible_labels: list[int] = [0, 1], +): + """Checks the consistence of an individual (loaded) batch. - for s in subset["test"]: - assert s.label in [0.0, 1.0] + Parameters + ---------- + batch + The loaded batch to be checked. -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc") -def test_loading(): - image_size_portrait = (2048, 2500) - image_size_landscape = (2500, 2048) + prefix + Each file named in a split should start with this prefix. + + extension + Each file named in a split should end with this extension. + + possible_labels + These are the list of possible labels contained in any split. + """ - def _check_size(size): - if size == image_size_portrait: - return True - elif size == image_size_landscape: - return True - return False + assert len(batch) == 2 # data, metadata - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 + assert isinstance(batch[0], torch.Tensor) + assert batch[0].shape[0] == size # mini-batch size + assert batch[0].shape[1] == 1 # grayscale images + assert batch[0].shape[2] == batch[0].shape[3] # image is square - assert "data" in data - assert _check_size(data["data"].size) # Check size - assert data["data"].mode, "L" # Check colors + assert isinstance(batch[1], dict) # metadata + assert len(batch[1]) == 2 # label and name - assert "label" in data - assert data["label"] in [0, 1] # Check labels + assert "label" in batch[1] + assert all([k in possible_labels for k in batch[1]["label"]]) - limit = 30 # use this to limit testing to first images only, else None + assert "name" in batch[1] + # assert all([k.startswith(prefix) for k in batch[1]["name"]]) + assert all([k.endswith(extension) for k in batch[1]["name"]]) - subset = dataset.subsets("fold_0") - for s in subset["train"][:limit]: - _check_sample(s) + +def test_protocol_consistency(): + # Cross-validation fold 0-6 + for k in range(7): + _check_split( + f"fold-{k}.json", + lengths=dict(train=292, validation=74, test=41), + ) + + # Cross-validation fold 7-9 + for k in range(7, 10): + _check_split( + f"fold-{k}.json", + lengths=dict(train=293, validation=74, test=40), + ) -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc") -def test_check(): - assert dataset.check() == 0 +@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") +def test_loading(): + from ptbench.data.tbpoc.fold_0 import datamodule + + datamodule.model_transforms = [] # should be done before setup() + datamodule.setup("predict") # sets up all datasets + + for loader in datamodule.predict_dataloader().values(): + limit = 5 # limit load checking + for batch in loader: + if limit == 0: + break + _check_loaded_batch(batch) + limit -= 1