From 5e6d38bccf1e2da6f284a3449f12f2e13f5f7441 Mon Sep 17 00:00:00 2001 From: mdelitroz <maxime.delitroz@idiap.ch> Date: Wed, 19 Jul 2023 13:39:26 +0200 Subject: [PATCH] updated tests for Montgomery dataset --- tests/test_mc.py | 137 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 98 insertions(+), 39 deletions(-) diff --git a/tests/test_mc.py b/tests/test_mc.py index 1b2aa4fd..87de46ea 100644 --- a/tests/test_mc.py +++ b/tests/test_mc.py @@ -4,131 +4,190 @@ """Tests for Montgomery dataset.""" +import importlib + import pytest def test_protocol_consistency(): - from ptbench.data.montgomery import dataset # Default protocol - subset = dataset.subsets("default") + datamodule = importlib.import_module( + "ptbench.data.montgomery.default" + ).datamodule + subset = datamodule.dataset_split.subsets + assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 88 for s in subset["train"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "validation" in subset assert len(subset["validation"]) == 22 for s in subset["validation"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "test" in subset assert len(subset["test"]) == 28 for s in subset["test"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") # Check labels for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["validation"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["test"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] # Cross-validation fold 0-7 for f in range(8): - subset = dataset.subsets("fold_" + str(f)) + datamodule = importlib.import_module( + f"ptbench.data.montgomery.fold_{str(f)}" + ).datamodule + subset = datamodule.database_split.subsets + assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 99 for s in subset["train"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "validation" in subset assert len(subset["validation"]) == 25 for s in subset["validation"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "test" in subset assert len(subset["test"]) == 14 for s in subset["test"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") # Check labels for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["validation"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["test"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] # Cross-validation fold 8-9 for f in range(8, 10): - subset = dataset.subsets("fold_" + str(f)) + datamodule = importlib.import_module( + f"ptbench.data.montgomery.fold_{str(f)}" + ).datamodule + subset = datamodule.database_split.subsets + assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 100 for s in subset["train"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "validation" in subset assert len(subset["validation"]) == 25 for s in subset["validation"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") assert "test" in subset assert len(subset["test"]) == 13 for s in subset["test"]: - assert s.key.startswith("CXR_png/MCUCXR_0") + assert s[0].startswith("CXR_png/MCUCXR_0") # Check labels for s in subset["train"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["validation"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] for s in subset["test"]: - assert s.label in [0.0, 1.0] + assert s[1] in [0.0, 1.0] @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_loading(): - from ptbench.data.montgomery import dataset + import torch + import torchvision.transforms + + from ptbench.data.datamodule import _DelayedLoadingDataset def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert data["data"].size in ( - (4020, 4892), # portrait - (4892, 4020), # landscape - (512, 512), # test database @ CI + data = s[0] + metadata = s[1] + + assert isinstance(data, torch.Tensor) + + assert data.size in ( + (1, 4020, 4892), # portrait + (1, 4892, 4020), # landscape + (1, 512, 512), # test database @ CI ) - assert data["data"].mode == "L" # Check colors + assert ( + torchvision.transforms.ToPILImage()(data).mode == "L" + ) # Check colors - assert "label" in data - assert data["label"] in [0, 1] # Check labels + assert "label" in metadata + assert metadata["label"] in [0, 1] # Check labels limit = 30 # use this to limit testing to first images only, else None - subset = dataset.subsets("default") - for s in subset["train"][:limit]: + datamodule = importlib.import_module( + "ptbench.data.montgomery.default" + ).datamodule + subset = datamodule.database_split.subsetss + raw_data_loader = datamodule.raw_data_loader + + # Need to use private function so we can limit the number of samples to use + dataset = _DelayedLoadingDataset( + subset["train"][:limit], + raw_data_loader + ) + + for s in dataset: _check_sample(s) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_check(): - from ptbench.data.montgomery import dataset + from ptbench.data.split import check_database_split_loading + + limit = 30 # use this to limit testing to first images only, else 0 + + # Default protocol + datamodule = importlib.import_module( + "ptbench.data.montgomery.default" + ).datamodule + database_split = datamodule.database_split + raw_data_loader = datamodule.raw_data_loader + + assert ( + check_database_split_loading( + database_split, raw_data_loader, limit=limit + ) + == 0 + ) + + # Folds + for f in range(10): + datamodule = importlib.import_module( + f"ptbench.data.montgomery.fold_{f}" + ).datamodule + database_split = datamodule.database_split + raw_data_loader = datamodule.raw_data_loader + + assert ( + check_database_split_loading( + database_split, raw_data_loader, limit=limit + ) + == 0 + ) - assert dataset.check() == 0 -- GitLab