Skip to content
Snippets Groups Projects
Commit 75f98d0c authored by Maxime DELITROZ's avatar Maxime DELITROZ
Browse files

updated HIV-TB dataset and related tests

parent a4ca903d
No related branches found
No related tags found
2 merge requests!10Update HIV-TB dataset,!6Making use of LightningDataModule and simplification of data loading
Pipeline #76721 failed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for computer-aided diagnosis (only BMP files)
* Reference: [HIV-TB-2019]_
* Original resolution (height x width or width x height): 2048 x 2500
* Split reference: none
* Stratified kfold protocol:
* Training samples: 72% of TB and healthy CXR (including labels)
* Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
"""
import importlib.resources
import os
from ...utils.rc import load_rc
from .. import make_dataset
from ..dataset import JSONDataset
from ..loader import load_pil_grayscale, make_delayed
_protocols = [
importlib.resources.files(__name__).joinpath("fold_0.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_1.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_2.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_3.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_4.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_5.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_6.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_7.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
]
_datadir = load_rc().get("datadir.hivtb", os.path.realpath(os.curdir))
def _raw_data_loader(sample):
return dict(
data=load_pil_grayscale(os.path.join(_datadir, sample["data"])),
label=sample["label"],
)
def _loader(context, sample):
# "context" is ignored in this case - database is homogeneous
# we returned delayed samples to avoid loading all images at once
return make_delayed(sample, _raw_data_loader)
json_dataset = JSONDataset(
protocols=_protocols,
fieldnames=("data", "label"),
loader=_loader,
)
"""HIV-TB dataset object."""
def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms
from ..augmentations import ElasticDeformation, RemoveBlackBorders
post_transforms = []
if RGB:
post_transforms = [
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]
return make_dataset(
[json_dataset.subsets(protocol)],
[
RemoveBlackBorders(),
transforms.Resize(resize_size),
transforms.CenterCrop(cc_size),
],
[ElasticDeformation(p=0.8)],
post_transforms,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import importlib.resources
import os
import PIL.Image
from torchvision.transforms.functional import center_crop, to_tensor
from ...utils.rc import load_rc
from ..datamodule import CachingDataModule
from ..image_utils import load_pil_grayscale, remove_black_borders
from ..split import JSONDatabaseSplit
from ..typing import DatabaseSplit
from ..typing import RawDataLoader as _BaseRawDataLoader
from ..typing import Sample
class RawDataLoader(_BaseRawDataLoader):
"""A specialized raw-data-loader for the HIV-TB dataset.
Attributes
----------
datadir
This variable contains the base directory where the database raw data
is stored.
"""
datadir: str
def __init__(self):
self.datadir = load_rc().get(
"datadir.hivtb", os.path.realpath(os.curdir)
)
def sample(self, sample: tuple[str, int]) -> Sample:
"""Loads a single image sample from the disk.
Parameters
----------
sample:
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
sample
The sample representation
"""
image = load_pil_grayscale(os.path.join(self.datadir, sample[0]))
image = remove_black_borders(image)
tensor = to_tensor(image)
tensor = center_crop(tensor, min(*tensor.shape[1:]))
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
def label(self, sample: tuple[str, int]) -> int:
"""Loads a single image sample label from the disk.
Parameters
----------
sample:
A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing the
sample label.
Returns
-------
label
The integer label associated with the sample
"""
return sample[1]
def make_split(basename: str) -> DatabaseSplit:
"""Returns a database split for the HIV-TB database."""
return JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
)
class DataModule(CachingDataModule):
"""HIV-TB dataset for computer-aided diagnosis (only BMP files)
* Database reference: [HIV-TB-2019]_
* Original resolution (height x width or width x height): 2048 x 2500 pixels
or 2500 x 2048 pixels
Data specifications:
* Raw data input (on disk):
* BMP images 8 bit grayscale
* resolution fixed to one of the cases above
* Output image:
* Transforms:
* Load raw BMP with :py:mod:`PIL`
* Remove black borders
* Convert to torch tensor
* Torch center cropping to get square image
* Final specifications
* Grayscale, encoded as a single plane tensor, 32-bit floats,
square at 2048 x 2048 pixels
* Labels: 0 (healthy), 1 (active tuberculosis)
"""
def __init__(self, split_filename: str):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 0) """HIV-TB dataset for TB detection (cross validation fold 0)
* Split reference: none (stratified kfolding) * Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets * Stratified kfold protocol:
from ..base_datamodule import BaseDataModule * Training samples: 72% of TB and healthy CXR (including labels)
from . import _maker * Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * This configuration resolution: 2048 x 2048 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str): from .datamodule import DataModule
self.dataset = _maker("fold-0")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DataModule("fold-0.json")
datamodule = DefaultModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 1) """HIV-TB dataset for TB detection (cross validation fold 1)
* Split reference: none (stratified kfolding) * Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets * Stratified kfold protocol:
from ..base_datamodule import BaseDataModule * Training samples: 72% of TB and healthy CXR (including labels)
from . import _maker * Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * This configuration resolution: 2048 x 2048 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str): from .datamodule import DataModule
self.dataset = _maker("fold-1")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DataModule("fold-1.json")
datamodule = DefaultModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 2) """HIV-TB dataset for TB detection (cross validation fold 2)
* Split reference: none (stratified kfolding) * Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets * Stratified kfold protocol:
from ..base_datamodule import BaseDataModule * Training samples: 72% of TB and healthy CXR (including labels)
from . import _maker * Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * This configuration resolution: 2048 x 2048 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str): from .datamodule import DataModule
self.dataset = _maker("fold-2")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DataModule("fold-2.json")
datamodule = DefaultModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 3) """HIV-TB dataset for TB detection (cross validation fold 3)
* Split reference: none (stratified kfolding) * Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets * Stratified kfold protocol:
from ..base_datamodule import BaseDataModule * Training samples: 72% of TB and healthy CXR (including labels)
from . import _maker * Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * This configuration resolution: 2048 x 2048 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str): from .datamodule import DataModule
self.dataset = _maker("fold-3")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DataModule("fold-3.json")
datamodule = DefaultModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 4) """HIV-TB dataset for TB detection (cross validation fold 4)
* Split reference: none (stratified kfolding) * Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets * Stratified kfold protocol:
from ..base_datamodule import BaseDataModule * Training samples: 72% of TB and healthy CXR (including labels)
from . import _maker * Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * This configuration resolution: 2048 x 2048 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str): from .datamodule import DataModule
self.dataset = _maker("fold-4")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DataModule("fold-4.json")
datamodule = DefaultModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 5) """HIV-TB dataset for TB detection (cross validation fold 5)
* Split reference: none (stratified kfolding) * Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets * Stratified kfold protocol:
from ..base_datamodule import BaseDataModule * Training samples: 72% of TB and healthy CXR (including labels)
from . import _maker * Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * This configuration resolution: 2048 x 2048 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str): from .datamodule import DataModule
self.dataset = _maker("fold-5")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DataModule("fold-5.json")
datamodule = DefaultModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 6) """HIV-TB dataset for TB detection (cross validation fold 6)
* Split reference: none (stratified kfolding) * Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets * Stratified kfold protocol:
from ..base_datamodule import BaseDataModule * Training samples: 72% of TB and healthy CXR (including labels)
from . import _maker * Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * This configuration resolution: 2048 x 2048 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str): from .datamodule import DataModule
self.dataset = _maker("fold-6")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DataModule("fold-6.json")
datamodule = DefaultModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 7) """HIV-TB dataset for TB detection (cross validation fold 7)
* Split reference: none (stratified kfolding) * Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets * Stratified kfold protocol:
from ..base_datamodule import BaseDataModule * Training samples: 72% of TB and healthy CXR (including labels)
from . import _maker * Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * This configuration resolution: 2048 x 2048 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str): from .datamodule import DataModule
self.dataset = _maker("fold-7")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DataModule("fold-7.json")
datamodule = DefaultModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 8) """HIV-TB dataset for TB detection (cross validation fold 8)
* Split reference: none (stratified kfolding) * Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets * Stratified kfold protocol:
from ..base_datamodule import BaseDataModule * Training samples: 72% of TB and healthy CXR (including labels)
from . import _maker * Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * This configuration resolution: 2048 x 2048 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str): from .datamodule import DataModule
self.dataset = _maker("fold-8")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DataModule("fold-8.json")
datamodule = DefaultModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""HIV-TB dataset for TB detection (cross validation fold 9) """HIV-TB dataset for TB detection (cross validation fold 9)
* Split reference: none (stratified kfolding) * Split reference: none (stratified kfolding)
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
from clapper.logging import setup
from .. import return_subsets * Stratified kfold protocol:
from ..base_datamodule import BaseDataModule * Training samples: 72% of TB and healthy CXR (including labels)
from . import _maker * Validation samples: 18% of TB and healthy CXR (including labels)
* Test samples: 10% of TB and healthy CXR (including labels)
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") * This configuration resolution: 2048 x 2048 (default)
* See :py:mod:`ptbench.data.hivtb` for dataset details
"""
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str): from .datamodule import DataModule
self.dataset = _maker("fold-9")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DataModule("fold-9.json")
datamodule = DefaultModule
...@@ -4,106 +4,126 @@ ...@@ -4,106 +4,126 @@
"""Tests for HIV-TB dataset.""" """Tests for HIV-TB dataset."""
import pytest import pytest
import torch
dataset = None from ptbench.data.hivtb.datamodule import make_split
@pytest.mark.skip(reason="Test need to be updated") def _check_split(
def test_protocol_consistency(): split_filename: str,
# Cross-validation fold 0-2 lengths: dict[str, int],
for f in range(3): prefix: str = "HIV-TB_Algorithm_study_X-rays/",
subset = dataset.subsets("fold_" + str(f)) extension: str = ".BMP",
assert len(subset) == 3 possible_labels: list[int] = [0, 1],
):
"""Runs a simple consistence check on the data split.
assert "train" in subset Parameters
assert len(subset["train"]) == 174 ----------
for s in subset["train"]:
assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
assert "validation" in subset split_filename
assert len(subset["validation"]) == 44 This is the split we will check
for s in subset["validation"]:
assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
assert "test" in subset lenghts
assert len(subset["test"]) == 25 A dictionary that contains keys matching those of the split (this will
for s in subset["test"]: be checked). The values of the dictionary should correspond to the
assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") sizes of each of the datasets in the split.
# Check labels prefix
for s in subset["train"]: Each file named in a split should start with this prefix.
assert s.label in [0.0, 1.0]
for s in subset["validation"]: extension
assert s.label in [0.0, 1.0] Each file named in a split should end with this extension.
for s in subset["test"]: possible_labels
assert s.label in [0.0, 1.0] These are the list of possible labels contained in any split.
"""
# Cross-validation fold 3-9 split = make_split(split_filename)
for f in range(3, 10):
subset = dataset.subsets("fold_" + str(f))
assert len(subset) == 3
assert "train" in subset assert len(split) == len(lengths)
assert len(subset["train"]) == 175
for s in subset["train"]:
assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
assert "validation" in subset for k in lengths.keys():
assert len(subset["validation"]) == 44 # dataset must have been declared
for s in subset["validation"]: assert k in split
assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
assert "test" in subset assert len(split[k]) == lengths[k]
assert len(subset["test"]) == 24 for s in split[k]:
for s in subset["test"]: assert s[0].startswith(prefix)
assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") assert s[0].endswith(extension)
assert s[1] in possible_labels
# Check labels
for s in subset["train"]:
assert s.label in [0.0, 1.0]
for s in subset["validation"]: def _check_loaded_batch(
assert s.label in [0.0, 1.0] batch,
size: int = 1,
prefix: str = "HIV-TB_Algorithm_study_X-rays/",
extension: str = ".BMP",
possible_labels: list[int] = [0, 1],
):
"""Checks the consistence of an individual (loaded) batch.
for s in subset["test"]: Parameters
assert s.label in [0.0, 1.0] ----------
batch
The loaded batch to be checked.
@pytest.mark.skip(reason="Test need to be updated") prefix
@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") Each file named in a split should start with this prefix.
def test_loading():
image_size_portrait = (2048, 2500) extension
image_size_landscape = (2500, 2048) Each file named in a split should end with this extension.
def _check_size(size): possible_labels
if size == image_size_portrait: These are the list of possible labels contained in any split.
return True """
elif size == image_size_landscape:
return True
return False
def _check_sample(s): assert len(batch) == 2 # data, metadata
data = s.data
assert isinstance(data, dict)
assert len(data) == 2
assert "data" in data assert isinstance(batch[0], torch.Tensor)
assert _check_size(data["data"].size) # Check size assert batch[0].shape[0] == size # mini-batch size
assert data["data"].mode == "L" # Check colors assert batch[0].shape[1] == 1 # grayscale images
assert batch[0].shape[2] == batch[0].shape[3] # image is square
assert "label" in data assert isinstance(batch[1], dict) # metadata
assert data["label"] in [0, 1] # Check labels assert len(batch[1]) == 2 # label and name
limit = 30 # use this to limit testing to first images only, else None assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]])
subset = dataset.subsets("fold_0") assert "name" in batch[1]
for s in subset["train"][:limit]: assert all([k.startswith(prefix) for k in batch[1]["name"]])
_check_sample(s) assert all([k.endswith(extension) for k in batch[1]["name"]])
def test_protocol_consistency():
# Cross-validation fold 0-2
for k in range(3):
_check_split(
f"fold-{k}.json",
lengths=dict(train=174, validation=44, test=25),
)
# Cross-validation fold 3-9
for k in range(3, 10):
_check_split(
f"fold-{k}.json",
lengths=dict(train=175, validation=44, test=24),
)
@pytest.mark.skip(reason="Test need to be updated")
@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
def test_check(): def test_loading():
assert dataset.check() == 0 from ptbench.data.hivtb.fold_0 import datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
for loader in datamodule.predict_dataloader().values():
limit = 5 # limit load checking
for batch in loader:
if limit == 0:
break
_check_loaded_batch(batch)
limit -= 1
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment