diff --git a/src/ptbench/data/hivtb/__init__.py b/src/ptbench/data/hivtb/__init__.py index 0fca31ae5ad1553328baa81b59282fb1c8622391..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/src/ptbench/data/hivtb/__init__.py +++ b/src/ptbench/data/hivtb/__init__.py @@ -1,82 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for computer-aided diagnosis (only BMP files) - -* Reference: [HIV-TB-2019]_ -* Original resolution (height x width or width x height): 2048 x 2500 -* Split reference: none -* Stratified kfold protocol: - - * Training samples: 72% of TB and healthy CXR (including labels) - * Validation samples: 18% of TB and healthy CXR (including labels) - * Test samples: 10% of TB and healthy CXR (including labels) -""" - -import importlib.resources -import os - -from ...utils.rc import load_rc -from .. import make_dataset -from ..dataset import JSONDataset -from ..loader import load_pil_grayscale, make_delayed - -_protocols = [ - importlib.resources.files(__name__).joinpath("fold_0.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_1.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_2.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_3.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_4.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_5.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_6.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_7.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), - importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), -] - -_datadir = load_rc().get("datadir.hivtb", os.path.realpath(os.curdir)) - - -def _raw_data_loader(sample): - return dict( - data=load_pil_grayscale(os.path.join(_datadir, sample["data"])), - label=sample["label"], - ) - - -def _loader(context, sample): - # "context" is ignored in this case - database is homogeneous - # we returned delayed samples to avoid loading all images at once - return make_delayed(sample, _raw_data_loader) - - -json_dataset = JSONDataset( - protocols=_protocols, - fieldnames=("data", "label"), - loader=_loader, -) -"""HIV-TB dataset object.""" - - -def _maker(protocol, resize_size=512, cc_size=512, RGB=False): - from torchvision import transforms - - from ..augmentations import ElasticDeformation, RemoveBlackBorders - - post_transforms = [] - if RGB: - post_transforms = [ - transforms.Lambda(lambda x: x.convert("RGB")), - transforms.ToTensor(), - ] - - return make_dataset( - [json_dataset.subsets(protocol)], - [ - RemoveBlackBorders(), - transforms.Resize(resize_size), - transforms.CenterCrop(cc_size), - ], - [ElasticDeformation(p=0.8)], - post_transforms, - ) diff --git a/src/ptbench/data/hivtb/datamodule.py b/src/ptbench/data/hivtb/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b84ec434ea4ce4145119bf565d0d8ee8a091d9 --- /dev/null +++ b/src/ptbench/data/hivtb/datamodule.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import importlib.resources +import os + +import PIL.Image + +from torchvision.transforms.functional import center_crop, to_tensor + +from ...utils.rc import load_rc +from ..datamodule import CachingDataModule +from ..image_utils import remove_black_borders +from ..split import JSONDatabaseSplit +from ..typing import DatabaseSplit +from ..typing import RawDataLoader as _BaseRawDataLoader +from ..typing import Sample + + +class RawDataLoader(_BaseRawDataLoader): + """A specialized raw-data-loader for the HIV-TB dataset. + + Attributes + ---------- + + datadir + This variable contains the base directory where the database raw data + is stored. + """ + + datadir: str + + def __init__(self): + self.datadir = load_rc().get( + "datadir.hivtb", os.path.realpath(os.curdir) + ) + + def sample(self, sample: tuple[str, int]) -> Sample: + """Loads a single image sample from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + sample + The sample representation + """ + image = PIL.Image.open(os.path.join(self.datadir, sample[0])).convert( + "L" + ) + image = remove_black_borders(image) + tensor = to_tensor(image) + tensor = center_crop(tensor, min(*tensor.shape[1:])) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(tensor).show() + # __import__("pdb").set_trace() + + return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + + def label(self, sample: tuple[str, int]) -> int: + """Loads a single image sample label from the disk. + + Parameters + ---------- + + sample: + A tuple containing the path suffix, within the dataset root folder, + where to find the image to be loaded, and an integer, representing the + sample label. + + + Returns + ------- + + label + The integer label associated with the sample + """ + return sample[1] + + +def make_split(basename: str) -> DatabaseSplit: + """Returns a database split for the HIV-TB database.""" + + return JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename) + ) + + +class DataModule(CachingDataModule): + """HIV-TB dataset for computer-aided diagnosis (only BMP files) + + * Database reference: [HIV-TB-2019]_ + * Original resolution, varying with most images being 2048 x 2500 pixels + or 2500 x 2048 pixels, but not all. + + Data specifications: + + * Raw data input (on disk): + + * BMP (BMP3) and JPEG grayscale images encoded as 8-bit RGB, with + varying resolution + + * Output image: + + * Transforms: + + * Load raw BMP or JPEG with :py:mod:`PIL` + * Remove black borders + * Convert to torch tensor + * Torch center cropping to get square image + + * Final specifications + + * Grayscale, encoded as a single plane tensor, 32-bit floats, + square at 2048 x 2048 pixels + * Labels: 0 (healthy), 1 (active tuberculosis) + """ + + def __init__(self, split_filename: str): + super().__init__( + database_split=make_split(split_filename), + raw_data_loader=RawDataLoader(), + ) diff --git a/src/ptbench/data/hivtb/fold_0.py b/src/ptbench/data/hivtb/fold_0.py index e8caee654fc076f31b2ed7abffcedfc5096ea568..57d77952ad0f012f4c7224f38dc293aa58d72dcd 100644 --- a/src/ptbench/data/hivtb/fold_0.py +++ b/src/ptbench/data/hivtb/fold_0.py @@ -1,45 +1,11 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for TB detection (cross validation fold 0) -* Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +from .datamodule import DataModule +datamodule = DataModule("fold-0.json") +"""HIV-TB dataset for TB detection (cross validation fold 0). -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-0") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +See :py:class:`DataModule` for technical details. +""" diff --git a/src/ptbench/data/hivtb/fold_1.py b/src/ptbench/data/hivtb/fold_1.py index bb12b3114f251cbbed1bc23700621b111de0398a..c91a968f500204bd1fa30e43e168dbf3e7f0edab 100644 --- a/src/ptbench/data/hivtb/fold_1.py +++ b/src/ptbench/data/hivtb/fold_1.py @@ -1,45 +1,11 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for TB detection (cross validation fold 1) -* Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +from .datamodule import DataModule +datamodule = DataModule("fold-1.json") +"""HIV-TB dataset for TB detection (cross validation fold 1). -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-1") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +See :py:class:`DataModule` for technical details. +""" diff --git a/src/ptbench/data/hivtb/fold_2.py b/src/ptbench/data/hivtb/fold_2.py index 7bd3703ef4383e1c6898ef55b1cda1329ead3af3..323e80a02a0b44b5691d13abc971679182e2d97f 100644 --- a/src/ptbench/data/hivtb/fold_2.py +++ b/src/ptbench/data/hivtb/fold_2.py @@ -1,45 +1,11 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for TB detection (cross validation fold 2) -* Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +from .datamodule import DataModule +datamodule = DataModule("fold-2.json") +"""HIV-TB dataset for TB detection (cross validation fold 2). -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-2") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +See :py:class:`DataModule` for technical details. +""" diff --git a/src/ptbench/data/hivtb/fold_3.py b/src/ptbench/data/hivtb/fold_3.py index cac94f6721c73e355e3eca23e48721a5ff27a272..1eed4c056648bce88f174ccbce8a71efe69fc136 100644 --- a/src/ptbench/data/hivtb/fold_3.py +++ b/src/ptbench/data/hivtb/fold_3.py @@ -1,45 +1,11 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for TB detection (cross validation fold 3) -* Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +from .datamodule import DataModule +datamodule = DataModule("fold-3.json") +"""HIV-TB dataset for TB detection (cross validation fold 3). -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-3") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +See :py:class:`DataModule` for technical details. +""" diff --git a/src/ptbench/data/hivtb/fold_4.py b/src/ptbench/data/hivtb/fold_4.py index c59523565f6028bebd28b32ec248a5e314df7d95..9cfa6186d6dc7d44f8bcfa56d7c978e7bf346c54 100644 --- a/src/ptbench/data/hivtb/fold_4.py +++ b/src/ptbench/data/hivtb/fold_4.py @@ -1,45 +1,11 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for TB detection (cross validation fold 4) -* Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +from .datamodule import DataModule +datamodule = DataModule("fold-4.json") +"""HIV-TB dataset for TB detection (cross validation fold 4). -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-4") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +See :py:class:`DataModule` for technical details. +""" diff --git a/src/ptbench/data/hivtb/fold_5.py b/src/ptbench/data/hivtb/fold_5.py index bc80b9fff624e2341f073255afcacbe6974652cc..591fef3732b522569a92082cb7e3c208c16bf2da 100644 --- a/src/ptbench/data/hivtb/fold_5.py +++ b/src/ptbench/data/hivtb/fold_5.py @@ -1,45 +1,11 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for TB detection (cross validation fold 5) -* Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +from .datamodule import DataModule +datamodule = DataModule("fold-5.json") +"""HIV-TB dataset for TB detection (cross validation fold 5). -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-5") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +See :py:class:`DataModule` for technical details. +""" diff --git a/src/ptbench/data/hivtb/fold_6.py b/src/ptbench/data/hivtb/fold_6.py index d1a646dc045fb6eeb8e500d8c4e4c8a360cb04ea..fb5e1614b349779d42771bd165a9a1d96c6cb83d 100644 --- a/src/ptbench/data/hivtb/fold_6.py +++ b/src/ptbench/data/hivtb/fold_6.py @@ -1,45 +1,11 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for TB detection (cross validation fold 6) -* Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +from .datamodule import DataModule +datamodule = DataModule("fold-6.json") +"""HIV-TB dataset for TB detection (cross validation fold 6). -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-6") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +See :py:class:`DataModule` for technical details. +""" diff --git a/src/ptbench/data/hivtb/fold_7.py b/src/ptbench/data/hivtb/fold_7.py index de29f23467139c94734af762ea89ba00069504d8..d64db4837f24058d34b2daf4c8383595aee7be21 100644 --- a/src/ptbench/data/hivtb/fold_7.py +++ b/src/ptbench/data/hivtb/fold_7.py @@ -1,45 +1,11 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for TB detection (cross validation fold 7) -* Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +from .datamodule import DataModule +datamodule = DataModule("fold-7.json") +"""HIV-TB dataset for TB detection (cross validation fold 7). -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-7") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +See :py:class:`DataModule` for technical details. +""" diff --git a/src/ptbench/data/hivtb/fold_8.py b/src/ptbench/data/hivtb/fold_8.py index 9370dcea622c82aeec72015f56c537541f98a0ad..8a0f87d10c934f08249ed4f0206c09b5bbc6a7a9 100644 --- a/src/ptbench/data/hivtb/fold_8.py +++ b/src/ptbench/data/hivtb/fold_8.py @@ -1,45 +1,11 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for TB detection (cross validation fold 8) -* Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +from .datamodule import DataModule +datamodule = DataModule("fold-8.json") +"""HIV-TB dataset for TB detection (cross validation fold 8). -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-8") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +See :py:class:`DataModule` for technical details. +""" diff --git a/src/ptbench/data/hivtb/fold_9.py b/src/ptbench/data/hivtb/fold_9.py index 70605f8d6d22a63bbf8d54f611b8f9be3b850abe..d92de50e75cd18ea9b99d1bdb010f6f88872b9cc 100644 --- a/src/ptbench/data/hivtb/fold_9.py +++ b/src/ptbench/data/hivtb/fold_9.py @@ -1,45 +1,11 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""HIV-TB dataset for TB detection (cross validation fold 9) -* Split reference: none (stratified kfolding) -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.hivtb` for dataset details -""" - -from clapper.logging import setup - -from .. import return_subsets -from ..base_datamodule import BaseDataModule -from . import _maker - -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") +from .datamodule import DataModule +datamodule = DataModule("fold-9.json") +"""HIV-TB dataset for TB detection (cross validation fold 9). -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - multiproc_kwargs=None, - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - def setup(self, stage: str): - self.dataset = _maker("fold-9") - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - self.predict_dataset, - ) = return_subsets(self.dataset) - - -datamodule = DefaultModule +See :py:class:`DataModule` for technical details. +""" diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py index 37876051415f25dcc7d43072af08c27c9f8ff0a9..9e8141386649140a072fcff82adb38cd7fc85bb1 100644 --- a/tests/test_hivtb.py +++ b/tests/test_hivtb.py @@ -4,106 +4,126 @@ """Tests for HIV-TB dataset.""" import pytest +import torch -dataset = None +from ptbench.data.hivtb.datamodule import make_split -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - # Cross-validation fold 0-2 - for f in range(3): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 +def _check_split( + split_filename: str, + lengths: dict[str, int], + prefix: str = "HIV-TB_Algorithm_study_X-rays/", + extension: str = ".BMP", + possible_labels: list[int] = [0, 1], +): + """Runs a simple consistence check on the data split. - assert "train" in subset - assert len(subset["train"]) == 174 - for s in subset["train"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + Parameters + ---------- - assert "validation" in subset - assert len(subset["validation"]) == 44 - for s in subset["validation"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + split_filename + This is the split we will check - assert "test" in subset - assert len(subset["test"]) == 25 - for s in subset["test"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + lenghts + A dictionary that contains keys matching those of the split (this will + be checked). The values of the dictionary should correspond to the + sizes of each of the datasets in the split. - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] + prefix + Each file named in a split should start with this prefix. - for s in subset["validation"]: - assert s.label in [0.0, 1.0] + extension + Each file named in a split should end with this extension. - for s in subset["test"]: - assert s.label in [0.0, 1.0] + possible_labels + These are the list of possible labels contained in any split. + """ - # Cross-validation fold 3-9 - for f in range(3, 10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 + split = make_split(split_filename) - assert "train" in subset - assert len(subset["train"]) == 175 - for s in subset["train"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + assert len(split) == len(lengths) - assert "validation" in subset - assert len(subset["validation"]) == 44 - for s in subset["validation"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + for k in lengths.keys(): + # dataset must have been declared + assert k in split - assert "test" in subset - assert len(subset["test"]) == 24 - for s in subset["test"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") + assert len(split[k]) == lengths[k] + for s in split[k]: + assert s[0].startswith(prefix) + assert s[0].endswith(extension) + assert s[1] in possible_labels - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - for s in subset["validation"]: - assert s.label in [0.0, 1.0] +def _check_loaded_batch( + batch, + size: int = 1, + prefix: str = "HIV-TB_Algorithm_study_X-rays/", + extension: str = ".BMP", + possible_labels: list[int] = [0, 1], +): + """Checks the consistence of an individual (loaded) batch. - for s in subset["test"]: - assert s.label in [0.0, 1.0] + Parameters + ---------- + batch + The loaded batch to be checked. -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") -def test_loading(): - image_size_portrait = (2048, 2500) - image_size_landscape = (2500, 2048) + prefix + Each file named in a split should start with this prefix. + + extension + Each file named in a split should end with this extension. - def _check_size(size): - if size == image_size_portrait: - return True - elif size == image_size_landscape: - return True - return False + possible_labels + These are the list of possible labels contained in any split. + """ - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 + assert len(batch) == 2 # data, metadata - assert "data" in data - assert _check_size(data["data"].size) # Check size - assert data["data"].mode == "L" # Check colors + assert isinstance(batch[0], torch.Tensor) + assert batch[0].shape[0] == size # mini-batch size + assert batch[0].shape[1] == 1 # grayscale images + assert batch[0].shape[2] == batch[0].shape[3] # image is square - assert "label" in data - assert data["label"] in [0, 1] # Check labels + assert isinstance(batch[1], dict) # metadata + assert len(batch[1]) == 2 # label and name - limit = 30 # use this to limit testing to first images only, else None + assert "label" in batch[1] + assert all([k in possible_labels for k in batch[1]["label"]]) - subset = dataset.subsets("fold_0") - for s in subset["train"][:limit]: - _check_sample(s) + assert "name" in batch[1] + assert all([k.startswith(prefix) for k in batch[1]["name"]]) + assert all([k.endswith(extension) for k in batch[1]["name"]]) + + +def test_protocol_consistency(): + # Cross-validation fold 0-2 + for k in range(3): + _check_split( + f"fold-{k}.json", + lengths=dict(train=174, validation=44, test=25), + ) + + # Cross-validation fold 3-9 + for k in range(3, 10): + _check_split( + f"fold-{k}.json", + lengths=dict(train=175, validation=44, test=24), + ) -@pytest.mark.skip(reason="Test need to be updated") @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") -def test_check(): - assert dataset.check() == 0 +def test_loading(): + from ptbench.data.hivtb.fold_0 import datamodule + + datamodule.model_transforms = [] # should be done before setup() + datamodule.setup("predict") # sets up all datasets + + for loader in datamodule.predict_dataloader().values(): + limit = 5 # limit load checking + for batch in loader: + if limit == 0: + break + _check_loaded_batch(batch) + limit -= 1 \ No newline at end of file