Skip to content
Snippets Groups Projects
Commit 74d90ccb authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[test] Add tests for raw data loading image quality

parent 45f95973
No related branches found
No related tags found
2 merge requests!18Update tests,!16Make square centre-padding a model transform
Subproject commit 05344d20182ad4169fc5b9c38052d629aded30ed
Subproject commit e2e4a98d675ec61dac44c339c28d91fcb180b398
......@@ -7,6 +7,10 @@ import importlib
import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val):
if isinstance(val, dict):
......@@ -90,3 +94,39 @@ def test_loading(database_checkers, name: str, dataset: str):
expected_num_labels=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
def test_loaded_image_quality(datadir):
reference_histogram_file = str(
datadir / "lfs/histograms/raw_data/histograms_hivtb_fold_0.json"
)
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module(
".fold_0", "mednet.config.data.hivtb"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
for split_name in ref_histogram_splits:
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
......@@ -7,11 +7,12 @@ dataset A/dataset B) dataset.
"""
import importlib
import json
import pytest
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val):
......@@ -102,6 +103,11 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
def test_loaded_image_quality(datadir):
reference_histogram_file = str(
datadir / "lfs/histograms/raw_data/histograms_indian_default.json"
)
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module(
".default", "mednet.config.data.indian"
).datamodule
......@@ -109,17 +115,23 @@ def test_loaded_image_quality(datadir):
datamodule.model_transforms = []
datamodule.setup("predict")
loader = datamodule.splits["train"][0][1]
first_sample = datamodule.splits["train"][0][0][0]
image_data = loader.sample(first_sample)[0].numpy()[
0, :, :
] # PIL expects grayscale to not have any leading dim
img = Image.fromarray(image_data, mode="L")
for split_name in ref_histogram_splits:
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
histogram = img.histogram()
img = to_pil_image(image_data)
reference_histogram_file = str(datadir / "histogram_indian.json")
with open(reference_histogram_file) as i_f:
ref_histogram = json.load(i_f)["histogram"]
histogram = img.histogram()
assert histogram == ref_histogram
assert histogram == ref_histogram
......@@ -4,11 +4,12 @@
"""Tests for Montgomery dataset."""
import importlib
import json
import pytest
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val):
......@@ -99,6 +100,11 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_loaded_image_quality(datadir):
reference_histogram_file = str(
datadir / "lfs/histograms/raw_data/histograms_montgomery_default.json"
)
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module(
".default", "mednet.config.data.montgomery"
).datamodule
......@@ -106,17 +112,23 @@ def test_loaded_image_quality(datadir):
datamodule.model_transforms = []
datamodule.setup("predict")
loader = datamodule.splits["train"][0][1]
first_sample = datamodule.splits["train"][0][0][0]
image_data = loader.sample(first_sample)[0].numpy()[
0, :, :
] # PIL expects grayscale to not have any leading dim
img = Image.fromarray(image_data, mode="L")
for split_name in ref_histogram_splits:
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
histogram = img.histogram()
img = to_pil_image(image_data)
reference_histogram_file = str(datadir / "histogram_montgomery.json")
with open(reference_histogram_file) as i_f:
ref_histogram = json.load(i_f)["histogram"]
histogram = img.histogram()
assert histogram == ref_histogram
assert histogram == ref_histogram
......@@ -7,6 +7,10 @@ import importlib
import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val):
if isinstance(val, dict):
......@@ -44,7 +48,7 @@ testdata = [
]
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
@pytest.mark.parametrize("name,dataset,num_labels", testdata)
def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
datamodule = importlib.import_module(
......@@ -70,3 +74,39 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
expected_image_shape=(1, 1024, 1024),
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
def test_loaded_image_quality(datadir):
reference_histogram_file = str(
datadir / "lfs/histograms/raw_data/histograms_nih_cxr14_default.json"
)
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module(
".default", "mednet.config.data.nih_cxr14"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
for split_name in ref_histogram_splits:
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
......@@ -7,6 +7,10 @@ import importlib
import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val):
if isinstance(val, dict):
......@@ -75,3 +79,39 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
expected_num_labels=num_labels,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
def test_loaded_image_quality(datadir):
reference_histogram_file = str(
datadir / "lfs/histograms/raw_data/histograms_padchest_idiap.json"
)
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module(
".idiap", "mednet.config.data.padchest"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
for split_name in ref_histogram_splits:
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
......@@ -4,11 +4,12 @@
"""Tests for Shenzhen dataset."""
import importlib
import json
import pytest
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val):
......@@ -99,6 +100,11 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_loaded_image_quality(datadir):
reference_histogram_file = str(
datadir / "lfs/histograms/raw_data/histograms_shenzhen_default.json"
)
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module(
".default", "mednet.config.data.shenzhen"
).datamodule
......@@ -106,17 +112,23 @@ def test_loaded_image_quality(datadir):
datamodule.model_transforms = []
datamodule.setup("predict")
loader = datamodule.splits["train"][0][1]
first_sample = datamodule.splits["train"][0][0][0]
image_data = loader.sample(first_sample)[0].numpy()[
0, :, :
] # PIL expects grayscale to not have any leading dim
img = Image.fromarray(image_data, mode="L")
for split_name in ref_histogram_splits:
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
histogram = img.histogram()
img = to_pil_image(image_data)
reference_histogram_file = str(datadir / "histogram_shenzhen.json")
with open(reference_histogram_file) as i_f:
ref_histogram = json.load(i_f)["histogram"]
histogram = img.histogram()
assert histogram == ref_histogram
assert histogram == ref_histogram
......@@ -7,6 +7,10 @@ import importlib
import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val):
if isinstance(val, dict):
......@@ -96,3 +100,39 @@ def test_loading(database_checkers, name: str, dataset: str):
expected_num_labels=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
def test_loaded_image_quality(datadir):
reference_histogram_file = str(
datadir / "lfs/histograms/raw_data/histograms_tbpoc_fold_0.json"
)
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module(
".fold_0", "mednet.config.data.tbpoc"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
for split_name in ref_histogram_splits:
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
......@@ -9,6 +9,10 @@ import typing
import pytest
import torch
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val):
if isinstance(val, (dict, tuple)):
......@@ -283,3 +287,39 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
expected_image_shape=(3, 512, 512),
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
def test_loaded_image_quality(datadir):
reference_histogram_file = str(
datadir / "lfs/histograms/raw_data/histograms_tbx11k_v1_fold_0.json"
)
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module(
".v1_fold_0", "mednet.config.data.tbx11k"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
for split_name in ref_histogram_splits:
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment