From 5e6d38bccf1e2da6f284a3449f12f2e13f5f7441 Mon Sep 17 00:00:00 2001
From: mdelitroz <maxime.delitroz@idiap.ch>
Date: Wed, 19 Jul 2023 13:39:26 +0200
Subject: [PATCH] updated tests for Montgomery dataset

---
 tests/test_mc.py | 137 +++++++++++++++++++++++++++++++++--------------
 1 file changed, 98 insertions(+), 39 deletions(-)

diff --git a/tests/test_mc.py b/tests/test_mc.py
index 1b2aa4fd..87de46ea 100644
--- a/tests/test_mc.py
+++ b/tests/test_mc.py
@@ -4,131 +4,190 @@
 
 """Tests for Montgomery dataset."""
 
+import importlib
+
 import pytest
 
 
 def test_protocol_consistency():
-    from ptbench.data.montgomery import dataset
 
     # Default protocol
-    subset = dataset.subsets("default")
+    datamodule = importlib.import_module(
+        "ptbench.data.montgomery.default"
+    ).datamodule
+    subset = datamodule.dataset_split.subsets
+
     assert len(subset) == 3
 
     assert "train" in subset
     assert len(subset["train"]) == 88
     for s in subset["train"]:
-        assert s.key.startswith("CXR_png/MCUCXR_0")
+        assert s[0].startswith("CXR_png/MCUCXR_0")
 
     assert "validation" in subset
     assert len(subset["validation"]) == 22
     for s in subset["validation"]:
-        assert s.key.startswith("CXR_png/MCUCXR_0")
+        assert s[0].startswith("CXR_png/MCUCXR_0")
 
     assert "test" in subset
     assert len(subset["test"]) == 28
     for s in subset["test"]:
-        assert s.key.startswith("CXR_png/MCUCXR_0")
+        assert s[0].startswith("CXR_png/MCUCXR_0")
 
     # Check labels
     for s in subset["train"]:
-        assert s.label in [0.0, 1.0]
+        assert s[1] in [0.0, 1.0]
 
     for s in subset["validation"]:
-        assert s.label in [0.0, 1.0]
+        assert s[1] in [0.0, 1.0]
 
     for s in subset["test"]:
-        assert s.label in [0.0, 1.0]
+        assert s[1] in [0.0, 1.0]
 
     # Cross-validation fold 0-7
     for f in range(8):
-        subset = dataset.subsets("fold_" + str(f))
+        datamodule = importlib.import_module(
+            f"ptbench.data.montgomery.fold_{str(f)}"
+        ).datamodule
+        subset = datamodule.database_split.subsets
+
         assert len(subset) == 3
 
         assert "train" in subset
         assert len(subset["train"]) == 99
         for s in subset["train"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         assert "validation" in subset
         assert len(subset["validation"]) == 25
         for s in subset["validation"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         assert "test" in subset
         assert len(subset["test"]) == 14
         for s in subset["test"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         # Check labels
         for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
     # Cross-validation fold 8-9
     for f in range(8, 10):
-        subset = dataset.subsets("fold_" + str(f))
+        datamodule = importlib.import_module(
+            f"ptbench.data.montgomery.fold_{str(f)}"
+        ).datamodule
+        subset = datamodule.database_split.subsets
+
         assert len(subset) == 3
 
         assert "train" in subset
         assert len(subset["train"]) == 100
         for s in subset["train"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         assert "validation" in subset
         assert len(subset["validation"]) == 25
         for s in subset["validation"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         assert "test" in subset
         assert len(subset["test"]) == 13
         for s in subset["test"]:
-            assert s.key.startswith("CXR_png/MCUCXR_0")
+            assert s[0].startswith("CXR_png/MCUCXR_0")
 
         # Check labels
         for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_loading():
-    from ptbench.data.montgomery import dataset
+    import torch
+    import torchvision.transforms
+
+    from ptbench.data.datamodule import _DelayedLoadingDataset
 
     def _check_sample(s):
-        data = s.data
-        assert isinstance(data, dict)
-        assert len(data) == 2
-
-        assert "data" in data
-        assert data["data"].size in (
-            (4020, 4892),  # portrait
-            (4892, 4020),  # landscape
-            (512, 512),  # test database @ CI
+        data = s[0]
+        metadata = s[1]
+
+        assert isinstance(data, torch.Tensor)
+
+        assert data.size in (
+            (1, 4020, 4892),  # portrait
+            (1, 4892, 4020),  # landscape
+            (1, 512, 512),  # test database @ CI
         )
-        assert data["data"].mode == "L"  # Check colors
+        assert (
+            torchvision.transforms.ToPILImage()(data).mode == "L" 
+        ) # Check colors
 
-        assert "label" in data
-        assert data["label"] in [0, 1]  # Check labels
+        assert "label" in metadata
+        assert metadata["label"] in [0, 1]  # Check labels
 
     limit = 30  # use this to limit testing to first images only, else None
 
-    subset = dataset.subsets("default")
-    for s in subset["train"][:limit]:
+    datamodule = importlib.import_module(
+        "ptbench.data.montgomery.default"
+    ).datamodule
+    subset = datamodule.database_split.subsetss
+    raw_data_loader = datamodule.raw_data_loader
+
+    # Need to use private function so we can limit the number of samples to use
+    dataset = _DelayedLoadingDataset(
+        subset["train"][:limit],
+        raw_data_loader
+    )
+
+    for s in dataset:
         _check_sample(s)
 
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_check():
-    from ptbench.data.montgomery import dataset
+    from ptbench.data.split import check_database_split_loading
+
+    limit = 30  # use this to limit testing to first images only, else 0
+
+    # Default protocol
+    datamodule = importlib.import_module(
+        "ptbench.data.montgomery.default"
+    ).datamodule
+    database_split = datamodule.database_split
+    raw_data_loader = datamodule.raw_data_loader
+
+    assert (
+        check_database_split_loading(
+            database_split, raw_data_loader, limit=limit
+        )
+        == 0
+    )
+
+    # Folds
+    for f in range(10):
+        datamodule = importlib.import_module(
+            f"ptbench.data.montgomery.fold_{f}"
+        ).datamodule
+        database_split = datamodule.database_split
+        raw_data_loader = datamodule.raw_data_loader
+
+        assert (
+            check_database_split_loading(
+                database_split, raw_data_loader, limit=limit
+            )
+            == 0
+        )
 
-    assert dataset.check() == 0
-- 
GitLab