diff --git a/tests/conftest.py b/tests/conftest.py index 77fa6703bb6a4b88f7d03ae9752fffa1be6bea07..f1c8744b0d193d0f6593c173e112315476dc1159 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ import pathlib import typing import numpy +import numpy.typing import pytest import torch from mednet.data.split import JSONDatabaseSplit @@ -205,6 +206,15 @@ class DatabaseCheckers: # to_pil_image(batch[0][0]).show() # __import__("pdb").set_trace() + @staticmethod + def _make_histo(data: numpy.typing.NDArray[numpy.uint8]) -> list[int]: + from itertools import chain + + def _mk_single_channel(data: numpy.typing.NDArray[numpy.uint8]) -> list[int]: + return numpy.histogram(data, bins=256, range=(0, 256))[0].tolist() + + return list(chain(*[_mk_single_channel(k) for k in data[0, :]])) + @staticmethod def check_image_quality( datamodule, @@ -218,11 +228,13 @@ class DatabaseCheckers: for split_name, loader in datamodule.predict_dataloader().items(): for sample in loader: - ubyte_tensor = (255 * sample[0]).byte().numpy() - histogram = numpy.histogram(ubyte_tensor, bins=256, range=(0, 256))[ - 0 - ].tolist() - ref_histogram = reference[split_name][sample[1]["name"][0]] + uint8_array = (255 * sample[0]).byte().numpy() + histogram = DatabaseCheckers._make_histo(uint8_array) + + if sample[1]["name"][0] in reference[split_name]: + ref_histogram = reference[split_name].pop(sample[1]["name"][0]) + else: + continue if compare_type == "statistical": # Compute pearson coefficients between histogram and @@ -239,6 +251,13 @@ class DatabaseCheckers: f"reference = {ref_histogram}" ) + # all references must have been consumed + for split, values in reference.items(): + assert len(values) == 0, ( + f"Not all references at split `{split}` were consumed: {len(values)} " + f"are left" + ) + @staticmethod def write_image_quality_histogram( datamodule, @@ -248,14 +267,9 @@ class DatabaseCheckers: for split_name, loader in datamodule.predict_dataloader().items(): data[split_name] = [] for sample in loader: - ubyte_tensor = (255 * sample[0]).byte().numpy() + uint8_array = (255 * sample[0]).byte().numpy() data[split_name].append( - [ - sample[1]["name"][0], - numpy.histogram(ubyte_tensor, bins=256, range=(0, 256))[ - 0 - ].tolist(), - ] + [sample[1]["name"][0], DatabaseCheckers._make_histo(uint8_array)] ) with reference_histogram_file.open("w") as f: