diff --git a/tests/test_ch.py b/tests/test_ch.py index 853a2f7184fc56ebd805ea278a4cc21f8ecbad21..510b1171741a9a4ea0176976a7aa45dddaa427fe 100644 --- a/tests/test_ch.py +++ b/tests/test_ch.py @@ -4,133 +4,194 @@ """Tests for Shenzhen dataset.""" -import pytest +import importlib -from ptbench.data.shenzhen import dataset +import pytest def test_protocol_consistency(): # Default protocol - subset = dataset.subsets("default") + + datamodule = importlib.import_module( + "ptbench.data.shenzhen.default" + ).datamodule + subset = datamodule.database_split.subsets + assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 422 for s in subset["train"]: - assert s.key.startswith("CXR_png/CHNCXR_0") + assert s[0].startswith("CXR_png/CHNCXR_0") assert "validation" in subset assert len(subset["validation"]) == 107 for s in subset["validation"]: - assert s.key.startswith("CXR_png/CHNCXR_0") + assert s[0].startswith("CXR_png/CHNCXR_0") assert "test" in subset assert len(subset["test"]) == 133 for s in subset["test"]: - assert s.key.startswith("CXR_png/CHNCXR_0") + assert s[0].startswith("CXR_png/CHNCXR_0") # Check labels for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["validation"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["test"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] # Cross-validation folds 0-1 for f in range(2): - subset = dataset.subsets("fold_" + str(f)) + datamodule = importlib.import_module( + f"ptbench.data.shenzhen.fold_{str(f)}" + ).datamodule + subset = datamodule.database_split.subsets + assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 476 for s in subset["train"]: - assert s.key.startswith("CXR_png/CHNCXR_0") + assert s[0].startswith("CXR_png/CHNCXR_0") assert "validation" in subset assert len(subset["validation"]) == 119 for s in subset["validation"]: - assert s.key.startswith("CXR_png/CHNCXR_0") + assert s[0].startswith("CXR_png/CHNCXR_0") assert "test" in subset assert len(subset["test"]) == 67 for s in subset["test"]: - assert s.key.startswith("CXR_png/CHNCXR_0") + assert s[0].startswith("CXR_png/CHNCXR_0") # Check labels for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["validation"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["test"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] # Cross-validation folds 2-9 for f in range(2, 10): - subset = dataset.subsets("fold_" + str(f)) + datamodule = importlib.import_module( + f"ptbench.data.shenzhen.fold_{str(f)}" + ).datamodule + subset = datamodule.database_split.subsets + assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 476 for s in subset["train"]: - assert s.key.startswith("CXR_png/CHNCXR_0") + assert s[0].startswith("CXR_png/CHNCXR_0") assert "validation" in subset assert len(subset["validation"]) == 120 for s in subset["validation"]: - assert s.key.startswith("CXR_png/CHNCXR_0") + assert s[0].startswith("CXR_png/CHNCXR_0") assert "test" in subset assert len(subset["test"]) == 66 for s in subset["test"]: - assert s.key.startswith("CXR_png/CHNCXR_0") + assert s[0].startswith("CXR_png/CHNCXR_0") # Check labels for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["validation"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["test"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") def test_loading(): - def _check_size(size): - if ( - size[0] >= 1130 - and size[0] <= 3001 - and size[1] >= 948 - and size[1] <= 3001 - ): + import torch + import torchvision.transforms + + from ptbench.data.datamodule import _DelayedLoadingDataset + + def _check_size(shape): + if shape[0] == 1 and shape[1] == 512 and shape[2] == 512: return True return False def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 + assert len(s) == 2 - assert "data" in data - assert _check_size(data["data"].size) # Check size - assert data["data"].mode == "L" # Check colors + data = s[0] + metadata = s[1] - assert "label" in data - assert data["label"] in [0, 1] # Check labels + assert isinstance(data, torch.Tensor) + + print(data.shape) + assert _check_size(data.shape) # Check size + + assert ( + torchvision.transforms.ToPILImage()(data).mode == "L" + ) # Check colors + + assert "label" in metadata + assert metadata["label"] in [0, 1] # Check labels limit = 30 # use this to limit testing to first images only, else None - subset = dataset.subsets("default") - for s in subset["train"][:limit]: + datamodule = importlib.import_module( + "ptbench.data.shenzhen.default" + ).datamodule + subset = datamodule.database_split.subsets + raw_data_loader = datamodule.raw_data_loader + + # Need to use private function so we can limit the number of samples to use + dataset = _DelayedLoadingDataset( + subset["train"][:limit], + raw_data_loader, + ) + + for s in dataset: _check_sample(s) @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") def test_check(): - assert dataset.check() == 0 + from ptbench.data.split import check_database_split_loading + + limit = 30 # use this to limit testing to first images only, else 0 + + # Default protocol + datamodule = importlib.import_module( + "ptbench.data.shenzhen.default" + ).datamodule + database_split = datamodule.database_split + raw_data_loader = datamodule.raw_data_loader + + assert ( + check_database_split_loading( + database_split, raw_data_loader, limit=limit + ) + == 0 + ) + + # Folds + for f in range(10): + datamodule = importlib.import_module( + f"ptbench.data.shenzhen.fold_{f}" + ).datamodule + database_split = datamodule.database_split + raw_data_loader = datamodule.raw_data_loader + + assert ( + check_database_split_loading( + database_split, raw_data_loader, limit=limit + ) + == 0 + )