From 14816eae60da76064a56458b31bcfc3b30816e59 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 28 May 2024 11:37:32 +0200 Subject: [PATCH] [tests] Minor fixes --- .../segmentation/config/data/drive/datamodule.py | 4 ++-- src/mednet/libs/segmentation/tests/test_drive.py | 15 ++++----------- src/mednet/libs/segmentation/tests/test_stare.py | 13 +++---------- 3 files changed, 9 insertions(+), 23 deletions(-) diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py index 893cbbb8..d7f2890a 100644 --- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py +++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py @@ -25,7 +25,7 @@ database.""" class SegmentationRawDataLoader(_SegmentationRawDataLoader): - """A specialized raw-data-loader for the Montgomery dataset.""" + """A specialized raw-data-loader for the Drive dataset.""" datadir: str """This variable contains the base directory where the database raw data is @@ -73,7 +73,7 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader): def make_split(basename: str) -> DatabaseSplit: - """Return a database split for the Montgomery database. + """Return a database split for the Drive database. Parameters ---------- diff --git a/src/mednet/libs/segmentation/tests/test_drive.py b/src/mednet/libs/segmentation/tests/test_drive.py index a90f6318..d1a1301d 100644 --- a/src/mednet/libs/segmentation/tests/test_drive.py +++ b/src/mednet/libs/segmentation/tests/test_drive.py @@ -109,17 +109,10 @@ def test_raw_transforms_image_quality(database_checkers, datadir): ["lwnet"], ) def test_model_transforms_image_quality(database_checkers, datadir, model_name): - # Densenet's model.name is "densenet-212" and does not correspond to its module name. - if model_name == "densenet": - reference_histogram_file = str( - datadir - / "histograms/models/histograms_densenet-121_drive_default.json", - ) - else: - reference_histogram_file = str( - datadir - / f"histograms/models/histograms_{model_name}_drive_default.json", - ) + reference_histogram_file = str( + datadir + / f"histograms/models/histograms_{model_name}_drive_default.json", + ) datamodule = importlib.import_module( ".default", diff --git a/src/mednet/libs/segmentation/tests/test_stare.py b/src/mednet/libs/segmentation/tests/test_stare.py index 99b36913..9b003894 100644 --- a/src/mednet/libs/segmentation/tests/test_stare.py +++ b/src/mednet/libs/segmentation/tests/test_stare.py @@ -111,16 +111,9 @@ def test_raw_transforms_image_quality(database_checkers, datadir): ["lwnet"], ) def test_model_transforms_image_quality(database_checkers, datadir, model_name): - # Densenet's model.name is "densenet-212" and does not correspond to its module name. - if model_name == "densenet": - reference_histogram_file = str( - datadir / "histograms/models/histograms_densenet-121_stare_ah.json", - ) - else: - reference_histogram_file = str( - datadir - / f"histograms/models/histograms_{model_name}_stare_ah.json", - ) + reference_histogram_file = str( + datadir / f"histograms/models/histograms_{model_name}_stare_ah.json", + ) datamodule = importlib.import_module( ".ah", -- GitLab