From ccc390cc1affdb4d85bf3e4f4a62385e7fca7a81 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 18 Jul 2023 15:58:50 +0200
Subject: [PATCH] Fixed issue with check_database_split_loading

---
 src/ptbench/data/split.py | 4 ++--
 tests/test_ch.py          | 6 ++----
 2 files changed, 4 insertions(+), 6 deletions(-)

diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py
index 78fe7e33..606e40cb 100644
--- a/src/ptbench/data/split.py
+++ b/src/ptbench/data/split.py
@@ -238,8 +238,8 @@ def check_database_split_loading(
         "Checking if can load all samples in all subsets of this split..."
     )
     errors = 0
-    for subset in database_split.keys():
-        samples = subset if not limit else subset[:limit]
+    for subset, samples in database_split.items():
+        samples = samples if not limit else samples[:limit]
         for pos, sample in enumerate(samples):
             try:
                 data, _ = loader.sample(sample)
diff --git a/tests/test_ch.py b/tests/test_ch.py
index 510b1171..659e2c35 100644
--- a/tests/test_ch.py
+++ b/tests/test_ch.py
@@ -132,8 +132,6 @@ def test_loading():
         metadata = s[1]
 
         assert isinstance(data, torch.Tensor)
-
-        print(data.shape)
         assert _check_size(data.shape)  # Check size
 
         assert (
@@ -176,7 +174,7 @@ def test_check():
 
     assert (
         check_database_split_loading(
-            database_split, raw_data_loader, limit=limit
+            database_split.subsets, raw_data_loader, limit=limit
         )
         == 0
     )
@@ -191,7 +189,7 @@ def test_check():
 
         assert (
             check_database_split_loading(
-                database_split, raw_data_loader, limit=limit
+                database_split.subsets, raw_data_loader, limit=limit
             )
             == 0
         )
-- 
GitLab