Skip to content
Snippets Groups Projects
Commit 1e040288 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[test] Fix and generalize histogram tests

parent 218a89a7
No related branches found
No related tags found
2 merge requests!18Update tests,!16Make square centre-padding a model transform
...@@ -12,6 +12,9 @@ import pytest ...@@ -12,6 +12,9 @@ import pytest
import tomli_w import tomli_w
import torch import torch
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
from mednet.data.typing import DatabaseSplit from mednet.data.typing import DatabaseSplit
...@@ -163,6 +166,7 @@ class DatabaseCheckers: ...@@ -163,6 +166,7 @@ class DatabaseCheckers:
split split
An instance of DatabaseSplit. An instance of DatabaseSplit.
lengths lengths
A dictionary that contains keys matching those of the split (this will A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split. sizes of each of the datasets in the split.
...@@ -251,6 +255,33 @@ class DatabaseCheckers: ...@@ -251,6 +255,33 @@ class DatabaseCheckers:
# to_pil_image(batch[0][0]).show() # to_pil_image(batch[0][0]).show()
# __import__("pdb").set_trace() # __import__("pdb").set_trace()
@staticmethod
def check_image_quality(datamodule, reference_histogram_file):
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
for split_name in ref_histogram_splits:
raw_samples = datamodule.splits[split_name][0][0]
# It is not possible to get a sample from a Dataset by name/path, only by index.
# This creates a dict of sample name to dataset index.
raw_samples_indices = {}
for idx, rs in enumerate(raw_samples):
raw_samples_indices[rs[0]] = idx
for ref_hist_path, ref_hist_data in ref_histogram_splits[
split_name
]:
# Get index in the dataset that will return the data corresponding to the specified sample name
dataset_sample_index = raw_samples_indices[ref_hist_path]
image_tensor = datamodule._datasets[split_name][
dataset_sample_index
][0]
img = to_pil_image(image_tensor)
histogram = img.histogram()
assert histogram == ref_hist_data
@pytest.fixture @pytest.fixture
def database_checkers(): def database_checkers():
......
...@@ -7,10 +7,6 @@ import importlib ...@@ -7,10 +7,6 @@ import importlib
import pytest import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val): def id_function(val):
if isinstance(val, dict): if isinstance(val, dict):
...@@ -97,11 +93,10 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -97,11 +93,10 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
def test_loaded_image_quality(datadir): def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str( reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_hivtb_fold_0.json" datadir / "histograms/raw_data/histograms_hivtb_fold_0.json"
) )
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module( datamodule = importlib.import_module(
".fold_0", "mednet.config.data.hivtb" ".fold_0", "mednet.config.data.hivtb"
...@@ -110,23 +105,4 @@ def test_loaded_image_quality(datadir): ...@@ -110,23 +105,4 @@ def test_loaded_image_quality(datadir):
datamodule.model_transforms = [] datamodule.model_transforms = []
datamodule.setup("predict") datamodule.setup("predict")
for split_name in ref_histogram_splits: database_checkers.check_image_quality(datamodule, reference_histogram_file)
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
...@@ -10,10 +10,6 @@ import importlib ...@@ -10,10 +10,6 @@ import importlib
import pytest import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val): def id_function(val):
if isinstance(val, dict): if isinstance(val, dict):
...@@ -102,11 +98,10 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -102,11 +98,10 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.indian") @pytest.mark.skip_if_rc_var_not_set("datadir.indian")
def test_loaded_image_quality(datadir): def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str( reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_indian_default.json" datadir / "histograms/raw_data/histograms_indian_default.json"
) )
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module( datamodule = importlib.import_module(
".default", "mednet.config.data.indian" ".default", "mednet.config.data.indian"
...@@ -115,23 +110,4 @@ def test_loaded_image_quality(datadir): ...@@ -115,23 +110,4 @@ def test_loaded_image_quality(datadir):
datamodule.model_transforms = [] datamodule.model_transforms = []
datamodule.setup("predict") datamodule.setup("predict")
for split_name in ref_histogram_splits: database_checkers.check_image_quality(datamodule, reference_histogram_file)
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
...@@ -7,10 +7,6 @@ import importlib ...@@ -7,10 +7,6 @@ import importlib
import pytest import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val): def id_function(val):
if isinstance(val, dict): if isinstance(val, dict):
...@@ -99,11 +95,10 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -99,11 +95,10 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_loaded_image_quality(datadir): def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str( reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_montgomery_default.json" datadir / "histograms/raw_data/histograms_montgomery_default.json"
) )
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module( datamodule = importlib.import_module(
".default", "mednet.config.data.montgomery" ".default", "mednet.config.data.montgomery"
...@@ -112,23 +107,4 @@ def test_loaded_image_quality(datadir): ...@@ -112,23 +107,4 @@ def test_loaded_image_quality(datadir):
datamodule.model_transforms = [] datamodule.model_transforms = []
datamodule.setup("predict") datamodule.setup("predict")
for split_name in ref_histogram_splits: database_checkers.check_image_quality(datamodule, reference_histogram_file)
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
...@@ -7,10 +7,6 @@ import importlib ...@@ -7,10 +7,6 @@ import importlib
import pytest import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val): def id_function(val):
if isinstance(val, dict): if isinstance(val, dict):
...@@ -77,11 +73,10 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int): ...@@ -77,11 +73,10 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
def test_loaded_image_quality(datadir): def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str( reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_nih_cxr14_default.json" datadir / "histograms/raw_data/histograms_nih_cxr14_default.json"
) )
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module( datamodule = importlib.import_module(
".default", "mednet.config.data.nih_cxr14" ".default", "mednet.config.data.nih_cxr14"
...@@ -90,23 +85,4 @@ def test_loaded_image_quality(datadir): ...@@ -90,23 +85,4 @@ def test_loaded_image_quality(datadir):
datamodule.model_transforms = [] datamodule.model_transforms = []
datamodule.setup("predict") datamodule.setup("predict")
for split_name in ref_histogram_splits: database_checkers.check_image_quality(datamodule, reference_histogram_file)
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
...@@ -7,10 +7,6 @@ import importlib ...@@ -7,10 +7,6 @@ import importlib
import pytest import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val): def id_function(val):
if isinstance(val, dict): if isinstance(val, dict):
...@@ -82,11 +78,10 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int): ...@@ -82,11 +78,10 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest") @pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
def test_loaded_image_quality(datadir): def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str( reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_padchest_idiap.json" datadir / "histograms/raw_data/histograms_padchest_idiap.json"
) )
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module( datamodule = importlib.import_module(
".idiap", "mednet.config.data.padchest" ".idiap", "mednet.config.data.padchest"
...@@ -95,23 +90,4 @@ def test_loaded_image_quality(datadir): ...@@ -95,23 +90,4 @@ def test_loaded_image_quality(datadir):
datamodule.model_transforms = [] datamodule.model_transforms = []
datamodule.setup("predict") datamodule.setup("predict")
for split_name in ref_histogram_splits: database_checkers.check_image_quality(datamodule, reference_histogram_file)
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
...@@ -7,10 +7,6 @@ import importlib ...@@ -7,10 +7,6 @@ import importlib
import pytest import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val): def id_function(val):
if isinstance(val, dict): if isinstance(val, dict):
...@@ -99,11 +95,10 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -99,11 +95,10 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_loaded_image_quality(datadir): def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str( reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_shenzhen_default.json" datadir / "histograms/raw_data/histograms_shenzhen_default.json"
) )
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module( datamodule = importlib.import_module(
".default", "mednet.config.data.shenzhen" ".default", "mednet.config.data.shenzhen"
...@@ -112,23 +107,4 @@ def test_loaded_image_quality(datadir): ...@@ -112,23 +107,4 @@ def test_loaded_image_quality(datadir):
datamodule.model_transforms = [] datamodule.model_transforms = []
datamodule.setup("predict") datamodule.setup("predict")
for split_name in ref_histogram_splits: database_checkers.check_image_quality(datamodule, reference_histogram_file)
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
...@@ -7,10 +7,6 @@ import importlib ...@@ -7,10 +7,6 @@ import importlib
import pytest import pytest
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val): def id_function(val):
if isinstance(val, dict): if isinstance(val, dict):
...@@ -103,11 +99,10 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -103,11 +99,10 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc") @pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
def test_loaded_image_quality(datadir): def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str( reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_tbpoc_fold_0.json" datadir / "histograms/raw_data/histograms_tbpoc_fold_0.json"
) )
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module( datamodule = importlib.import_module(
".fold_0", "mednet.config.data.tbpoc" ".fold_0", "mednet.config.data.tbpoc"
...@@ -116,23 +111,4 @@ def test_loaded_image_quality(datadir): ...@@ -116,23 +111,4 @@ def test_loaded_image_quality(datadir):
datamodule.model_transforms = [] datamodule.model_transforms = []
datamodule.setup("predict") datamodule.setup("predict")
for split_name in ref_histogram_splits: database_checkers.check_image_quality(datamodule, reference_histogram_file)
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
...@@ -9,10 +9,6 @@ import typing ...@@ -9,10 +9,6 @@ import typing
import pytest import pytest
import torch import torch
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
def id_function(val): def id_function(val):
if isinstance(val, (dict, tuple)): if isinstance(val, (dict, tuple)):
...@@ -305,11 +301,10 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): ...@@ -305,11 +301,10 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
], ],
) )
@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
def test_loaded_image_quality(datadir, split): def test_loaded_image_quality(database_checkers, datadir, split):
reference_histogram_file = str( reference_histogram_file = str(
datadir / f"histograms/raw_data/histograms_tbx11k_{split}.json" datadir / f"histograms/raw_data/histograms_tbx11k_{split}.json"
) )
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
datamodule = importlib.import_module( datamodule = importlib.import_module(
f".{split}", "mednet.config.data.tbx11k" f".{split}", "mednet.config.data.tbx11k"
...@@ -318,23 +313,4 @@ def test_loaded_image_quality(datadir, split): ...@@ -318,23 +313,4 @@ def test_loaded_image_quality(datadir, split):
datamodule.model_transforms = [] datamodule.model_transforms = []
datamodule.setup("predict") datamodule.setup("predict")
for split_name in ref_histogram_splits: database_checkers.check_image_quality(datamodule, reference_histogram_file)
datamodule_split = datamodule.splits[split_name]
loader = datamodule_split[0][1]
for ref_data in ref_histogram_splits[split_name]:
sample_path = ref_data[0]
ref_histogram = ref_data[1]
test_sample = (
sample_path,
-1,
) # Need to specify a label even if not used.
image_data = loader.sample(test_sample)[0]
img = to_pil_image(image_data)
histogram = img.histogram()
assert histogram == ref_histogram
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment