diff --git a/src/mednet/libs/classification/data/typing.py b/src/mednet/libs/classification/data/typing.py index b504f8466e4633a39831f46fc9de108ada035d8b..0d18220e0f54b339cf296278eba7c9c38df6fd86 100644 --- a/src/mednet/libs/classification/data/typing.py +++ b/src/mednet/libs/classification/data/typing.py @@ -38,4 +38,4 @@ class ClassificationRawDataLoader(RawDataLoader): The label corresponding to the specified sample. """ - return self.sample(k)[1]["label"] + return self.sample(k)[1]["target"] diff --git a/src/mednet/libs/classification/engine/saliency/completeness.py b/src/mednet/libs/classification/engine/saliency/completeness.py index e57ba8d1e062383d4b2f9abbe5af52a89bbf845d..c7f3be68e0865b64b5c1c340d4c6a023d851775d 100644 --- a/src/mednet/libs/classification/engine/saliency/completeness.py +++ b/src/mednet/libs/classification/engine/saliency/completeness.py @@ -168,7 +168,7 @@ def _process_sample( """ name: str = sample[1]["name"][0] - label: int = int(sample[1]["label"].item()) + label: int = int(sample[1]["target"].item()) image = sample[0] # in binary classification systems, negative labels may be skipped diff --git a/src/mednet/libs/classification/engine/saliency/interpretability.py b/src/mednet/libs/classification/engine/saliency/interpretability.py index 621aa9d6138fc2fb6dcd2c85d2faf0eb8e502fb3..da96650ec1912d8ef19daa5c574b0f791b16d77b 100644 --- a/src/mednet/libs/classification/engine/saliency/interpretability.py +++ b/src/mednet/libs/classification/engine/saliency/interpretability.py @@ -426,7 +426,7 @@ def run( disable=None, ): name = str(sample[1]["name"][0]) - label = int(sample[1]["label"].item()) + label = int(sample[1]["target"].item()) if label != target_label: # we add the entry for dataset completeness, but do not treat diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py index 7d94a87a6b7deb7888b55d1953503e123d6a0f36..57cac2cf889f8652689c48ab878bb54f831aeafa 100644 --- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py +++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen.py @@ -44,12 +44,12 @@ def test_split_consistency(name: str): montgomery_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.montgomery", - ).RawDataLoader + ).ClassificationRawDataLoader shenzhen_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.shenzhen", - ).RawDataLoader + ).ClassificationRawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py index 3a95cd7b1968330e668be743ccc1effa22f5e72a..b1f0ec09a781c614cab38a3028c3a8249fe538ff 100644 --- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py +++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian.py @@ -49,17 +49,17 @@ def test_split_consistency(name: str): montgomery_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.montgomery", - ).RawDataLoader + ).ClassificationRawDataLoader shenzhen_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.shenzhen", - ).RawDataLoader + ).ClassificationRawDataLoader indian_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.indian", - ).RawDataLoader + ).ClassificationRawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py index bf7e913f2a7260d85e4ecaa6b99248f6d226aad6..3872fb1b98578e2b4525798e0fc0b18131c4cdc4 100644 --- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py +++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_padchest.py @@ -44,22 +44,22 @@ def test_split_consistency(name: str, padchest_name: str): montgomery_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.montgomery", - ).RawDataLoader + ).ClassificationRawDataLoader shenzhen_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.shenzhen", - ).RawDataLoader + ).ClassificationRawDataLoader indian_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.indian", - ).RawDataLoader + ).ClassificationRawDataLoader padchest_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.padchest", - ).RawDataLoader + ).ClassificationRawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] diff --git a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py index 644d8073ae576d15f50e8fc3c936d1b6251ab5a7..fd9c88cfa87eb50423ab5116f3069f7ae92dfe20 100644 --- a/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py +++ b/src/mednet/libs/classification/tests/test_montgomery_shenzhen_indian_tbx11k.py @@ -65,22 +65,22 @@ def test_split_consistency(name: str, tbx11k_name: str): montgomery_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.montgomery", - ).RawDataLoader + ).ClassificationRawDataLoader shenzhen_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.shenzhen", - ).RawDataLoader + ).ClassificationRawDataLoader indian_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.indian", - ).RawDataLoader + ).ClassificationRawDataLoader tbx11k_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.tbx11k", - ).RawDataLoader + ).ClassificationRawDataLoader for split in ("train", "validation", "test"): assert montgomery.splits[split][0][0] == combined.splits[split][0][0] diff --git a/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py b/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py index bc5cde53a550dd5f2dcd24abad0c80b2b0c95b63..987c2ba784f30760011d1112050cc8ee0fce54cc 100644 --- a/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py +++ b/src/mednet/libs/classification/tests/test_nih_cxr14_padchest.py @@ -34,12 +34,12 @@ def test_split_consistency(name: str, padchest_name: str, combined_name: str): cxr14_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.nih_cxr14", - ).RawDataLoader + ).ClassificationRawDataLoader padchest_loader = importlib.import_module( ".datamodule", "mednet.libs.classification.config.data.padchest", - ).RawDataLoader + ).ClassificationRawDataLoader for split in ("train", "validation", "test"): assert nih_cxr14.splits[split][0][0] == combined.splits[split][0][0] diff --git a/src/mednet/libs/classification/tests/test_tbx11k.py b/src/mednet/libs/classification/tests/test_tbx11k.py index 19be118c898a4f04f52687ddcd4b448fe7c5191c..0f9a5884cbb8dacbbf31f4170c9b4ad5c395c332 100644 --- a/src/mednet/libs/classification/tests/test_tbx11k.py +++ b/src/mednet/libs/classification/tests/test_tbx11k.py @@ -192,11 +192,11 @@ def check_loaded_batch( assert isinstance(batch[1], dict) # metadata assert len(batch[1]) == 3 # label, name and radiological sign bounding-boxes - assert "label" in batch[1] - assert all([k in possible_labels for k in batch[1]["label"]]) + assert "target" in batch[1] + assert all([k in possible_labels for k in batch[1]["target"]]) if expected_num_labels: - assert len(batch[1]["label"]) == expected_num_labels + assert len(batch[1]["target"]) == expected_num_labels assert "name" in batch[1] assert all( @@ -207,7 +207,7 @@ def check_loaded_batch( for sample, label, bboxes in zip( batch[0], - batch[1]["label"], + batch[1]["target"], batch[1]["bounding_boxes"], ): # there must be a sign indicated on the image, if active TB is detected diff --git a/src/mednet/libs/common/tests/conftest.py b/src/mednet/libs/common/tests/conftest.py index a024dbbc29345fde2110210804223f0b279d83fd..c732ed83b69ebd22da92c54208fccd330f792657 100644 --- a/src/mednet/libs/common/tests/conftest.py +++ b/src/mednet/libs/common/tests/conftest.py @@ -190,11 +190,11 @@ class DatabaseCheckers: assert isinstance(batch[1], dict) # metadata assert len(batch[1]) == 2 # label and name - assert "label" in batch[1] - assert all([k in possible_labels for k in batch[1]["label"]]) + assert "target" in batch[1] + assert all([k in possible_labels for k in batch[1]["target"]]) if expected_num_labels: - assert len(batch[1]["label"]) == expected_num_labels + assert len(batch[1]["target"]) == expected_num_labels assert "name" in batch[1] assert all(