Skip to content
Snippets Groups Projects
Commit 6bff02f8 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Centralize histogram production; Fix tests to only verify existing...

[tests] Centralize histogram production; Fix tests to only verify existing references and error if not all references are checked
parent adf0efe5
No related branches found
No related tags found
1 merge request!44Make code and tests flexible to the use of a pre-processed Montgomery dataset...
Pipeline #87710 failed
......@@ -7,6 +7,7 @@ import pathlib
import typing
import numpy
import numpy.typing
import pytest
import torch
from mednet.data.split import JSONDatabaseSplit
......@@ -205,6 +206,15 @@ class DatabaseCheckers:
# to_pil_image(batch[0][0]).show()
# __import__("pdb").set_trace()
@staticmethod
def _make_histo(data: numpy.typing.NDArray[numpy.uint8]) -> list[int]:
from itertools import chain
def _mk_single_channel(data: numpy.typing.NDArray[numpy.uint8]) -> list[int]:
return numpy.histogram(data, bins=256, range=(0, 256))[0].tolist()
return list(chain(*[_mk_single_channel(k) for k in data[0, :]]))
@staticmethod
def check_image_quality(
datamodule,
......@@ -218,11 +228,13 @@ class DatabaseCheckers:
for split_name, loader in datamodule.predict_dataloader().items():
for sample in loader:
ubyte_tensor = (255 * sample[0]).byte().numpy()
histogram = numpy.histogram(ubyte_tensor, bins=256, range=(0, 256))[
0
].tolist()
ref_histogram = reference[split_name][sample[1]["name"][0]]
uint8_array = (255 * sample[0]).byte().numpy()
histogram = DatabaseCheckers._make_histo(uint8_array)
if sample[1]["name"][0] in reference[split_name]:
ref_histogram = reference[split_name].pop(sample[1]["name"][0])
else:
continue
if compare_type == "statistical":
# Compute pearson coefficients between histogram and
......@@ -239,6 +251,13 @@ class DatabaseCheckers:
f"reference = {ref_histogram}"
)
# all references must have been consumed
for split, values in reference.items():
assert len(values) == 0, (
f"Not all references at split `{split}` were consumed: {len(values)} "
f"are left"
)
@staticmethod
def write_image_quality_histogram(
datamodule,
......@@ -248,14 +267,9 @@ class DatabaseCheckers:
for split_name, loader in datamodule.predict_dataloader().items():
data[split_name] = []
for sample in loader:
ubyte_tensor = (255 * sample[0]).byte().numpy()
uint8_array = (255 * sample[0]).byte().numpy()
data[split_name].append(
[
sample[1]["name"][0],
numpy.histogram(ubyte_tensor, bins=256, range=(0, 256))[
0
].tolist(),
]
[sample[1]["name"][0], DatabaseCheckers._make_histo(uint8_array)]
)
with reference_histogram_file.open("w") as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment