From 75f98d0c95cb91e5333b134e40e6ad414b4c026b Mon Sep 17 00:00:00 2001 From: mdelitroz <maxime.delitroz@idiap.ch> Date: Wed, 2 Aug 2023 15:09:48 +0200 Subject: [PATCH] updated HIV-TB dataset and related tests --- src/ptbench/data/hivtb/__init__.py | 82 ------------- src/ptbench/data/hivtb/datamodule.py | 132 ++++++++++++++++++++ src/ptbench/data/hivtb/fold_0.py | 44 ++----- src/ptbench/data/hivtb/fold_1.py | 44 ++----- src/ptbench/data/hivtb/fold_2.py | 44 ++----- src/ptbench/data/hivtb/fold_3.py | 44 ++----- src/ptbench/data/hivtb/fold_4.py | 44 ++----- src/ptbench/data/hivtb/fold_5.py | 44 ++----- src/ptbench/data/hivtb/fold_6.py | 44 ++----- src/ptbench/data/hivtb/fold_7.py | 44 ++----- src/ptbench/data/hivtb/fold_8.py | 44 ++----- src/ptbench/data/hivtb/fold_9.py | 44 ++----- tests/test_hivtb.py | 172 +++++++++++++++------------ 13 files changed, 328 insertions(+), 498 deletions(-) create mode 100644 src/ptbench/data/hivtb/datamodule.py diff --git a/src/ptbench/data/hivtb/__init__.py b/src/ptbench/data/hivtb/__init__.py index 0fca31ae..e69de29b 100644 --- a/src/ptbench/data/hivtb/__init__.py +++ b/src/ptbench/data/hivtb/__init__.py @@ -1,82 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for computer-aided diagnosis (only BMP files) - -* Reference: [HIV-TB-2019]_ -* Original resolution (height x width or width x height): 2048 x 2500 -* Split reference: none -* Stratified kfold protocol: - - * Training samples: 72% of TB and healthy CXR (including labels) - * Validation samples: 18% of TB and healthy CXR (including labels) - * Test samples: 10% of TB and healthy CXR (including labels) -""" - -import importlib.resources -import os - -from ...utils.rc import load_rc -from .. import make_dataset -from ..dataset import JSONDataset -from ..loader import load_pil_grayscale, make_delayed - -_protocols = [ - importlib.resources.files(__name__).joinpath("fold_0.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_1.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_2.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_3.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_4.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_5.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_6.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_7.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), -] - -_datadir = load_rc().get("datadir.hivtb", os.path.realpath(os.curdir)) - - -def _raw_data_loader(sample): - return dict( - data=load_pil_grayscale(os.path.join(_datadir, sample["data"])), - label=sample["label"], - ) - - -def _loader(context, sample): - # "context" is ignored in this case - database is homogeneous - # we returned delayed samples to avoid loading all images at once - return make_delayed(sample, _raw_data_loader) - - -json_dataset = JSONDataset( - protocols=_protocols, - fieldnames=("data", "label"), - loader=_loader, -) -"""HIV-TB dataset object.""" - - -def _maker(protocol, resize_size=512, cc_size=512, RGB=False): - from torchvision import transforms - - from ..augmentations import ElasticDeformation, RemoveBlackBorders - - post_transforms = [] - if RGB: - post_transforms = [ - transforms.Lambda(lambda x: x.convert("RGB")), - transforms.ToTensor(), - ] - - return make_dataset( - [json_dataset.subsets(protocol)], - [ - RemoveBlackBorders(), - transforms.Resize(resize_size), - transforms.CenterCrop(cc_size), - ], - [ElasticDeformation(p=0.8)], - post_transforms, - ) diff --git a/src/ptbench/data/hivtb/datamodule.py b/src/ptbench/data/hivtb/datamodule.py new file mode 100644 index 00000000..63075c61 --- /dev/null +++ b/src/ptbench/data/hivtb/datamodule.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import importlib.resources +import os + +import PIL.Image + +from torchvision.transforms.functional import center_crop, to_tensor + +from ...utils.rc import load_rc +from ..datamodule import CachingDataModule +from ..image_utils import load_pil_grayscale, remove_black_borders +from ..split import JSONDatabaseSplit +from ..typing import DatabaseSplit +from ..typing import RawDataLoader as _BaseRawDataLoader +from ..typing import Sample + + +class RawDataLoader(_BaseRawDataLoader): + """A specialized raw-data-loader for the HIV-TB dataset. + + Attributes + ---------- + + datadir + This variable contains the base directory where the database raw data + is stored. + """ + + datadir: str + + def __init__(self): + self.datadir = load_rc().get( + "datadir.hivtb", os.path.realpath(os.curdir) + ) + + def sample(self, sample: tuple[str, int]) -> Sample: + """Loads a single image sample from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + sample + The sample representation + """ + image = load_pil_grayscale(os.path.join(self.datadir, sample[0])) + image = remove_black_borders(image) + tensor = to_tensor(image) + tensor = center_crop(tensor, min(*tensor.shape[1:])) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() + + return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + + def label(self, sample: tuple[str, int]) -> int: + """Loads a single image sample label from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + label + The integer label associated with the sample + """ + return sample[1] + + +def make_split(basename: str) -> DatabaseSplit: + """Returns a database split for the HIV-TB database.""" + + return JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + ) + + +class DataModule(CachingDataModule): + """HIV-TB dataset for computer-aided diagnosis (only BMP files) + + * Database reference: [HIV-TB-2019]_ + * Original resolution (height x width or width x height): 2048 x 2500 pixels + or 2500 x 2048 pixels + + Data specifications: + + * Raw data input (on disk): + + * BMP images 8 bit grayscale + * resolution fixed to one of the cases above + + * Output image: + + * Transforms: + + * Load raw BMP with :py:mod:`PIL` + * Remove black borders + * Convert to torch tensor + * Torch center cropping to get square image + + * Final specifications + + * Grayscale, encoded as a single plane tensor, 32-bit floats, + square at 2048 x 2048 pixels + * Labels: 0 (healthy), 1 (active tuberculosis) + """ + + def __init__(self, split_filename: str): + super().__init__( + database_split=make_split(split_filename), + raw_data_loader=RawDataLoader(), + ) diff --git a/src/ptbench/data/hivtb/fold_0.py b/src/ptbench/data/hivtb/fold_0.py index e8caee65..ba9e9150 100644 --- a/src/ptbench/data/hivtb/fold_0.py +++ b/src/ptbench/data/hivtb/fold_0.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """HIV-TB dataset for TB detection (cross validation fold 0) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) +* This configuration resolution: 2048 x 2048 (default) +* See :py:mod:`ptbench.data.hivtb` for dataset details +""" - def setup(self, stage: str): - self.dataset = _maker("fold-0") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) +from .datamodule import DataModule +datamodule = DataModule("fold-0.json") -datamodule = DefaultModule diff --git a/src/ptbench/data/hivtb/fold_1.py b/src/ptbench/data/hivtb/fold_1.py index bb12b311..84fb7581 100644 --- a/src/ptbench/data/hivtb/fold_1.py +++ b/src/ptbench/data/hivtb/fold_1.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """HIV-TB dataset for TB detection (cross validation fold 1) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) +* This configuration resolution: 2048 x 2048 (default) +* See :py:mod:`ptbench.data.hivtb` for dataset details +""" - def setup(self, stage: str): - self.dataset = _maker("fold-1") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) +from .datamodule import DataModule +datamodule = DataModule("fold-1.json") -datamodule = DefaultModule diff --git a/src/ptbench/data/hivtb/fold_2.py b/src/ptbench/data/hivtb/fold_2.py index 7bd3703e..a5f5e97a 100644 --- a/src/ptbench/data/hivtb/fold_2.py +++ b/src/ptbench/data/hivtb/fold_2.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """HIV-TB dataset for TB detection (cross validation fold 2) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) +* This configuration resolution: 2048 x 2048 (default) +* See :py:mod:`ptbench.data.hivtb` for dataset details +""" - def setup(self, stage: str): - self.dataset = _maker("fold-2") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) +from .datamodule import DataModule +datamodule = DataModule("fold-2.json") -datamodule = DefaultModule diff --git a/src/ptbench/data/hivtb/fold_3.py b/src/ptbench/data/hivtb/fold_3.py index cac94f67..1b643ae4 100644 --- a/src/ptbench/data/hivtb/fold_3.py +++ b/src/ptbench/data/hivtb/fold_3.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """HIV-TB dataset for TB detection (cross validation fold 3) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) +* This configuration resolution: 2048 x 2048 (default) +* See :py:mod:`ptbench.data.hivtb` for dataset details +""" - def setup(self, stage: str): - self.dataset = _maker("fold-3") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) +from .datamodule import DataModule +datamodule = DataModule("fold-3.json") -datamodule = DefaultModule diff --git a/src/ptbench/data/hivtb/fold_4.py b/src/ptbench/data/hivtb/fold_4.py index c5952356..581eb85c 100644 --- a/src/ptbench/data/hivtb/fold_4.py +++ b/src/ptbench/data/hivtb/fold_4.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """HIV-TB dataset for TB detection (cross validation fold 4) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) +* This configuration resolution: 2048 x 2048 (default) +* See :py:mod:`ptbench.data.hivtb` for dataset details +""" - def setup(self, stage: str): - self.dataset = _maker("fold-4") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) +from .datamodule import DataModule +datamodule = DataModule("fold-4.json") -datamodule = DefaultModule diff --git a/src/ptbench/data/hivtb/fold_5.py b/src/ptbench/data/hivtb/fold_5.py index bc80b9ff..47ae66d1 100644 --- a/src/ptbench/data/hivtb/fold_5.py +++ b/src/ptbench/data/hivtb/fold_5.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """HIV-TB dataset for TB detection (cross validation fold 5) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) +* This configuration resolution: 2048 x 2048 (default) +* See :py:mod:`ptbench.data.hivtb` for dataset details +""" - def setup(self, stage: str): - self.dataset = _maker("fold-5") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) +from .datamodule import DataModule +datamodule = DataModule("fold-5.json") -datamodule = DefaultModule diff --git a/src/ptbench/data/hivtb/fold_6.py b/src/ptbench/data/hivtb/fold_6.py index d1a646dc..c93232f4 100644 --- a/src/ptbench/data/hivtb/fold_6.py +++ b/src/ptbench/data/hivtb/fold_6.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """HIV-TB dataset for TB detection (cross validation fold 6) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) +* This configuration resolution: 2048 x 2048 (default) +* See :py:mod:`ptbench.data.hivtb` for dataset details +""" - def setup(self, stage: str): - self.dataset = _maker("fold-6") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) +from .datamodule import DataModule +datamodule = DataModule("fold-6.json") -datamodule = DefaultModule diff --git a/src/ptbench/data/hivtb/fold_7.py b/src/ptbench/data/hivtb/fold_7.py index de29f234..33d5cc83 100644 --- a/src/ptbench/data/hivtb/fold_7.py +++ b/src/ptbench/data/hivtb/fold_7.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """HIV-TB dataset for TB detection (cross validation fold 7) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) +* This configuration resolution: 2048 x 2048 (default) +* See :py:mod:`ptbench.data.hivtb` for dataset details +""" - def setup(self, stage: str): - self.dataset = _maker("fold-7") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) +from .datamodule import DataModule +datamodule = DataModule("fold-7.json") -datamodule = DefaultModule diff --git a/src/ptbench/data/hivtb/fold_8.py b/src/ptbench/data/hivtb/fold_8.py index 9370dcea..91d89557 100644 --- a/src/ptbench/data/hivtb/fold_8.py +++ b/src/ptbench/data/hivtb/fold_8.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """HIV-TB dataset for TB detection (cross validation fold 8) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) +* This configuration resolution: 2048 x 2048 (default) +* See :py:mod:`ptbench.data.hivtb` for dataset details +""" - def setup(self, stage: str): - self.dataset = _maker("fold-8") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) +from .datamodule import DataModule +datamodule = DataModule("fold-8.json") -datamodule = DefaultModule diff --git a/src/ptbench/data/hivtb/fold_9.py b/src/ptbench/data/hivtb/fold_9.py index 70605f8d..0e0063e8 100644 --- a/src/ptbench/data/hivtb/fold_9.py +++ b/src/ptbench/data/hivtb/fold_9.py @@ -1,45 +1,21 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later + """HIV-TB dataset for TB detection (cross validation fold 9) * Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker +* Stratified kfold protocol: + * Training samples: 72% of TB and healthy CXR (including labels) + * Validation samples: 18% of TB and healthy CXR (including labels) + * Test samples: 10% of TB and healthy CXR (including labels) -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) +* This configuration resolution: 2048 x 2048 (default) +* See :py:mod:`ptbench.data.hivtb` for dataset details +""" - def setup(self, stage: str): - self.dataset = _maker("fold-9") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) +from .datamodule import DataModule +datamodule = DataModule("fold-9.json") -datamodule = DefaultModule diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py index 37876051..9e814138 100644 --- a/tests/test_hivtb.py +++ b/tests/test_hivtb.py @@ -4,106 +4,126 @@ """Tests for HIV-TB dataset.""" import pytest +import torch -dataset = None +from ptbench.data.hivtb.datamodule import make_split -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - # Cross-validation fold 0-2 - for f in range(3): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 +def _check_split( + split_filename: str, + lengths: dict[str, int], + prefix: str = "HIV-TB_Algorithm_study_X-rays/", + extension: str = ".BMP", + possible_labels: list[int] = [0, 1], +): + """Runs a simple consistence check on the data split. - assert "train" in subset - assert len(subset["train"]) == 174 - for s in subset["train"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + Parameters + ---------- - assert "validation" in subset - assert len(subset["validation"]) == 44 - for s in subset["validation"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + split_filename + This is the split we will check - assert "test" in subset - assert len(subset["test"]) == 25 - for s in subset["test"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + lenghts + A dictionary that contains keys matching those of the split (this will + be checked). The values of the dictionary should correspond to the + sizes of each of the datasets in the split. - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] + prefix + Each file named in a split should start with this prefix. - for s in subset["validation"]: - assert s.label in [0.0, 1.0] + extension + Each file named in a split should end with this extension. - for s in subset["test"]: - assert s.label in [0.0, 1.0] + possible_labels + These are the list of possible labels contained in any split. + """ - # Cross-validation fold 3-9 - for f in range(3, 10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 + split = make_split(split_filename) - assert "train" in subset - assert len(subset["train"]) == 175 - for s in subset["train"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + assert len(split) == len(lengths) - assert "validation" in subset - assert len(subset["validation"]) == 44 - for s in subset["validation"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + for k in lengths.keys(): + # dataset must have been declared + assert k in split - assert "test" in subset - assert len(subset["test"]) == 24 - for s in subset["test"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + assert len(split[k]) == lengths[k] + for s in split[k]: + assert s[0].startswith(prefix) + assert s[0].endswith(extension) + assert s[1] in possible_labels - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - for s in subset["validation"]: - assert s.label in [0.0, 1.0] +def _check_loaded_batch( + batch, + size: int = 1, + prefix: str = "HIV-TB_Algorithm_study_X-rays/", + extension: str = ".BMP", + possible_labels: list[int] = [0, 1], +): + """Checks the consistence of an individual (loaded) batch. - for s in subset["test"]: - assert s.label in [0.0, 1.0] + Parameters + ---------- + batch + The loaded batch to be checked. -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") -def test_loading(): - image_size_portrait = (2048, 2500) - image_size_landscape = (2500, 2048) + prefix + Each file named in a split should start with this prefix. + + extension + Each file named in a split should end with this extension. - def _check_size(size): - if size == image_size_portrait: - return True - elif size == image_size_landscape: - return True - return False + possible_labels + These are the list of possible labels contained in any split. + """ - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 + assert len(batch) == 2 # data, metadata - assert "data" in data - assert _check_size(data["data"].size) # Check size - assert data["data"].mode == "L" # Check colors + assert isinstance(batch[0], torch.Tensor) + assert batch[0].shape[0] == size # mini-batch size + assert batch[0].shape[1] == 1 # grayscale images + assert batch[0].shape[2] == batch[0].shape[3] # image is square - assert "label" in data - assert data["label"] in [0, 1] # Check labels + assert isinstance(batch[1], dict) # metadata + assert len(batch[1]) == 2 # label and name - limit = 30 # use this to limit testing to first images only, else None + assert "label" in batch[1] + assert all([k in possible_labels for k in batch[1]["label"]]) - subset = dataset.subsets("fold_0") - for s in subset["train"][:limit]: - _check_sample(s) + assert "name" in batch[1] + assert all([k.startswith(prefix) for k in batch[1]["name"]]) + assert all([k.endswith(extension) for k in batch[1]["name"]]) + + +def test_protocol_consistency(): + # Cross-validation fold 0-2 + for k in range(3): + _check_split( + f"fold-{k}.json", + lengths=dict(train=174, validation=44, test=25), + ) + + # Cross-validation fold 3-9 + for k in range(3, 10): + _check_split( + f"fold-{k}.json", + lengths=dict(train=175, validation=44, test=24), + ) -@pytest.mark.skip(reason="Test need to be updated") @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") -def test_check(): - assert dataset.check() == 0 +def test_loading(): + from ptbench.data.hivtb.fold_0 import datamodule + + datamodule.model_transforms = [] # should be done before setup() + datamodule.setup("predict") # sets up all datasets + + for loader in datamodule.predict_dataloader().values(): + limit = 5 # limit load checking + for batch in loader: + if limit == 0: + break + _check_loaded_batch(batch) + limit -= 1 \ No newline at end of file -- GitLab