From 94d25ec88d7a28ddbaff36ec7aeaedc30d6c63a8 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 26 Jul 2023 09:58:07 +0200
Subject: [PATCH] Fixed access to splits subsets in tests

---
 tests/test_ch.py | 19 +++++++++++--------
 tests/test_mc.py | 11 +++++++----
 2 files changed, 18 insertions(+), 12 deletions(-)

diff --git a/tests/test_ch.py b/tests/test_ch.py
index c678e087..787fa95e 100644
--- a/tests/test_ch.py
+++ b/tests/test_ch.py
@@ -15,7 +15,8 @@ def test_protocol_consistency():
     datamodule = importlib.import_module(
         "ptbench.data.shenzhen.default"
     ).datamodule
-    subset = datamodule.database_split.subsets
+
+    subset = datamodule.database_split
 
     assert len(subset) == 3
 
@@ -49,7 +50,8 @@ def test_protocol_consistency():
         datamodule = importlib.import_module(
             f"ptbench.data.shenzhen.fold_{str(f)}"
         ).datamodule
-        subset = datamodule.database_split.subsets
+
+        subset = datamodule.database_split
 
         assert len(subset) == 3
 
@@ -83,7 +85,8 @@ def test_protocol_consistency():
         datamodule = importlib.import_module(
             f"ptbench.data.shenzhen.fold_{str(f)}"
         ).datamodule
-        subset = datamodule.database_split.subsets
+
+        subset = datamodule.database_split
 
         assert len(subset) == 3
 
@@ -128,11 +131,11 @@ def test_loading():
 
         assert isinstance(data, torch.Tensor)
 
-        assert data.size(0) == 3  # check 3 channels
+        assert data.size(0) == 1  # check 1 channel
         assert data.size(1) == data.size(2)  # check square image
 
         assert (
-            torchvision.transforms.ToPILImage()(data).mode == "RGB"
+            torchvision.transforms.ToPILImage()(data).mode == "L"
         )  # Check colors
 
         assert "label" in metadata
@@ -143,7 +146,7 @@ def test_loading():
     datamodule = importlib.import_module(
         "ptbench.data.shenzhen.default"
     ).datamodule
-    subset = datamodule.database_split.subsets
+    subset = datamodule.database_split
     raw_data_loader = datamodule.raw_data_loader
 
     # Need to use private function so we can limit the number of samples to use
@@ -171,7 +174,7 @@ def test_check():
 
     assert (
         check_database_split_loading(
-            database_split.subsets, raw_data_loader, limit=limit
+            database_split, raw_data_loader, limit=limit
         )
         == 0
     )
@@ -186,7 +189,7 @@ def test_check():
 
         assert (
             check_database_split_loading(
-                database_split.subsets, raw_data_loader, limit=limit
+                database_split, raw_data_loader, limit=limit
             )
             == 0
         )
diff --git a/tests/test_mc.py b/tests/test_mc.py
index 25bd4709..3cc6adb0 100644
--- a/tests/test_mc.py
+++ b/tests/test_mc.py
@@ -14,7 +14,8 @@ def test_protocol_consistency():
     datamodule = importlib.import_module(
         "ptbench.data.montgomery.default"
     ).datamodule
-    subset = datamodule.database_split.subsets
+
+    subset = datamodule.database_split
 
     assert len(subset) == 3
 
@@ -48,7 +49,7 @@ def test_protocol_consistency():
         datamodule = importlib.import_module(
             f"ptbench.data.montgomery.fold_{str(f)}"
         ).datamodule
-        subset = datamodule.database_split.subsets
+        subset = datamodule.database_split
 
         assert len(subset) == 3
 
@@ -82,7 +83,7 @@ def test_protocol_consistency():
         datamodule = importlib.import_module(
             f"ptbench.data.montgomery.fold_{str(f)}"
         ).datamodule
-        subset = datamodule.database_split.subsets
+        subset = datamodule.database_split
 
         assert len(subset) == 3
 
@@ -120,6 +121,8 @@ def test_loading():
     from ptbench.data.datamodule import _DelayedLoadingDataset
 
     def _check_sample(s):
+        assert len(s) == 2
+
         data = s[0]
         metadata = s[1]
 
@@ -140,7 +143,7 @@ def test_loading():
     datamodule = importlib.import_module(
         "ptbench.data.montgomery.default"
     ).datamodule
-    subset = datamodule.database_split.subsets
+    subset = datamodule.database_split
     raw_data_loader = datamodule.raw_data_loader
 
     # Need to use private function so we can limit the number of samples to use
-- 
GitLab