From 1e040288249693b1088d061bc73b6e221a24b224 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 12 Feb 2024 16:41:22 +0100 Subject: [PATCH] [test] Fix and generalize histogram tests --- tests/conftest.py | 31 +++++++++++++++++++++++++++++++ tests/test_hivtb.py | 28 ++-------------------------- tests/test_indian.py | 28 ++-------------------------- tests/test_montgomery.py | 28 ++-------------------------- tests/test_nih_cxr14.py | 28 ++-------------------------- tests/test_padchest.py | 28 ++-------------------------- tests/test_shenzhen.py | 28 ++-------------------------- tests/test_tbpoc.py | 28 ++-------------------------- tests/test_tbx11k.py | 28 ++-------------------------- 9 files changed, 47 insertions(+), 208 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7b4fe9e5..ef45a883 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,9 @@ import pytest import tomli_w import torch +from torchvision.transforms.functional import to_pil_image + +from mednet.data.split import JSONDatabaseSplit from mednet.data.typing import DatabaseSplit @@ -163,6 +166,7 @@ class DatabaseCheckers: split An instance of DatabaseSplit. lengths + A dictionary that contains keys matching those of the split (this will be checked). The values of the dictionary should correspond to the sizes of each of the datasets in the split. @@ -251,6 +255,33 @@ class DatabaseCheckers: # to_pil_image(batch[0][0]).show() # __import__("pdb").set_trace() + @staticmethod + def check_image_quality(datamodule, reference_histogram_file): + ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) + + for split_name in ref_histogram_splits: + raw_samples = datamodule.splits[split_name][0][0] + + # It is not possible to get a sample from a Dataset by name/path, only by index. + # This creates a dict of sample name to dataset index. + raw_samples_indices = {} + for idx, rs in enumerate(raw_samples): + raw_samples_indices[rs[0]] = idx + + for ref_hist_path, ref_hist_data in ref_histogram_splits[ + split_name + ]: + # Get index in the dataset that will return the data corresponding to the specified sample name + dataset_sample_index = raw_samples_indices[ref_hist_path] + + image_tensor = datamodule._datasets[split_name][ + dataset_sample_index + ][0] + img = to_pil_image(image_tensor) + histogram = img.histogram() + + assert histogram == ref_hist_data + @pytest.fixture def database_checkers(): diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py index 807b00aa..44c1cca3 100644 --- a/tests/test_hivtb.py +++ b/tests/test_hivtb.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -97,11 +93,10 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_hivtb_fold_0.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".fold_0", "mednet.config.data.hivtb" @@ -110,23 +105,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_indian.py b/tests/test_indian.py index f54b6bce..3a959e60 100644 --- a/tests/test_indian.py +++ b/tests/test_indian.py @@ -10,10 +10,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -102,11 +98,10 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.indian") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_indian_default.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".default", "mednet.config.data.indian" @@ -115,23 +110,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 9494ce17..c756c725 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -99,11 +95,10 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_montgomery_default.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".default", "mednet.config.data.montgomery" @@ -112,23 +107,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py index aaa0fbfd..1790dbc7 100644 --- a/tests/test_nih_cxr14.py +++ b/tests/test_nih_cxr14.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -77,11 +73,10 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int): @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_nih_cxr14_default.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".default", "mednet.config.data.nih_cxr14" @@ -90,23 +85,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_padchest.py b/tests/test_padchest.py index dc97d609..68e24077 100644 --- a/tests/test_padchest.py +++ b/tests/test_padchest.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -82,11 +78,10 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int): @pytest.mark.skip_if_rc_var_not_set("datadir.padchest") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_padchest_idiap.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".idiap", "mednet.config.data.padchest" @@ -95,23 +90,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py index 4748f76a..42b23ce1 100644 --- a/tests/test_shenzhen.py +++ b/tests/test_shenzhen.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -99,11 +95,10 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_shenzhen_default.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".default", "mednet.config.data.shenzhen" @@ -112,23 +107,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py index edc251ca..44e74278 100644 --- a/tests/test_tbpoc.py +++ b/tests/test_tbpoc.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -103,11 +99,10 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_tbpoc_fold_0.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".fold_0", "mednet.config.data.tbpoc" @@ -116,23 +111,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py index 6ddacdfe..5d1f584c 100644 --- a/tests/test_tbx11k.py +++ b/tests/test_tbx11k.py @@ -9,10 +9,6 @@ import typing import pytest import torch -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, (dict, tuple)): @@ -305,11 +301,10 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): ], ) @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") -def test_loaded_image_quality(datadir, split): +def test_loaded_image_quality(database_checkers, datadir, split): reference_histogram_file = str( datadir / f"histograms/raw_data/histograms_tbx11k_{split}.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( f".{split}", "mednet.config.data.tbx11k" @@ -318,23 +313,4 @@ def test_loaded_image_quality(datadir, split): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) -- GitLab