From afda408c40ea4c98efbbac4532af63fa532f350a Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 14 May 2024 14:40:47 +0200 Subject: [PATCH] [tests] Fix tests --- src/mednet/libs/classification/data/typing.py | 2 +- .../libs/classification/engine/saliency/completeness.py | 2 +- .../classification/engine/saliency/interpretability.py | 2 +- .../libs/classification/tests/test_montgomery_shenzhen.py | 4 ++-- .../tests/test_montgomery_shenzhen_indian.py | 6 +++--- .../tests/test_montgomery_shenzhen_indian_padchest.py | 8 ++++---- .../tests/test_montgomery_shenzhen_indian_tbx11k.py | 8 ++++---- .../libs/classification/tests/test_nih_cxr14_padchest.py | 4 ++-- src/mednet/libs/classification/tests/test_tbx11k.py | 8 ++++---- src/mednet/libs/common/tests/conftest.py | 6 +++--- 10 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/mednet/libs/classification/data/typing.py b/src/mednet/libs/classification/data/typing.py index b504f846..0d18220e 100644 --- a/src/mednet/libs/classification/data/typing.py +++ b/src/mednet/libs/classification/data/typing.py @@ -38,4 +38,4 @@ class ClassificationRawDataLoader(RawDataLoader): The label corresponding to the specified sample. """ - return self.sample(k)[1]["label"] + return self.sample(k)[1]["target"] diff --git a/src/mednet/libs/classification/engine/saliency/completeness.py b/src/mednet/libs/classification/engine/saliency/completeness.py index e57ba8d1..c7f3be68 100644 --- a/src/mednet/libs/classification/engine/saliency/completeness.py +++ b/src/mednet/libs/classification/engine/saliency/completeness.py @@ -168,7 +168,7 @@ def _process_sample( """ name: str = sample[1]["name"][0] - label: int = int(sample[1]["label"].item()) + label: int = int(sample[1]["target"].item()) image = sample[0] # in binary classification systems, negative labels may be skipped diff --git a/src/mednet/libs/classification/engine/saliency/interpretability.py b/src/mednet/libs/classification/engine/saliency/interpretability.py index 621aa9d6..da96650e 100644 --- a/src/mednet/libs/classification/engine/saliency/interpretability.py +++ b/src/mednet/libs/classification/engine/saliency/interpretability.py @@ -426,7 +426,7 @@ def run( disable=None, ): name = str(sample[1]["name"][0]) - label = int(sample[1]["label"].item()) + label = int(sample[1]["target"].item()) if label != target_label: # we add the entry for dataset completeness, but do not treat diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py index 7d94a87a..57cac2cf 100644 --- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py +++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py @@ -44,12 +44,12 @@ def test_split_consistency(name: str): montgomery_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.montgomery", - ).RawDataLoader + ).ClassificationRawDataLoader shenzhen_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.shenzhen", - ).RawDataLoader + ).ClassificationRawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py index 3a95cd7b..b1f0ec09 100644 --- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py +++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py @@ -49,17 +49,17 @@ def test_split_consistency(name: str): montgomery_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.montgomery", - ).RawDataLoader + ).ClassificationRawDataLoader shenzhen_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.shenzhen", - ).RawDataLoader + ).ClassificationRawDataLoader indian_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.indian", - ).RawDataLoader + ).ClassificationRawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py index bf7e913f..3872fb1b 100644 --- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py +++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py @@ -44,22 +44,22 @@ def test_split_consistency(name: str, padchest_name: str): montgomery_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.montgomery", - ).RawDataLoader + ).ClassificationRawDataLoader shenzhen_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.shenzhen", - ).RawDataLoader + ).ClassificationRawDataLoader indian_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.indian", - ).RawDataLoader + ).ClassificationRawDataLoader padchest_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.padchest", - ).RawDataLoader + ).ClassificationRawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py index 644d8073..fd9c88cf 100644 --- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py +++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py @@ -65,22 +65,22 @@ def test_split_consistency(name: str, tbx11k_name: str): montgomery_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.montgomery", - ).RawDataLoader + ).ClassificationRawDataLoader shenzhen_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.shenzhen", - ).RawDataLoader + ).ClassificationRawDataLoader indian_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.indian", - ).RawDataLoader + ).ClassificationRawDataLoader tbx11k_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.tbx11k", - ).RawDataLoader + ).ClassificationRawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] diff --git a/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py b/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py index bc5cde53..987c2ba7 100644 --- a/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py +++ b/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py @@ -34,12 +34,12 @@ def test_split_consistency(name: str, padchest_name: str, combined_name: str): cxr14_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.nih_cxr14", - ).RawDataLoader + ).ClassificationRawDataLoader padchest_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.padchest", - ).RawDataLoader + ).ClassificationRawDataLoader for split in ("train", "validation", "test"): assert nih_cxr14.splits[split][0][0] == combined.splits[split][0][0] diff --git a/src/mednet/libs/classification/tests/test_tbx11k.py b/src/mednet/libs/classification/tests/test_tbx11k.py index 19be118c..0f9a5884 100644 --- a/src/mednet/libs/classification/tests/test_tbx11k.py +++ b/src/mednet/libs/classification/tests/test_tbx11k.py @@ -192,11 +192,11 @@ def check_loaded_batch( assert isinstance(batch[1], dict) # metadata assert len(batch[1]) == 3 # label, name and radiological sign bounding-boxes - assert "label" in batch[1] - assert all([k in possible_labels for k in batch[1]["label"]]) + assert "target" in batch[1] + assert all([k in possible_labels for k in batch[1]["target"]]) if expected_num_labels: - assert len(batch[1]["label"]) == expected_num_labels + assert len(batch[1]["target"]) == expected_num_labels assert "name" in batch[1] assert all( @@ -207,7 +207,7 @@ def check_loaded_batch( for sample, label, bboxes in zip( batch[0], - batch[1]["label"], + batch[1]["target"], batch[1]["bounding_boxes"], ): # there must be a sign indicated on the image, if active TB is detected diff --git a/src/mednet/libs/common/tests/conftest.py b/src/mednet/libs/common/tests/conftest.py index a024dbbc..c732ed83 100644 --- a/src/mednet/libs/common/tests/conftest.py +++ b/src/mednet/libs/common/tests/conftest.py @@ -190,11 +190,11 @@ class DatabaseCheckers: assert isinstance(batch[1], dict) # metadata assert len(batch[1]) == 2 # label and name - assert "label" in batch[1] - assert all([k in possible_labels for k in batch[1]["label"]]) + assert "target" in batch[1] + assert all([k in possible_labels for k in batch[1]["target"]]) if expected_num_labels: - assert len(batch[1]["label"]) == expected_num_labels + assert len(batch[1]["target"]) == expected_num_labels assert "name" in batch[1] assert all( -- GitLab