Skip to content
Snippets Groups Projects
Commit 5e6d38bc authored by Maxime DELITROZ's avatar Maxime DELITROZ Committed by André Anjos
Browse files

updated tests for Montgomery dataset

parent 4c71da47
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -4,131 +4,190 @@
"""Tests for Montgomery dataset."""
import importlib
import pytest
def test_protocol_consistency():
from ptbench.data.montgomery import dataset
# Default protocol
subset = dataset.subsets("default")
datamodule = importlib.import_module(
"ptbench.data.montgomery.default"
).datamodule
subset = datamodule.dataset_split.subsets
assert len(subset) == 3
assert "train" in subset
assert len(subset["train"]) == 88
for s in subset["train"]:
assert s.key.startswith("CXR_png/MCUCXR_0")
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "validation" in subset
assert len(subset["validation"]) == 22
for s in subset["validation"]:
assert s.key.startswith("CXR_png/MCUCXR_0")
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "test" in subset
assert len(subset["test"]) == 28
for s in subset["test"]:
assert s.key.startswith("CXR_png/MCUCXR_0")
assert s[0].startswith("CXR_png/MCUCXR_0")
# Check labels
for s in subset["train"]:
assert s.label in [0.0, 1.0]
assert s[1] in [0.0, 1.0]
for s in subset["validation"]:
assert s.label in [0.0, 1.0]
assert s[1] in [0.0, 1.0]
for s in subset["test"]:
assert s.label in [0.0, 1.0]
assert s[1] in [0.0, 1.0]
# Cross-validation fold 0-7
for f in range(8):
subset = dataset.subsets("fold_" + str(f))
datamodule = importlib.import_module(
f"ptbench.data.montgomery.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
assert len(subset) == 3
assert "train" in subset
assert len(subset["train"]) == 99
for s in subset["train"]:
assert s.key.startswith("CXR_png/MCUCXR_0")
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "validation" in subset
assert len(subset["validation"]) == 25
for s in subset["validation"]:
assert s.key.startswith("CXR_png/MCUCXR_0")
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "test" in subset
assert len(subset["test"]) == 14
for s in subset["test"]:
assert s.key.startswith("CXR_png/MCUCXR_0")
assert s[0].startswith("CXR_png/MCUCXR_0")
# Check labels
for s in subset["train"]:
assert s.label in [0.0, 1.0]
assert s[1] in [0.0, 1.0]
for s in subset["validation"]:
assert s.label in [0.0, 1.0]
assert s[1] in [0.0, 1.0]
for s in subset["test"]:
assert s.label in [0.0, 1.0]
assert s[1] in [0.0, 1.0]
# Cross-validation fold 8-9
for f in range(8, 10):
subset = dataset.subsets("fold_" + str(f))
datamodule = importlib.import_module(
f"ptbench.data.montgomery.fold_{str(f)}"
).datamodule
subset = datamodule.database_split.subsets
assert len(subset) == 3
assert "train" in subset
assert len(subset["train"]) == 100
for s in subset["train"]:
assert s.key.startswith("CXR_png/MCUCXR_0")
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "validation" in subset
assert len(subset["validation"]) == 25
for s in subset["validation"]:
assert s.key.startswith("CXR_png/MCUCXR_0")
assert s[0].startswith("CXR_png/MCUCXR_0")
assert "test" in subset
assert len(subset["test"]) == 13
for s in subset["test"]:
assert s.key.startswith("CXR_png/MCUCXR_0")
assert s[0].startswith("CXR_png/MCUCXR_0")
# Check labels
for s in subset["train"]:
assert s.label in [0.0, 1.0]
assert s[1] in [0.0, 1.0]
for s in subset["validation"]:
assert s.label in [0.0, 1.0]
assert s[1] in [0.0, 1.0]
for s in subset["test"]:
assert s.label in [0.0, 1.0]
assert s[1] in [0.0, 1.0]
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_loading():
from ptbench.data.montgomery import dataset
import torch
import torchvision.transforms
from ptbench.data.datamodule import _DelayedLoadingDataset
def _check_sample(s):
data = s.data
assert isinstance(data, dict)
assert len(data) == 2
assert "data" in data
assert data["data"].size in (
(4020, 4892), # portrait
(4892, 4020), # landscape
(512, 512), # test database @ CI
data = s[0]
metadata = s[1]
assert isinstance(data, torch.Tensor)
assert data.size in (
(1, 4020, 4892), # portrait
(1, 4892, 4020), # landscape
(1, 512, 512), # test database @ CI
)
assert data["data"].mode == "L" # Check colors
assert (
torchvision.transforms.ToPILImage()(data).mode == "L"
) # Check colors
assert "label" in data
assert data["label"] in [0, 1] # Check labels
assert "label" in metadata
assert metadata["label"] in [0, 1] # Check labels
limit = 30 # use this to limit testing to first images only, else None
subset = dataset.subsets("default")
for s in subset["train"][:limit]:
datamodule = importlib.import_module(
"ptbench.data.montgomery.default"
).datamodule
subset = datamodule.database_split.subsetss
raw_data_loader = datamodule.raw_data_loader
# Need to use private function so we can limit the number of samples to use
dataset = _DelayedLoadingDataset(
subset["train"][:limit],
raw_data_loader
)
for s in dataset:
_check_sample(s)
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_check():
from ptbench.data.montgomery import dataset
from ptbench.data.split import check_database_split_loading
limit = 30 # use this to limit testing to first images only, else 0
# Default protocol
datamodule = importlib.import_module(
"ptbench.data.montgomery.default"
).datamodule
database_split = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
assert (
check_database_split_loading(
database_split, raw_data_loader, limit=limit
)
== 0
)
# Folds
for f in range(10):
datamodule = importlib.import_module(
f"ptbench.data.montgomery.fold_{f}"
).datamodule
database_split = datamodule.database_split
raw_data_loader = datamodule.raw_data_loader
assert (
check_database_split_loading(
database_split, raw_data_loader, limit=limit
)
== 0
)
assert dataset.check() == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment