Skip to content
Snippets Groups Projects
Commit 6bff02f8 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Centralize histogram production; Fix tests to only verify existing...

[tests] Centralize histogram production; Fix tests to only verify existing references and error if not all references are checked
parent adf0efe5
No related branches found
No related tags found
1 merge request!44Make code and tests flexible to the use of a pre-processed Montgomery dataset...
Pipeline #87710 failed
...@@ -7,6 +7,7 @@ import pathlib ...@@ -7,6 +7,7 @@ import pathlib
import typing import typing
import numpy import numpy
import numpy.typing
import pytest import pytest
import torch import torch
from mednet.data.split import JSONDatabaseSplit from mednet.data.split import JSONDatabaseSplit
...@@ -205,6 +206,15 @@ class DatabaseCheckers: ...@@ -205,6 +206,15 @@ class DatabaseCheckers:
# to_pil_image(batch[0][0]).show() # to_pil_image(batch[0][0]).show()
# __import__("pdb").set_trace() # __import__("pdb").set_trace()
@staticmethod
def _make_histo(data: numpy.typing.NDArray[numpy.uint8]) -> list[int]:
from itertools import chain
def _mk_single_channel(data: numpy.typing.NDArray[numpy.uint8]) -> list[int]:
return numpy.histogram(data, bins=256, range=(0, 256))[0].tolist()
return list(chain(*[_mk_single_channel(k) for k in data[0, :]]))
@staticmethod @staticmethod
def check_image_quality( def check_image_quality(
datamodule, datamodule,
...@@ -218,11 +228,13 @@ class DatabaseCheckers: ...@@ -218,11 +228,13 @@ class DatabaseCheckers:
for split_name, loader in datamodule.predict_dataloader().items(): for split_name, loader in datamodule.predict_dataloader().items():
for sample in loader: for sample in loader:
ubyte_tensor = (255 * sample[0]).byte().numpy() uint8_array = (255 * sample[0]).byte().numpy()
histogram = numpy.histogram(ubyte_tensor, bins=256, range=(0, 256))[ histogram = DatabaseCheckers._make_histo(uint8_array)
0
].tolist() if sample[1]["name"][0] in reference[split_name]:
ref_histogram = reference[split_name][sample[1]["name"][0]] ref_histogram = reference[split_name].pop(sample[1]["name"][0])
else:
continue
if compare_type == "statistical": if compare_type == "statistical":
# Compute pearson coefficients between histogram and # Compute pearson coefficients between histogram and
...@@ -239,6 +251,13 @@ class DatabaseCheckers: ...@@ -239,6 +251,13 @@ class DatabaseCheckers:
f"reference = {ref_histogram}" f"reference = {ref_histogram}"
) )
# all references must have been consumed
for split, values in reference.items():
assert len(values) == 0, (
f"Not all references at split `{split}` were consumed: {len(values)} "
f"are left"
)
@staticmethod @staticmethod
def write_image_quality_histogram( def write_image_quality_histogram(
datamodule, datamodule,
...@@ -248,14 +267,9 @@ class DatabaseCheckers: ...@@ -248,14 +267,9 @@ class DatabaseCheckers:
for split_name, loader in datamodule.predict_dataloader().items(): for split_name, loader in datamodule.predict_dataloader().items():
data[split_name] = [] data[split_name] = []
for sample in loader: for sample in loader:
ubyte_tensor = (255 * sample[0]).byte().numpy() uint8_array = (255 * sample[0]).byte().numpy()
data[split_name].append( data[split_name].append(
[ [sample[1]["name"][0], DatabaseCheckers._make_histo(uint8_array)]
sample[1]["name"][0],
numpy.histogram(ubyte_tensor, bins=256, range=(0, 256))[
0
].tolist(),
]
) )
with reference_histogram_file.open("w") as f: with reference_histogram_file.open("w") as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment