From 14816eae60da76064a56458b31bcfc3b30816e59 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 28 May 2024 11:37:32 +0200
Subject: [PATCH] [tests] Minor fixes

---
 .../segmentation/config/data/drive/datamodule.py  |  4 ++--
 src/mednet/libs/segmentation/tests/test_drive.py  | 15 ++++-----------
 src/mednet/libs/segmentation/tests/test_stare.py  | 13 +++----------
 3 files changed, 9 insertions(+), 23 deletions(-)

diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
index 893cbbb8..d7f2890a 100644
--- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
@@ -25,7 +25,7 @@ database."""
 
 
 class SegmentationRawDataLoader(_SegmentationRawDataLoader):
-    """A specialized raw-data-loader for the Montgomery dataset."""
+    """A specialized raw-data-loader for the Drive dataset."""
 
     datadir: str
     """This variable contains the base directory where the database raw data is
@@ -73,7 +73,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
 
 
 def make_split(basename: str) -> DatabaseSplit:
-    """Return a database split for the Montgomery database.
+    """Return a database split for the Drive database.
 
     Parameters
     ----------
diff --git a/src/mednet/libs/segmentation/tests/test_drive.py b/src/mednet/libs/segmentation/tests/test_drive.py
index a90f6318..d1a1301d 100644
--- a/src/mednet/libs/segmentation/tests/test_drive.py
+++ b/src/mednet/libs/segmentation/tests/test_drive.py
@@ -109,17 +109,10 @@ def test_raw_transforms_image_quality(database_checkers, datadir):
     ["lwnet"],
 )
 def test_model_transforms_image_quality(database_checkers, datadir, model_name):
-    # Densenet's model.name is "densenet-212" and does not correspond to its module name.
-    if model_name == "densenet":
-        reference_histogram_file = str(
-            datadir
-            / "histograms/models/histograms_densenet-121_drive_default.json",
-        )
-    else:
-        reference_histogram_file = str(
-            datadir
-            / f"histograms/models/histograms_{model_name}_drive_default.json",
-        )
+    reference_histogram_file = str(
+        datadir
+        / f"histograms/models/histograms_{model_name}_drive_default.json",
+    )
 
     datamodule = importlib.import_module(
         ".default",
diff --git a/src/mednet/libs/segmentation/tests/test_stare.py b/src/mednet/libs/segmentation/tests/test_stare.py
index 99b36913..9b003894 100644
--- a/src/mednet/libs/segmentation/tests/test_stare.py
+++ b/src/mednet/libs/segmentation/tests/test_stare.py
@@ -111,16 +111,9 @@ def test_raw_transforms_image_quality(database_checkers, datadir):
     ["lwnet"],
 )
 def test_model_transforms_image_quality(database_checkers, datadir, model_name):
-    # Densenet's model.name is "densenet-212" and does not correspond to its module name.
-    if model_name == "densenet":
-        reference_histogram_file = str(
-            datadir / "histograms/models/histograms_densenet-121_stare_ah.json",
-        )
-    else:
-        reference_histogram_file = str(
-            datadir
-            / f"histograms/models/histograms_{model_name}_stare_ah.json",
-        )
+    reference_histogram_file = str(
+        datadir / f"histograms/models/histograms_{model_name}_stare_ah.json",
+    )
 
     datamodule = importlib.import_module(
         ".ah",
-- 
GitLab