diff --git a/tests/conftest.py b/tests/conftest.py
index 77fa6703bb6a4b88f7d03ae9752fffa1be6bea07..f1c8744b0d193d0f6593c173e112315476dc1159 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,6 +7,7 @@ import pathlib
 import typing
 
 import numpy
+import numpy.typing
 import pytest
 import torch
 from mednet.data.split import JSONDatabaseSplit
@@ -205,6 +206,15 @@ class DatabaseCheckers:
         # to_pil_image(batch[0][0]).show()
         # __import__("pdb").set_trace()
 
+    @staticmethod
+    def _make_histo(data: numpy.typing.NDArray[numpy.uint8]) -> list[int]:
+        from itertools import chain
+
+        def _mk_single_channel(data: numpy.typing.NDArray[numpy.uint8]) -> list[int]:
+            return numpy.histogram(data, bins=256, range=(0, 256))[0].tolist()
+
+        return list(chain(*[_mk_single_channel(k) for k in data[0, :]]))
+
     @staticmethod
     def check_image_quality(
         datamodule,
@@ -218,11 +228,13 @@ class DatabaseCheckers:
 
         for split_name, loader in datamodule.predict_dataloader().items():
             for sample in loader:
-                ubyte_tensor = (255 * sample[0]).byte().numpy()
-                histogram = numpy.histogram(ubyte_tensor, bins=256, range=(0, 256))[
-                    0
-                ].tolist()
-                ref_histogram = reference[split_name][sample[1]["name"][0]]
+                uint8_array = (255 * sample[0]).byte().numpy()
+                histogram = DatabaseCheckers._make_histo(uint8_array)
+
+                if sample[1]["name"][0] in reference[split_name]:
+                    ref_histogram = reference[split_name].pop(sample[1]["name"][0])
+                else:
+                    continue
 
                 if compare_type == "statistical":
                     # Compute pearson coefficients between histogram and
@@ -239,6 +251,13 @@ class DatabaseCheckers:
                         f"reference = {ref_histogram}"
                     )
 
+        # all references must have been consumed
+        for split, values in reference.items():
+            assert len(values) == 0, (
+                f"Not all references at split `{split}` were consumed: {len(values)} "
+                f"are left"
+            )
+
     @staticmethod
     def write_image_quality_histogram(
         datamodule,
@@ -248,14 +267,9 @@ class DatabaseCheckers:
         for split_name, loader in datamodule.predict_dataloader().items():
             data[split_name] = []
             for sample in loader:
-                ubyte_tensor = (255 * sample[0]).byte().numpy()
+                uint8_array = (255 * sample[0]).byte().numpy()
                 data[split_name].append(
-                    [
-                        sample[1]["name"][0],
-                        numpy.histogram(ubyte_tensor, bins=256, range=(0, 256))[
-                            0
-                        ].tolist(),
-                    ]
+                    [sample[1]["name"][0], DatabaseCheckers._make_histo(uint8_array)]
                 )
 
         with reference_histogram_file.open("w") as f: