From 751a38bc8cc2406275e4879eb606916d8d2c8834 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Fri, 16 Feb 2024 14:07:44 +0100 Subject: [PATCH] [test] Ignore edge values when comparing histograms --- src/mednet/models/pasa.py | 6 +++++- tests/conftest.py | 15 +++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 38ad0218..e285bf7b 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -75,7 +75,11 @@ class Pasa(pl.LightningModule): self.model_transforms = [ Grayscale(), SquareCenterPad(), - torchvision.transforms.Resize(512, antialias=True), + torchvision.transforms.Resize( + 512, + antialias=True, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ), ] self._train_loss = train_loss diff --git a/tests/conftest.py b/tests/conftest.py index a51252ca..856bf41c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,7 +203,9 @@ class DatabaseCheckers: # __import__("pdb").set_trace() @staticmethod - def check_image_quality(datamodule, reference_histogram_file): + def check_image_quality( + datamodule, reference_histogram_file, histogram_edges_threshold=2 + ): ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) for split_name in ref_histogram_splits: @@ -227,7 +229,16 @@ class DatabaseCheckers: img = to_pil_image(image_tensor) histogram = img.histogram() - assert histogram == ref_hist_data + # The histograms do not exacly match due to the torch resize transform + # acting differently depending on the environment. + assert ( + histogram[ + histogram_edges_threshold:-histogram_edges_threshold + ] + == ref_hist_data[ + histogram_edges_threshold:-histogram_edges_threshold + ] + ) @pytest.fixture -- GitLab