diff --git a/tests/test_indian.py b/tests/test_indian.py index 8dd07159e8d2aa659eb03bd81a743ee097403225..a35cfdb318ee14f299b9dc75932c0c3a715c3d40 100644 --- a/tests/test_indian.py +++ b/tests/test_indian.py @@ -7,9 +7,12 @@ dataset A/dataset B) dataset. """ import importlib +import json import pytest +from PIL import Image + def id_function(val): if isinstance(val, dict): @@ -95,3 +98,28 @@ def test_loading(database_checkers, name: str, dataset: str): expected_num_labels=1, ) limit -= 1 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.indian") +def test_loaded_image_quality(datadir): + datamodule = importlib.import_module( + ".default", "mednet.config.data.indian" + ).datamodule + + datamodule.model_transforms = [] + datamodule.setup("predict") + + loader = datamodule.splits["train"][0][1] + first_sample = datamodule.splits["train"][0][0][0] + image_data = loader.sample(first_sample)[0].numpy()[ + 0, :, : + ] # PIL expects grayscale to not have any leading dim + img = Image.fromarray(image_data, mode="L") + + histogram = img.histogram() + + reference_histogram_file = str(datadir / "histogram_indian.json") + with open(reference_histogram_file) as i_f: + ref_histogram = json.load(i_f)["histogram"] + + assert histogram == ref_histogram diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 91f8cc1004c34043cb1b7735d55c5d66dd92c604..349898d91500833c1141e8ddfb0471f8af4e896b 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -4,9 +4,12 @@ """Tests for Montgomery dataset.""" import importlib +import json import pytest +from PIL import Image + def id_function(val): if isinstance(val, dict): @@ -92,3 +95,28 @@ def test_loading(database_checkers, name: str, dataset: str): expected_num_labels=1, ) limit -= 1 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +def test_loaded_image_quality(datadir): + datamodule = importlib.import_module( + ".default", "mednet.config.data.montgomery" + ).datamodule + + datamodule.model_transforms = [] + datamodule.setup("predict") + + loader = datamodule.splits["train"][0][1] + first_sample = datamodule.splits["train"][0][0][0] + image_data = loader.sample(first_sample)[0].numpy()[ + 0, :, : + ] # PIL expects grayscale to not have any leading dim + img = Image.fromarray(image_data, mode="L") + + histogram = img.histogram() + + reference_histogram_file = str(datadir / "histogram_montgomery.json") + with open(reference_histogram_file) as i_f: + ref_histogram = json.load(i_f)["histogram"] + + assert histogram == ref_histogram diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py index aeb9da854ff5b2e7725bf3d54cf3469b96ea5201..ff0a8af2a94d673079c55f095ed2340c41c9ed9e 100644 --- a/tests/test_shenzhen.py +++ b/tests/test_shenzhen.py @@ -4,9 +4,12 @@ """Tests for Shenzhen dataset.""" import importlib +import json import pytest +from PIL import Image + def id_function(val): if isinstance(val, dict): @@ -92,3 +95,28 @@ def test_loading(database_checkers, name: str, dataset: str): expected_num_labels=1, ) limit -= 1 + + +@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") +def test_loaded_image_quality(datadir): + datamodule = importlib.import_module( + ".default", "mednet.config.data.shenzhen" + ).datamodule + + datamodule.model_transforms = [] + datamodule.setup("predict") + + loader = datamodule.splits["train"][0][1] + first_sample = datamodule.splits["train"][0][0][0] + image_data = loader.sample(first_sample)[0].numpy()[ + 0, :, : + ] # PIL expects grayscale to not have any leading dim + img = Image.fromarray(image_data, mode="L") + + histogram = img.histogram() + + reference_histogram_file = str(datadir / "histogram_shenzhen.json") + with open(reference_histogram_file) as i_f: + ref_histogram = json.load(i_f)["histogram"] + + assert histogram == ref_histogram