Skip to content
Snippets Groups Projects
Commit 5e6d38bc authored by Maxime DELITROZ's avatar Maxime DELITROZ Committed by André Anjos
Browse files

updated tests for Montgomery dataset

parent 4c71da47
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -4,131 +4,190 @@ ...@@ -4,131 +4,190 @@
"""Tests for Montgomery dataset.""" """Tests for Montgomery dataset."""
import importlib
import pytest import pytest
def test_protocol_consistency(): def test_protocol_consistency():
from ptbench.data.montgomery import dataset
# Default protocol # Default protocol
subset = dataset.subsets("default") datamodule = importlib.import_module(
"ptbench.data.montgomery.default"
).datamodule
subset = datamodule.dataset_split.subsets
assert len(subset) == 3 assert len(subset) == 3
assert "train" in subset assert "train" in subset
assert len(subset["train"]) == 88 assert len(subset["train"]) == 88
for s in subset["train"]: for s in subset["train"]:
assert s.key.startswith("CXR_png/MCUCXR_0") assert s[0].startswith("CXR_png/MCUCXR_0")
assert "validation" in subset assert "validation" in subset
assert len(subset["validation"]) == 22 assert len(subset["validation"]) == 22
for s in subset["validation"]: for s in subset["validation"]:
assert s.key.startswith("CXR_png/MCUCXR_0") assert s[0].startswith("CXR_png/MCUCXR_0")
assert "test" in subset assert "test" in subset
assert len(subset["test"]) == 28 assert len(subset["test"]) == 28
for s in subset["test"]: for s in subset["test"]:
assert s.key.startswith("CXR_png/MCUCXR_0") assert s[0].startswith("CXR_png/MCUCXR_0")
# Check labels # Check labels
for s in subset["train"]: for s in subset["train"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["validation"]: for s in subset["validation"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["test"]: for s in subset["test"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
# Cross-validation fold 0-7 # Cross-validation fold 0-7
for f in range(8): for f in range(8):
subset = dataset.subsets("fold_" + str(f)) datamodule = importlib.import_module(
f"ptbench.data.montgomery.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
assert len(subset) == 3 assert len(subset) == 3
assert "train" in subset assert "train" in subset
assert len(subset["train"]) == 99 assert len(subset["train"]) == 99
for s in subset["train"]: for s in subset["train"]:
assert s.key.startswith("CXR_png/MCUCXR_0") assert s[0].startswith("CXR_png/MCUCXR_0")
assert "validation" in subset assert "validation" in subset
assert len(subset["validation"]) == 25 assert len(subset["validation"]) == 25
for s in subset["validation"]: for s in subset["validation"]:
assert s.key.startswith("CXR_png/MCUCXR_0") assert s[0].startswith("CXR_png/MCUCXR_0")
assert "test" in subset assert "test" in subset
assert len(subset["test"]) == 14 assert len(subset["test"]) == 14
for s in subset["test"]: for s in subset["test"]:
assert s.key.startswith("CXR_png/MCUCXR_0") assert s[0].startswith("CXR_png/MCUCXR_0")
# Check labels # Check labels
for s in subset["train"]: for s in subset["train"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["validation"]: for s in subset["validation"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["test"]: for s in subset["test"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
# Cross-validation fold 8-9 # Cross-validation fold 8-9
for f in range(8, 10): for f in range(8, 10):
subset = dataset.subsets("fold_" + str(f)) datamodule = importlib.import_module(
f"ptbench.data.montgomery.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
assert len(subset) == 3 assert len(subset) == 3
assert "train" in subset assert "train" in subset
assert len(subset["train"]) == 100 assert len(subset["train"]) == 100
for s in subset["train"]: for s in subset["train"]:
assert s.key.startswith("CXR_png/MCUCXR_0") assert s[0].startswith("CXR_png/MCUCXR_0")
assert "validation" in subset assert "validation" in subset
assert len(subset["validation"]) == 25 assert len(subset["validation"]) == 25
for s in subset["validation"]: for s in subset["validation"]:
assert s.key.startswith("CXR_png/MCUCXR_0") assert s[0].startswith("CXR_png/MCUCXR_0")
assert "test" in subset assert "test" in subset
assert len(subset["test"]) == 13 assert len(subset["test"]) == 13
for s in subset["test"]: for s in subset["test"]:
assert s.key.startswith("CXR_png/MCUCXR_0") assert s[0].startswith("CXR_png/MCUCXR_0")
# Check labels # Check labels
for s in subset["train"]: for s in subset["train"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["validation"]: for s in subset["validation"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
for s in subset["test"]: for s in subset["test"]:
assert s.label in [0.0, 1.0] assert s[1] in [0.0, 1.0]
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_loading(): def test_loading():
from ptbench.data.montgomery import dataset import torch
import torchvision.transforms
from ptbench.data.datamodule import _DelayedLoadingDataset
def _check_sample(s): def _check_sample(s):
data = s.data data = s[0]
assert isinstance(data, dict) metadata = s[1]
assert len(data) == 2
assert isinstance(data, torch.Tensor)
assert "data" in data
assert data["data"].size in ( assert data.size in (
(4020, 4892), # portrait (1, 4020, 4892), # portrait
(4892, 4020), # landscape (1, 4892, 4020), # landscape
(512, 512), # test database @ CI (1, 512, 512), # test database @ CI
) )
assert data["data"].mode == "L" # Check colors assert (
torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors
assert "label" in data assert "label" in metadata
assert data["label"] in [0, 1] # Check labels assert metadata["label"] in [0, 1] # Check labels
limit = 30 # use this to limit testing to first images only, else None limit = 30 # use this to limit testing to first images only, else None
subset = dataset.subsets("default") datamodule = importlib.import_module(
for s in subset["train"][:limit]: "ptbench.data.montgomery.default"
).datamodule
subset = datamodule.database_split.subsetss
raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use
dataset = _DelayedLoadingDataset(
subset["train"][:limit],
raw_data_loader
)
for s in dataset:
_check_sample(s) _check_sample(s)
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_check(): def test_check():
from ptbench.data.montgomery import dataset from ptbench.data.split import check_database_split_loading
limit = 30 # use this to limit testing to first images only, else 0
# Default protocol
datamodule = importlib.import_module(
"ptbench.data.montgomery.default"
).datamodule
database_split = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
assert (
check_database_split_loading(
database_split, raw_data_loader, limit=limit
)
== 0
)
# Folds
for f in range(10):
datamodule = importlib.import_module(
f"ptbench.data.montgomery.fold_{f}"
).datamodule
database_split = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
assert (
check_database_split_loading(
database_split, raw_data_loader, limit=limit
)
== 0
)
assert dataset.check() == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment