diff --git a/src/ptbench/data/indian/datamodule.py b/src/ptbench/data/indian/datamodule.py index f68d184e4d10b0830bb3ef50d83527b4f2d88ac2..f6017cad4e7deac52c74c8478c3697f64dffa9c7 100644 --- a/src/ptbench/data/indian/datamodule.py +++ b/src/ptbench/data/indian/datamodule.py @@ -55,5 +55,5 @@ class DataModule(CachingDataModule): def __init__(self, split_filename: str): super().__init__( database_split=make_split(split_filename), - raw_data_loader=RawDataLoader(), + raw_data_loader=RawDataLoader(config_variable="datadir.indian"), ) diff --git a/src/ptbench/data/shenzhen/datamodule.py b/src/ptbench/data/shenzhen/datamodule.py index 7cf1833b27ca1e79382cb7e45fa41b5b4ee27292..0596007eaae5050b5691f1ebe1563f78e1507910 100644 --- a/src/ptbench/data/shenzhen/datamodule.py +++ b/src/ptbench/data/shenzhen/datamodule.py @@ -34,9 +34,9 @@ class RawDataLoader(_BaseRawDataLoader): datadir: str - def __init__(self): + def __init__(self, config_variable: str = "datadir.shenzhen"): self.datadir = load_rc().get( - "datadir.shenzhen", os.path.realpath(os.curdir) + config_variable, os.path.realpath(os.curdir) ) def sample(self, sample: tuple[str, int]) -> Sample: diff --git a/tests/test_indian.py b/tests/test_indian.py index 87660c1af4073fde09cc13d5a6f616d9180a3e4e..91adf0d42083cfa45869e41ccbe9fc1923edd039 100644 --- a/tests/test_indian.py +++ b/tests/test_indian.py @@ -1,142 +1,127 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for Indian dataset.""" - -import pytest +"""Tests for Indian (a.k.a. +database A/database B) dataset. +""" -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - from ptbench.data.indian import dataset +import pytest +import torch - # Default protocol - subset = dataset.subsets("default") - assert len(subset) == 3 +from ptbench.data.indian.datamodule import make_split - assert "train" in subset - assert len(subset["train"]) == 83 - for s in subset["train"]: - assert s.key.startswith("DatasetA/Training/") - assert "validation" in subset - assert len(subset["validation"]) == 20 - for s in subset["validation"]: - assert s.key.startswith("DatasetA/Training/") +def _check_split( + split_filename: str, + lengths: dict[str, int], + prefix: str = "Dataset", + possible_labels: list[int] = [0, 1], +): + """Runs a simple consistence check on the data split. - assert "test" in subset - assert len(subset["test"]) == 52 - for s in subset["test"]: - assert s.key.startswith("DatasetA/Testing/") + Parameters + ---------- - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] + split_filename + This is the split we will check - for s in subset["validation"]: - assert s.label in [0.0, 1.0] + lenghts + A dictionary that contains keys matching those of the split (this will + be checked). The values of the dictionary should correspond to the + sizes of each of the datasets in the split. - for s in subset["test"]: - assert s.label in [0.0, 1.0] + prefix + Each file named in a split should start with this prefix. - # Cross-validation fold 0-4 - for f in range(5): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 + possible_labels + These are the list of possible labels contained in any split. + """ - assert "train" in subset - assert len(subset["train"]) == 111 - for s in subset["train"]: - assert s.key.startswith("DatasetA") + split = make_split(split_filename) - assert "validation" in subset - assert len(subset["validation"]) == 28 - for s in subset["validation"]: - assert s.key.startswith("DatasetA") + assert len(split) == len(lengths) - assert "test" in subset - assert len(subset["test"]) == 16 - for s in subset["test"]: - assert s.key.startswith("DatasetA") + for k in lengths.keys(): + # dataset must have been declared + assert k in split - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert len(split[k]) == lengths[k] + for s in split[k]: + assert s[0].startswith(prefix) + assert s[1] in possible_labels - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - for s in subset["test"]: - assert s.label in [0.0, 1.0] +def _check_loaded_batch( + batch, + size: int = 1, + prefix: str = "Dataset", + possible_labels: list[int] = [0, 1], +): + """Checks the consistence of an individual (loaded) batch. - # Cross-validation fold 5-9 - for f in range(5, 10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 + Parameters + ---------- - assert "train" in subset - assert len(subset["train"]) == 112 - for s in subset["train"]: - assert s.key.startswith("DatasetA") + batch + The loaded batch to be checked. - assert "validation" in subset - assert len(subset["validation"]) == 28 - for s in subset["validation"]: - assert s.key.startswith("DatasetA") + prefix + Each file named in a split should start with this prefix. - assert "test" in subset - assert len(subset["test"]) == 15 - for s in subset["test"]: - assert s.key.startswith("DatasetA") + possible_labels + These are the list of possible labels contained in any split. + """ - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert len(batch) == 2 # data, metadata - for s in subset["validation"]: - assert s.label in [0.0, 1.0] + assert isinstance(batch[0], torch.Tensor) + assert batch[0].shape[0] == size # mini-batch size + assert batch[0].shape[1] == 1 # grayscale images + assert batch[0].shape[2] == batch[0].shape[3] # image is square - for s in subset["test"]: - assert s.label in [0.0, 1.0] + assert isinstance(batch[1], dict) # metadata + assert len(batch[1]) == 2 # label and name + assert "label" in batch[1] + assert all([k in possible_labels for k in batch[1]["label"]]) -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.indian") -def test_loading(): - from ptbench.data.indian import dataset + assert "name" in batch[1] + assert all([k.startswith(prefix) for k in batch[1]["name"]]) - def _check_size(size): - if ( - size[0] >= 1024 - and size[0] <= 2320 - and size[1] >= 1024 - and size[1] <= 2828 - ): - return True - return False - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert _check_size(data["data"].size) # Check size - assert data["data"].mode == "L" # Check colors - - assert "label" in data - assert data["label"] in [0, 1] # Check labels +def test_protocol_consistency(): + _check_split( + "default.json", + lengths=dict(train=83, validation=20, test=52), + ) - limit = 30 # use this to limit testing to first images only, else None + # Cross-validation fold 0-4 + for k in range(5): + _check_split( + f"fold-{k}.json", + lengths=dict(train=111, validation=28, test=16), + ) - subset = dataset.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) + # Cross-validation fold 5-9 + for k in range(5, 10): + _check_split( + f"fold-{k}.json", + lengths=dict(train=112, validation=28, test=15), + ) -@pytest.mark.skip(reason="Test need to be updated") @pytest.mark.skip_if_rc_var_not_set("datadir.indian") -def test_check(): - from ptbench.data.indian import dataset - - assert dataset.check() == 0 +def test_loading(): + from ptbench.data.indian.default import datamodule + + datamodule.model_transforms = [] # should be done before setup() + datamodule.setup("predict") # sets up all datasets + + for loader in datamodule.predict_dataloader().values(): + limit = 5 # limit load checking + for batch in loader: + if limit == 0: + break + _check_loaded_batch(batch) + limit -= 1 diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py index 6e577081a83a075c92a4808bbfb1c876d247fdde..30c69543a77d24715f2a815192d2e0abac41b9c5 100644 --- a/tests/test_shenzhen.py +++ b/tests/test_shenzhen.py @@ -3,202 +3,122 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Tests for Shenzhen dataset.""" -import importlib - import pytest +import torch +from ptbench.data.shenzhen.datamodule import make_split -def test_protocol_consistency(): - # Default protocol - - datamodule = getattr( - importlib.import_module("ptbench.data.shenzhen.datamodules"), "default" - ) - - subset = datamodule.splits - - assert len(subset) == 3 - - assert "train" in subset - train_samples = subset["train"][0][0] - assert len(train_samples) == 422 - for s in train_samples: - assert s[0].startswith("CXR_png/CHNCXR_0") - - assert "validation" in subset - validation_samples = subset["validation"][0][0] - assert len(validation_samples) == 107 - for s in validation_samples: - assert s[0].startswith("CXR_png/CHNCXR_0") - - assert "test" in subset - test_samples = subset["test"][0][0] - assert len(test_samples) == 133 - for s in test_samples: - assert s[0].startswith("CXR_png/CHNCXR_0") - # Check labels - for s in train_samples: - assert s[1] in [0.0, 1.0] +def _check_split( + split_filename: str, + lengths: dict[str, int], + prefix: str = "CXR_png/CHNCXR_0", + possible_labels: list[int] = [0, 1], +): + """Runs a simple consistence check on the data split. - for s in validation_samples: - assert s[1] in [0.0, 1.0] + Parameters + ---------- - for s in test_samples: - assert s[1] in [0.0, 1.0] + split_filename + This is the split we will check - # Cross-validation folds 0-1 - for f in range(2): - datamodule = getattr( - importlib.import_module("ptbench.data.shenzhen.datamodules"), - f"fold_{str(f)}", - ) - - subset = datamodule.splits - - assert len(subset) == 3 + lenghts + A dictionary that contains keys matching those of the split (this will + be checked). The values of the dictionary should correspond to the + sizes of each of the datasets in the split. - assert "train" in subset - train_samples = subset["train"][0][0] - assert len(train_samples) == 476 - for s in train_samples: - assert s[0].startswith("CXR_png/CHNCXR_0") + prefix + Each file named in a split should start with this prefix. - assert "validation" in subset - validation_samples = subset["validation"][0][0] - assert len(validation_samples) == 119 - for s in validation_samples: - assert s[0].startswith("CXR_png/CHNCXR_0") + possible_labels + These are the list of possible labels contained in any split. + """ - assert "test" in subset - test_samples = subset["test"][0][0] - assert len(test_samples) == 67 - for s in test_samples: - assert s[0].startswith("CXR_png/CHNCXR_0") + split = make_split(split_filename) - # Check labels - for s in train_samples: - assert s[1] in [0.0, 1.0] + assert len(split) == len(lengths) - for s in validation_samples: - assert s[1] in [0.0, 1.0] + for k in lengths.keys(): + # dataset must have been declared + assert k in split - for s in test_samples: - assert s[1] in [0.0, 1.0] - - # Cross-validation folds 2-9 - for f in range(2, 10): - datamodule = getattr( - importlib.import_module("ptbench.data.shenzhen.datamodules"), - f"fold_{str(f)}", - ) + assert len(split[k]) == lengths[k] + for s in split[k]: + assert s[0].startswith(prefix) + assert s[1] in possible_labels - subset = datamodule.splits - assert len(subset) == 3 +def _check_loaded_batch( + batch, + size: int = 1, + prefix: str = "CXR_png/CHNCXR_0", + possible_labels: list[int] = [0, 1], +): + """Checks the consistence of an individual (loaded) batch. - assert "train" in subset - train_samples = subset["train"][0][0] - assert len(train_samples) == 476 - for s in train_samples: - assert s[0].startswith("CXR_png/CHNCXR_0") + Parameters + ---------- - assert "validation" in subset - validation_samples = subset["validation"][0][0] - assert len(validation_samples) == 120 - for s in validation_samples: - assert s[0].startswith("CXR_png/CHNCXR_0") + batch + The loaded batch to be checked. - assert "test" in subset - test_samples = subset["test"][0][0] - assert len(test_samples) == 66 - for s in test_samples: - assert s[0].startswith("CXR_png/CHNCXR_0") + prefix + Each file named in a split should start with this prefix. - # Check labels - for s in train_samples: - assert s[1] in [0.0, 1.0] + possible_labels + These are the list of possible labels contained in any split. + """ - for s in validation_samples: - assert s[1] in [0.0, 1.0] + assert len(batch) == 2 # data, metadata - for s in test_samples: - assert s[1] in [0.0, 1.0] - - -@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") -def test_loading(): - import torch - import torchvision.transforms + assert isinstance(batch[0], torch.Tensor) + assert batch[0].shape[0] == size # mini-batch size + assert batch[0].shape[1] == 1 # grayscale images + assert batch[0].shape[2] == batch[0].shape[3] # image is square - from ptbench.data.datamodule import _DelayedLoadingDataset + assert isinstance(batch[1], dict) # metadata + assert len(batch[1]) == 2 # label and name - def _check_sample(s): - assert len(s) == 2 + assert "label" in batch[1] + assert all([k in possible_labels for k in batch[1]["label"]]) - data = s[0] - metadata = s[1] + assert "name" in batch[1] + assert all([k.startswith(prefix) for k in batch[1]["name"]]) - assert isinstance(data, torch.Tensor) - assert data.size(0) == 1 # check 1 channel - assert data.size(1) == data.size(2) # check square image - - assert ( - torchvision.transforms.ToPILImage()(data).mode == "L" - ) # Check colors - - assert "label" in metadata - assert metadata["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - module = importlib.import_module("ptbench.data.shenzhen.datamodules") - datamodule = getattr(module, "default") - raw_data_loader = module.RawDataLoader() - subset = datamodule.splits - - # Need to use private function so we can limit the number of samples to use - dataset = _DelayedLoadingDataset( - subset["train"][0][0][:limit], - raw_data_loader, +def test_protocol_consistency(): + _check_split( + "default.json", + lengths=dict(train=422, validation=107, test=133), ) - for s in dataset: - _check_sample(s) - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") -def test_check(): - from ptbench.data.split import check_database_split_loading - - limit = 30 # use this to limit testing to first images only, else 0 - - # Default protocol - module = importlib.import_module("ptbench.data.shenzhen.datamodules") - datamodule = getattr(module, "default") - database_split = datamodule.splits - raw_data_loader = module.RawDataLoader() - - assert ( - check_database_split_loading( - database_split, raw_data_loader, limit=limit + # Cross-validation fold 0-1 + for k in range(2): + _check_split( + f"fold-{k}.json", + lengths=dict(train=476, validation=119, test=67), ) - == 0 - ) - # Folds - for f in range(10): - module = importlib.import_module("ptbench.data.shenzhen.datamodules") - datamodule = getattr(module, f"fold_{f}") + # Cross-validation fold 2-9 + for k in range(2, 10): + _check_split( + f"fold-{k}.json", + lengths=dict(train=476, validation=120, test=66), + ) - database_split = datamodule.splits - raw_data_loader = module.RawDataLoader() - assert ( - check_database_split_loading( - database_split, raw_data_loader, limit=limit - ) - == 0 - ) +@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") +def test_loading(): + from ptbench.data.shenzhen.default import datamodule + + datamodule.model_transforms = [] # should be done before setup() + datamodule.setup("predict") # sets up all datasets + + for loader in datamodule.predict_dataloader().values(): + limit = 5 # limit load checking + for batch in loader: + if limit == 0: + break + _check_loaded_batch(batch) + limit -= 1