From afda408c40ea4c98efbbac4532af63fa532f350a Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 14 May 2024 14:40:47 +0200
Subject: [PATCH] [tests] Fix tests

---
 src/mednet/libs/classification/data/typing.py             | 2 +-
 .../libs/classification/engine/saliency/completeness.py   | 2 +-
 .../classification/engine/saliency/interpretability.py    | 2 +-
 .../libs/classification/tests/test_montgomery_shenzhen.py | 4 ++--
 .../tests/test_montgomery_shenzhen_indian.py              | 6 +++---
 .../tests/test_montgomery_shenzhen_indian_padchest.py     | 8 ++++----
 .../tests/test_montgomery_shenzhen_indian_tbx11k.py       | 8 ++++----
 .../libs/classification/tests/test_nih_cxr14_padchest.py  | 4 ++--
 src/mednet/libs/classification/tests/test_tbx11k.py       | 8 ++++----
 src/mednet/libs/common/tests/conftest.py                  | 6 +++---
 10 files changed, 25 insertions(+), 25 deletions(-)

diff --git a/src/mednet/libs/classification/data/typing.py b/src/mednet/libs/classification/data/typing.py
index b504f846..0d18220e 100644
--- a/src/mednet/libs/classification/data/typing.py
+++ b/src/mednet/libs/classification/data/typing.py
@@ -38,4 +38,4 @@ class ClassificationRawDataLoader(RawDataLoader):
             The label corresponding to the specified sample.
         """
 
-        return self.sample(k)[1]["label"]
+        return self.sample(k)[1]["target"]
diff --git a/src/mednet/libs/classification/engine/saliency/completeness.py b/src/mednet/libs/classification/engine/saliency/completeness.py
index e57ba8d1..c7f3be68 100644
--- a/src/mednet/libs/classification/engine/saliency/completeness.py
+++ b/src/mednet/libs/classification/engine/saliency/completeness.py
@@ -168,7 +168,7 @@ def _process_sample(
     """
 
     name: str = sample[1]["name"][0]
-    label: int = int(sample[1]["label"].item())
+    label: int = int(sample[1]["target"].item())
     image = sample[0]
 
     # in binary classification systems, negative labels may be skipped
diff --git a/src/mednet/libs/classification/engine/saliency/interpretability.py b/src/mednet/libs/classification/engine/saliency/interpretability.py
index 621aa9d6..da96650e 100644
--- a/src/mednet/libs/classification/engine/saliency/interpretability.py
+++ b/src/mednet/libs/classification/engine/saliency/interpretability.py
@@ -426,7 +426,7 @@ def run(
             disable=None,
         ):
             name = str(sample[1]["name"][0])
-            label = int(sample[1]["label"].item())
+            label = int(sample[1]["target"].item())
 
             if label != target_label:
                 # we add the entry for dataset completeness, but do not treat
diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py
index 7d94a87a..57cac2cf 100644
--- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py
+++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py
@@ -44,12 +44,12 @@ def test_split_consistency(name: str):
     montgomery_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.montgomery",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     shenzhen_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.shenzhen",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     for split in ("train", "validation", "test"):
         assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py
index 3a95cd7b..b1f0ec09 100644
--- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py
+++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py
@@ -49,17 +49,17 @@ def test_split_consistency(name: str):
     montgomery_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.montgomery",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     shenzhen_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.shenzhen",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     indian_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.indian",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     for split in ("train", "validation", "test"):
         assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py
index bf7e913f..3872fb1b 100644
--- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py
+++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py
@@ -44,22 +44,22 @@ def test_split_consistency(name: str, padchest_name: str):
     montgomery_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.montgomery",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     shenzhen_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.shenzhen",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     indian_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.indian",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     padchest_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.padchest",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     for split in ("train", "validation", "test"):
         assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py
index 644d8073..fd9c88cf 100644
--- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py
+++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py
@@ -65,22 +65,22 @@ def test_split_consistency(name: str, tbx11k_name: str):
     montgomery_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.montgomery",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     shenzhen_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.shenzhen",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     indian_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.indian",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     tbx11k_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.tbx11k",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     for split in ("train", "validation", "test"):
         assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
diff --git a/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py b/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py
index bc5cde53..987c2ba7 100644
--- a/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py
+++ b/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py
@@ -34,12 +34,12 @@ def test_split_consistency(name: str, padchest_name: str, combined_name: str):
     cxr14_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.nih_cxr14",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     padchest_loader = importlib.import_module(
         ".datamodule",
         "mednet.libs.classification.config.data.padchest",
-    ).RawDataLoader
+    ).ClassificationRawDataLoader
 
     for split in ("train", "validation", "test"):
         assert nih_cxr14.splits[split][0][0] == combined.splits[split][0][0]
diff --git a/src/mednet/libs/classification/tests/test_tbx11k.py b/src/mednet/libs/classification/tests/test_tbx11k.py
index 19be118c..0f9a5884 100644
--- a/src/mednet/libs/classification/tests/test_tbx11k.py
+++ b/src/mednet/libs/classification/tests/test_tbx11k.py
@@ -192,11 +192,11 @@ def check_loaded_batch(
     assert isinstance(batch[1], dict)  # metadata
     assert len(batch[1]) == 3  # label, name and radiological sign bounding-boxes
 
-    assert "label" in batch[1]
-    assert all([k in possible_labels for k in batch[1]["label"]])
+    assert "target" in batch[1]
+    assert all([k in possible_labels for k in batch[1]["target"]])
 
     if expected_num_labels:
-        assert len(batch[1]["label"]) == expected_num_labels
+        assert len(batch[1]["target"]) == expected_num_labels
 
     assert "name" in batch[1]
     assert all(
@@ -207,7 +207,7 @@ def check_loaded_batch(
 
     for sample, label, bboxes in zip(
         batch[0],
-        batch[1]["label"],
+        batch[1]["target"],
         batch[1]["bounding_boxes"],
     ):
         # there must be a sign indicated on the image, if active TB is detected
diff --git a/src/mednet/libs/common/tests/conftest.py b/src/mednet/libs/common/tests/conftest.py
index a024dbbc..c732ed83 100644
--- a/src/mednet/libs/common/tests/conftest.py
+++ b/src/mednet/libs/common/tests/conftest.py
@@ -190,11 +190,11 @@ class DatabaseCheckers:
         assert isinstance(batch[1], dict)  # metadata
         assert len(batch[1]) == 2  # label and name
 
-        assert "label" in batch[1]
-        assert all([k in possible_labels for k in batch[1]["label"]])
+        assert "target" in batch[1]
+        assert all([k in possible_labels for k in batch[1]["target"]])
 
         if expected_num_labels:
-            assert len(batch[1]["label"]) == expected_num_labels
+            assert len(batch[1]["target"]) == expected_num_labels
 
         assert "name" in batch[1]
         assert all(
-- 
GitLab