diff --git a/tests/conftest.py b/tests/conftest.py index 7b4fe9e5cc2f949b738926086d8415e08fb56708..ef45a883f7b7797facac85b4446574c64b806ebe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,9 @@ import pytest import tomli_w import torch +from torchvision.transforms.functional import to_pil_image + +from mednet.data.split import JSONDatabaseSplit from mednet.data.typing import DatabaseSplit @@ -163,6 +166,7 @@ class DatabaseCheckers: split An instance of DatabaseSplit. lengths + A dictionary that contains keys matching those of the split (this will be checked). The values of the dictionary should correspond to the sizes of each of the datasets in the split. @@ -251,6 +255,33 @@ class DatabaseCheckers: # to_pil_image(batch[0][0]).show() # __import__("pdb").set_trace() + @staticmethod + def check_image_quality(datamodule, reference_histogram_file): + ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) + + for split_name in ref_histogram_splits: + raw_samples = datamodule.splits[split_name][0][0] + + # It is not possible to get a sample from a Dataset by name/path, only by index. + # This creates a dict of sample name to dataset index. + raw_samples_indices = {} + for idx, rs in enumerate(raw_samples): + raw_samples_indices[rs[0]] = idx + + for ref_hist_path, ref_hist_data in ref_histogram_splits[ + split_name + ]: + # Get index in the dataset that will return the data corresponding to the specified sample name + dataset_sample_index = raw_samples_indices[ref_hist_path] + + image_tensor = datamodule._datasets[split_name][ + dataset_sample_index + ][0] + img = to_pil_image(image_tensor) + histogram = img.histogram() + + assert histogram == ref_hist_data + @pytest.fixture def database_checkers(): diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py index 807b00aa2812545cf3f3dfa02fa0f3a4359f8007..44c1cca3f702dc8776395c2702bc7c0b6533f105 100644 --- a/tests/test_hivtb.py +++ b/tests/test_hivtb.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -97,11 +93,10 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_hivtb_fold_0.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".fold_0", "mednet.config.data.hivtb" @@ -110,23 +105,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_indian.py b/tests/test_indian.py index f54b6bce1fd2098e5e070ad7f2123fecb74cf8e9..3a959e60783c4f6534c55352b0b933095ed9c093 100644 --- a/tests/test_indian.py +++ b/tests/test_indian.py @@ -10,10 +10,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -102,11 +98,10 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.indian") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_indian_default.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".default", "mednet.config.data.indian" @@ -115,23 +110,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 9494ce17c64e8c431c8f5d0f13167e33057257c4..c756c7252c91a2c31b9dfdde1027fbf86546a620 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -99,11 +95,10 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_montgomery_default.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".default", "mednet.config.data.montgomery" @@ -112,23 +107,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py index aaa0fbfdfec5b45a3c07b56f911320a510452d30..1790dbc77007c77a320745251ea0b93b3ffcd3a8 100644 --- a/tests/test_nih_cxr14.py +++ b/tests/test_nih_cxr14.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -77,11 +73,10 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int): @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_nih_cxr14_default.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".default", "mednet.config.data.nih_cxr14" @@ -90,23 +85,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_padchest.py b/tests/test_padchest.py index dc97d609638338dc7e77ddd19aee360ae0d45aaf..68e24077725172230bc7ecd76acf367eb1ffde0d 100644 --- a/tests/test_padchest.py +++ b/tests/test_padchest.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -82,11 +78,10 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int): @pytest.mark.skip_if_rc_var_not_set("datadir.padchest") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_padchest_idiap.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".idiap", "mednet.config.data.padchest" @@ -95,23 +90,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py index 4748f76ad666f64414a53b0e724259c0586cd5c4..42b23ce16b50d7303bdd1e14bfe6d0199c57623e 100644 --- a/tests/test_shenzhen.py +++ b/tests/test_shenzhen.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -99,11 +95,10 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_shenzhen_default.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".default", "mednet.config.data.shenzhen" @@ -112,23 +107,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py index edc251caa3fbf558c7fad878a75a70392cc51423..44e742789684bf823cccdaeca539e0ee0f2f0482 100644 --- a/tests/test_tbpoc.py +++ b/tests/test_tbpoc.py @@ -7,10 +7,6 @@ import importlib import pytest -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, dict): @@ -103,11 +99,10 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc") -def test_loaded_image_quality(datadir): +def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_tbpoc_fold_0.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( ".fold_0", "mednet.config.data.tbpoc" @@ -116,23 +111,4 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file) diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py index 6ddacdfeee9ebca206c0b8c9667d35551a3974de..5d1f584c1c79e0b420ec813d6dcc71d11ca3ecc1 100644 --- a/tests/test_tbx11k.py +++ b/tests/test_tbx11k.py @@ -9,10 +9,6 @@ import typing import pytest import torch -from torchvision.transforms.functional import to_pil_image - -from mednet.data.split import JSONDatabaseSplit - def id_function(val): if isinstance(val, (dict, tuple)): @@ -305,11 +301,10 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): ], ) @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") -def test_loaded_image_quality(datadir, split): +def test_loaded_image_quality(database_checkers, datadir, split): reference_histogram_file = str( datadir / f"histograms/raw_data/histograms_tbx11k_{split}.json" ) - ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) datamodule = importlib.import_module( f".{split}", "mednet.config.data.tbx11k" @@ -318,23 +313,4 @@ def test_loaded_image_quality(datadir, split): datamodule.model_transforms = [] datamodule.setup("predict") - for split_name in ref_histogram_splits: - datamodule_split = datamodule.splits[split_name] - - loader = datamodule_split[0][1] - - for ref_data in ref_histogram_splits[split_name]: - sample_path = ref_data[0] - ref_histogram = ref_data[1] - - test_sample = ( - sample_path, - -1, - ) # Need to specify a label even if not used. - image_data = loader.sample(test_sample)[0] - - img = to_pil_image(image_data) - - histogram = img.histogram() - - assert histogram == ref_histogram + database_checkers.check_image_quality(datamodule, reference_histogram_file)