From 74d90ccb6d889a76c18001555547aa1e451027d7 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 7 Feb 2024 16:08:25 +0100 Subject: [PATCH] [test] Add tests for raw data loading image quality --- tests/data/lfs | 2 +- tests/test_hivtb.py | 40 ++++++++++++++++++++++++++++++++++++++ tests/test_indian.py | 38 +++++++++++++++++++++++------------- tests/test_montgomery.py | 38 +++++++++++++++++++++++------------- tests/test_nih_cxr14.py | 42 +++++++++++++++++++++++++++++++++++++++- tests/test_padchest.py | 40 ++++++++++++++++++++++++++++++++++++++ tests/test_shenzhen.py | 38 +++++++++++++++++++++++------------- tests/test_tbpoc.py | 40 ++++++++++++++++++++++++++++++++++++++ tests/test_tbx11k.py | 40 ++++++++++++++++++++++++++++++++++++++ 9 files changed, 277 insertions(+), 41 deletions(-) diff --git a/tests/data/lfs b/tests/data/lfs index 05344d20..e2e4a98d 160000 --- a/tests/data/lfs +++ b/tests/data/lfs @@ -1 +1 @@ -Subproject commit 05344d20182ad4169fc5b9c38052d629aded30ed +Subproject commit e2e4a98d675ec61dac44c339c28d91fcb180b398 diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py index 03066c2f..6373d0ee 100644 --- a/tests/test_hivtb.py +++ b/tests/test_hivtb.py @@ -7,6 +7,10 @@ import importlib import pytest +from torchvision.transforms.functional import to_pil_image + +from mednet.data.split import JSONDatabaseSplit + def id_function(val): if isinstance(val, dict): @@ -90,3 +94,39 @@ def test_loading(database_checkers, name: str, dataset: str): expected_num_labels=1, ) limit -= 1 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") +def test_loaded_image_quality(datadir): + reference_histogram_file = str( + datadir / "lfs/histograms/raw_data/histograms_hivtb_fold_0.json" + ) + ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) + + datamodule = importlib.import_module( + ".fold_0", "mednet.config.data.hivtb" + ).datamodule + + datamodule.model_transforms = [] + datamodule.setup("predict") + + for split_name in ref_histogram_splits: + datamodule_split = datamodule.splits[split_name] + + loader = datamodule_split[0][1] + + for ref_data in ref_histogram_splits[split_name]: + sample_path = ref_data[0] + ref_histogram = ref_data[1] + + test_sample = ( + sample_path, + -1, + ) # Need to specify a label even if not used. + image_data = loader.sample(test_sample)[0] + + img = to_pil_image(image_data) + + histogram = img.histogram() + + assert histogram == ref_histogram diff --git a/tests/test_indian.py b/tests/test_indian.py index a35cfdb3..452aa6c6 100644 --- a/tests/test_indian.py +++ b/tests/test_indian.py @@ -7,11 +7,12 @@ dataset A/dataset B) dataset. """ import importlib -import json import pytest -from PIL import Image +from torchvision.transforms.functional import to_pil_image + +from mednet.data.split import JSONDatabaseSplit def id_function(val): @@ -102,6 +103,11 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.indian") def test_loaded_image_quality(datadir): + reference_histogram_file = str( + datadir / "lfs/histograms/raw_data/histograms_indian_default.json" + ) + ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) + datamodule = importlib.import_module( ".default", "mednet.config.data.indian" ).datamodule @@ -109,17 +115,23 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - loader = datamodule.splits["train"][0][1] - first_sample = datamodule.splits["train"][0][0][0] - image_data = loader.sample(first_sample)[0].numpy()[ - 0, :, : - ] # PIL expects grayscale to not have any leading dim - img = Image.fromarray(image_data, mode="L") + for split_name in ref_histogram_splits: + datamodule_split = datamodule.splits[split_name] + + loader = datamodule_split[0][1] + + for ref_data in ref_histogram_splits[split_name]: + sample_path = ref_data[0] + ref_histogram = ref_data[1] + + test_sample = ( + sample_path, + -1, + ) # Need to specify a label even if not used. + image_data = loader.sample(test_sample)[0] - histogram = img.histogram() + img = to_pil_image(image_data) - reference_histogram_file = str(datadir / "histogram_indian.json") - with open(reference_histogram_file) as i_f: - ref_histogram = json.load(i_f)["histogram"] + histogram = img.histogram() - assert histogram == ref_histogram + assert histogram == ref_histogram diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 349898d9..91c7aa59 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -4,11 +4,12 @@ """Tests for Montgomery dataset.""" import importlib -import json import pytest -from PIL import Image +from torchvision.transforms.functional import to_pil_image + +from mednet.data.split import JSONDatabaseSplit def id_function(val): @@ -99,6 +100,11 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_loaded_image_quality(datadir): + reference_histogram_file = str( + datadir / "lfs/histograms/raw_data/histograms_montgomery_default.json" + ) + ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) + datamodule = importlib.import_module( ".default", "mednet.config.data.montgomery" ).datamodule @@ -106,17 +112,23 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - loader = datamodule.splits["train"][0][1] - first_sample = datamodule.splits["train"][0][0][0] - image_data = loader.sample(first_sample)[0].numpy()[ - 0, :, : - ] # PIL expects grayscale to not have any leading dim - img = Image.fromarray(image_data, mode="L") + for split_name in ref_histogram_splits: + datamodule_split = datamodule.splits[split_name] + + loader = datamodule_split[0][1] + + for ref_data in ref_histogram_splits[split_name]: + sample_path = ref_data[0] + ref_histogram = ref_data[1] + + test_sample = ( + sample_path, + -1, + ) # Need to specify a label even if not used. + image_data = loader.sample(test_sample)[0] - histogram = img.histogram() + img = to_pil_image(image_data) - reference_histogram_file = str(datadir / "histogram_montgomery.json") - with open(reference_histogram_file) as i_f: - ref_histogram = json.load(i_f)["histogram"] + histogram = img.histogram() - assert histogram == ref_histogram + assert histogram == ref_histogram diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py index dc73aa29..a2e014f3 100644 --- a/tests/test_nih_cxr14.py +++ b/tests/test_nih_cxr14.py @@ -7,6 +7,10 @@ import importlib import pytest +from torchvision.transforms.functional import to_pil_image + +from mednet.data.split import JSONDatabaseSplit + def id_function(val): if isinstance(val, dict): @@ -44,7 +48,7 @@ testdata = [ ] -@pytest.mark.skip_if_rc_var_not_set("datadir.padchest") +@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") @pytest.mark.parametrize("name,dataset,num_labels", testdata) def test_loading(database_checkers, name: str, dataset: str, num_labels: int): datamodule = importlib.import_module( @@ -70,3 +74,39 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int): expected_image_shape=(1, 1024, 1024), ) limit -= 1 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") +def test_loaded_image_quality(datadir): + reference_histogram_file = str( + datadir / "lfs/histograms/raw_data/histograms_nih_cxr14_default.json" + ) + ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) + + datamodule = importlib.import_module( + ".default", "mednet.config.data.nih_cxr14" + ).datamodule + + datamodule.model_transforms = [] + datamodule.setup("predict") + + for split_name in ref_histogram_splits: + datamodule_split = datamodule.splits[split_name] + + loader = datamodule_split[0][1] + + for ref_data in ref_histogram_splits[split_name]: + sample_path = ref_data[0] + ref_histogram = ref_data[1] + + test_sample = ( + sample_path, + -1, + ) # Need to specify a label even if not used. + image_data = loader.sample(test_sample)[0] + + img = to_pil_image(image_data) + + histogram = img.histogram() + + assert histogram == ref_histogram diff --git a/tests/test_padchest.py b/tests/test_padchest.py index 98eaeb2d..47494d94 100644 --- a/tests/test_padchest.py +++ b/tests/test_padchest.py @@ -7,6 +7,10 @@ import importlib import pytest +from torchvision.transforms.functional import to_pil_image + +from mednet.data.split import JSONDatabaseSplit + def id_function(val): if isinstance(val, dict): @@ -75,3 +79,39 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int): expected_num_labels=num_labels, ) limit -= 1 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.padchest") +def test_loaded_image_quality(datadir): + reference_histogram_file = str( + datadir / "lfs/histograms/raw_data/histograms_padchest_idiap.json" + ) + ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) + + datamodule = importlib.import_module( + ".idiap", "mednet.config.data.padchest" + ).datamodule + + datamodule.model_transforms = [] + datamodule.setup("predict") + + for split_name in ref_histogram_splits: + datamodule_split = datamodule.splits[split_name] + + loader = datamodule_split[0][1] + + for ref_data in ref_histogram_splits[split_name]: + sample_path = ref_data[0] + ref_histogram = ref_data[1] + + test_sample = ( + sample_path, + -1, + ) # Need to specify a label even if not used. + image_data = loader.sample(test_sample)[0] + + img = to_pil_image(image_data) + + histogram = img.histogram() + + assert histogram == ref_histogram diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py index ff0a8af2..e2283fad 100644 --- a/tests/test_shenzhen.py +++ b/tests/test_shenzhen.py @@ -4,11 +4,12 @@ """Tests for Shenzhen dataset.""" import importlib -import json import pytest -from PIL import Image +from torchvision.transforms.functional import to_pil_image + +from mednet.data.split import JSONDatabaseSplit def id_function(val): @@ -99,6 +100,11 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") def test_loaded_image_quality(datadir): + reference_histogram_file = str( + datadir / "lfs/histograms/raw_data/histograms_shenzhen_default.json" + ) + ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) + datamodule = importlib.import_module( ".default", "mednet.config.data.shenzhen" ).datamodule @@ -106,17 +112,23 @@ def test_loaded_image_quality(datadir): datamodule.model_transforms = [] datamodule.setup("predict") - loader = datamodule.splits["train"][0][1] - first_sample = datamodule.splits["train"][0][0][0] - image_data = loader.sample(first_sample)[0].numpy()[ - 0, :, : - ] # PIL expects grayscale to not have any leading dim - img = Image.fromarray(image_data, mode="L") + for split_name in ref_histogram_splits: + datamodule_split = datamodule.splits[split_name] + + loader = datamodule_split[0][1] + + for ref_data in ref_histogram_splits[split_name]: + sample_path = ref_data[0] + ref_histogram = ref_data[1] + + test_sample = ( + sample_path, + -1, + ) # Need to specify a label even if not used. + image_data = loader.sample(test_sample)[0] - histogram = img.histogram() + img = to_pil_image(image_data) - reference_histogram_file = str(datadir / "histogram_shenzhen.json") - with open(reference_histogram_file) as i_f: - ref_histogram = json.load(i_f)["histogram"] + histogram = img.histogram() - assert histogram == ref_histogram + assert histogram == ref_histogram diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py index 58f762d5..75eaeca7 100644 --- a/tests/test_tbpoc.py +++ b/tests/test_tbpoc.py @@ -7,6 +7,10 @@ import importlib import pytest +from torchvision.transforms.functional import to_pil_image + +from mednet.data.split import JSONDatabaseSplit + def id_function(val): if isinstance(val, dict): @@ -96,3 +100,39 @@ def test_loading(database_checkers, name: str, dataset: str): expected_num_labels=1, ) limit -= 1 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc") +def test_loaded_image_quality(datadir): + reference_histogram_file = str( + datadir / "lfs/histograms/raw_data/histograms_tbpoc_fold_0.json" + ) + ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) + + datamodule = importlib.import_module( + ".fold_0", "mednet.config.data.tbpoc" + ).datamodule + + datamodule.model_transforms = [] + datamodule.setup("predict") + + for split_name in ref_histogram_splits: + datamodule_split = datamodule.splits[split_name] + + loader = datamodule_split[0][1] + + for ref_data in ref_histogram_splits[split_name]: + sample_path = ref_data[0] + ref_histogram = ref_data[1] + + test_sample = ( + sample_path, + -1, + ) # Need to specify a label even if not used. + image_data = loader.sample(test_sample)[0] + + img = to_pil_image(image_data) + + histogram = img.histogram() + + assert histogram == ref_histogram diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py index 6c022dba..7b2a60c9 100644 --- a/tests/test_tbx11k.py +++ b/tests/test_tbx11k.py @@ -9,6 +9,10 @@ import typing import pytest import torch +from torchvision.transforms.functional import to_pil_image + +from mednet.data.split import JSONDatabaseSplit + def id_function(val): if isinstance(val, (dict, tuple)): @@ -283,3 +287,39 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): expected_image_shape=(3, 512, 512), ) limit -= 1 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") +def test_loaded_image_quality(datadir): + reference_histogram_file = str( + datadir / "lfs/histograms/raw_data/histograms_tbx11k_v1_fold_0.json" + ) + ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) + + datamodule = importlib.import_module( + ".v1_fold_0", "mednet.config.data.tbx11k" + ).datamodule + + datamodule.model_transforms = [] + datamodule.setup("predict") + + for split_name in ref_histogram_splits: + datamodule_split = datamodule.splits[split_name] + + loader = datamodule_split[0][1] + + for ref_data in ref_histogram_splits[split_name]: + sample_path = ref_data[0] + ref_histogram = ref_data[1] + + test_sample = ( + sample_path, + -1, + ) # Need to specify a label even if not used. + image_data = loader.sample(test_sample)[0] + + img = to_pil_image(image_data) + + histogram = img.histogram() + + assert histogram == ref_histogram -- GitLab