Skip to content
Snippets Groups Projects
Commit dc220169 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Updated shenzhen tests

parent 87a65d97
Branches
Tags
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -4,133 +4,194 @@ ...@@ -4,133 +4,194 @@
"""Tests for Shenzhen dataset.""" """Tests for Shenzhen dataset."""
import pytest import importlib
from ptbench.data.shenzhen import dataset import pytest
def test_protocol_consistency(): def test_protocol_consistency():
# Default protocol # Default protocol
subset = dataset.subsets("default")
datamodule = importlib.import_module(
"ptbench.data.shenzhen.default"
).datamodule
subset = datamodule.database_split.subsets
assert len(subset) == 3 assert len(subset) == 3
assert "train" in subset assert "train" in subset
assert len(subset["train"]) == 422 assert len(subset["train"]) == 422
for s in subset["train"]: for s in subset["train"]:
assert s.key.startswith("CXR_png/CHNCXR_0") assert s[0].startswith("CXR_png/CHNCXR_0")
assert "validation" in subset assert "validation" in subset
assert len(subset["validation"]) == 107 assert len(subset["validation"]) == 107
for s in subset["validation"]: for s in subset["validation"]:
assert s.key.startswith("CXR_png/CHNCXR_0") assert s[0].startswith("CXR_png/CHNCXR_0")
assert "test" in subset assert "test" in subset
assert len(subset["test"]) == 133 assert len(subset["test"]) == 133
for s in subset["test"]: for s in subset["test"]:
assert s.key.startswith("CXR_png/CHNCXR_0") assert s[0].startswith("CXR_png/CHNCXR_0")
# Check labels # Check labels
for s in subset["train"]: for s in subset["train"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["validation"]: for s in subset["validation"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["test"]: for s in subset["test"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
# Cross-validation folds 0-1 # Cross-validation folds 0-1
for f in range(2): for f in range(2):
subset = dataset.subsets("fold_" + str(f)) datamodule = importlib.import_module(
f"ptbench.data.shenzhen.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
assert len(subset) == 3 assert len(subset) == 3
assert "train" in subset assert "train" in subset
assert len(subset["train"]) == 476 assert len(subset["train"]) == 476
for s in subset["train"]: for s in subset["train"]:
assert s.key.startswith("CXR_png/CHNCXR_0") assert s[0].startswith("CXR_png/CHNCXR_0")
assert "validation" in subset assert "validation" in subset
assert len(subset["validation"]) == 119 assert len(subset["validation"]) == 119
for s in subset["validation"]: for s in subset["validation"]:
assert s.key.startswith("CXR_png/CHNCXR_0") assert s[0].startswith("CXR_png/CHNCXR_0")
assert "test" in subset assert "test" in subset
assert len(subset["test"]) == 67 assert len(subset["test"]) == 67
for s in subset["test"]: for s in subset["test"]:
assert s.key.startswith("CXR_png/CHNCXR_0") assert s[0].startswith("CXR_png/CHNCXR_0")
# Check labels # Check labels
for s in subset["train"]: for s in subset["train"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["validation"]: for s in subset["validation"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["test"]: for s in subset["test"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
# Cross-validation folds 2-9 # Cross-validation folds 2-9
for f in range(2, 10): for f in range(2, 10):
subset = dataset.subsets("fold_" + str(f)) datamodule = importlib.import_module(
f"ptbench.data.shenzhen.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
assert len(subset) == 3 assert len(subset) == 3
assert "train" in subset assert "train" in subset
assert len(subset["train"]) == 476 assert len(subset["train"]) == 476
for s in subset["train"]: for s in subset["train"]:
assert s.key.startswith("CXR_png/CHNCXR_0") assert s[0].startswith("CXR_png/CHNCXR_0")
assert "validation" in subset assert "validation" in subset
assert len(subset["validation"]) == 120 assert len(subset["validation"]) == 120
for s in subset["validation"]: for s in subset["validation"]:
assert s.key.startswith("CXR_png/CHNCXR_0") assert s[0].startswith("CXR_png/CHNCXR_0")
assert "test" in subset assert "test" in subset
assert len(subset["test"]) == 66 assert len(subset["test"]) == 66
for s in subset["test"]: for s in subset["test"]:
assert s.key.startswith("CXR_png/CHNCXR_0") assert s[0].startswith("CXR_png/CHNCXR_0")
# Check labels # Check labels
for s in subset["train"]: for s in subset["train"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["validation"]: for s in subset["validation"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["test"]: for s in subset["test"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_loading(): def test_loading():
def _check_size(size): import torch
if ( import torchvision.transforms
size[0] >= 1130
and size[0] <= 3001 from ptbench.data.datamodule import _DelayedLoadingDataset
and size[1] >= 948
and size[1] <= 3001 def _check_size(shape):
): if shape[0] == 1 and shape[1] == 512 and shape[2] == 512:
return True return True
return False return False
def _check_sample(s): def _check_sample(s):
data = s.data assert len(s) == 2
assert isinstance(data, dict)
assert len(data) == 2
assert "data" in data data = s[0]
assert _check_size(data["data"].size) # Check size metadata = s[1]
assert data["data"].mode == "L" # Check colors
assert "label" in data assert isinstance(data, torch.Tensor)
assert data["label"] in [0, 1] # Check labels
print(data.shape)
assert _check_size(data.shape) # Check size
assert (
torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors
assert "label" in metadata
assert metadata["label"] in [0, 1] # Check labels
limit = 30 # use this to limit testing to first images only, else None limit = 30 # use this to limit testing to first images only, else None
subset = dataset.subsets("default") datamodule = importlib.import_module(
for s in subset["train"][:limit]: "ptbench.data.shenzhen.default"
).datamodule
subset = datamodule.database_split.subsets
raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use
dataset = _DelayedLoadingDataset(
subset["train"][:limit],
raw_data_loader,
)
for s in dataset:
_check_sample(s) _check_sample(s)
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_check(): def test_check():
assert dataset.check() == 0 from ptbench.data.split import check_database_split_loading
limit = 30 # use this to limit testing to first images only, else 0
# Default protocol
datamodule = importlib.import_module(
"ptbench.data.shenzhen.default"
).datamodule
database_split = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
assert (
check_database_split_loading(
database_split, raw_data_loader, limit=limit
)
== 0
)
# Folds
for f in range(10):
datamodule = importlib.import_module(
f"ptbench.data.shenzhen.fold_{f}"
).datamodule
database_split = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
assert (
check_database_split_loading(
database_split, raw_data_loader, limit=limit
)
== 0
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment