From 0d2f0163bc08ff87daa0569fb71ffd1d42969c38 Mon Sep 17 00:00:00 2001
From: Oscar Alfonso Jimenez Del Toro <oscar.jimenez@idiap.ch>
Date: Mon, 24 Jun 2024 16:53:01 +0200
Subject: [PATCH] Test batch uniformity for combined datasets

---
 .../test_montgomery_shenzhen_indian_tbx11k.py | 31 +++++++++++++++++++
 1 file changed, 31 insertions(+)

diff --git a/tests/test_montgomery_shenzhen_indian_tbx11k.py b/tests/test_montgomery_shenzhen_indian_tbx11k.py
index 75464104..8e45c6d2 100644
--- a/tests/test_montgomery_shenzhen_indian_tbx11k.py
+++ b/tests/test_montgomery_shenzhen_indian_tbx11k.py
@@ -100,6 +100,37 @@ def test_split_consistency(name: str, tbx11k_name: str):
         assert isinstance(combined.splits[split][3][1], tbx11k_loader)
 
 
+@pytest.mark.parametrize(
+    "dataset",
+    [
+        "train",
+    ],
+)
+@pytest.mark.parametrize(
+    "tbx11k_name",
+    [
+        ("v1_healthy_vs_atb"),
+    ],
+)    
+def test_batch_uniformity(tbx11k_name: str, dataset: str):
+
+    combined = importlib.import_module(
+        f".{tbx11k_name}",
+        "mednet.config.data.montgomery_shenzhen_indian_tbx11k",
+    ).datamodule
+    
+    combined.model_transforms = []  # should be done before setup()
+    combined.setup("predict")  # sets up all datasets
+
+    loader = combined.predict_dataloader()[dataset]
+
+    limit = 5  # limit load checking
+    for batch in loader:
+        if limit == 0:
+            break
+        assert len(batch[1]) == 2  # label, name. No radiological sign bounding-boxes
+
+  
 @pytest.mark.slow
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
-- 
GitLab