diff --git a/tests/conftest.py b/tests/conftest.py index 183f238fc3fea2319a0aab3ab55c3cbacf861d69..b51eafa681ca9c95be145a2c5eedea0b1fe5a5b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,7 +203,10 @@ class DatabaseCheckers: @staticmethod def check_image_quality( - datamodule, reference_histogram_file, pearson_coeff_threshold=0.005 + datamodule, + reference_histogram_file, + compare_type="equal", + pearson_coeff_threshold=0.005, ): ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) @@ -226,21 +229,27 @@ class DatabaseCheckers: dataset_sample_index ][0] - image_tensor = numpy.multiply(image_tensor.numpy(), 255).astype( - int - ) - histogram = numpy.histogram( - image_tensor, bins=256, range=(0, 256) - )[0].tolist() - - # We cannot test if histograms are exactly equal because - # the torch.resize transform is inconsistent depending on the environment. - # assert histogram == ref_hist_data - - # Compute pearson coefficients between histogram and reference - # and check the similarity within a certain threshold - pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data) - assert 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 + histogram = [] + for color_channel in image_tensor: + color_channel = numpy.multiply( + color_channel.numpy(), 255 + ).astype(int) + histogram.extend( + numpy.histogram( + color_channel, bins=256, range=(0, 256) + )[0].tolist() + ) + + if compare_type == "statistical": + # Compute pearson coefficients between histogram and reference + # and check the similarity within a certain threshold + pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data) + assert ( + 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 + ) + + else: + assert histogram == ref_hist_data @pytest.fixture diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 2600829f51c63fc6d5dbccf9f812b9a0127b38d6..4ce6258064a64507ccdc2714b9aa16c957961a36 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -95,7 +95,7 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_loaded_image_quality(database_checkers, datadir): +def test_raw_transforms_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_montgomery_default.json" ) @@ -115,8 +115,8 @@ def test_loaded_image_quality(database_checkers, datadir): "model_name", [ "alexnet", - # "densenet", - # "pasa", + "densenet", + "pasa", ], ) def test_model_transforms_image_quality(database_checkers, datadir, model_name): @@ -142,4 +142,10 @@ def test_model_transforms_image_quality(database_checkers, datadir, model_name): datamodule.model_transforms = model.model_transforms datamodule.setup("predict") - database_checkers.check_image_quality(datamodule, reference_histogram_file) + + database_checkers.check_image_quality( + datamodule, + reference_histogram_file, + compare_type="statistical", + pearson_coeff_threshold=0.005, + )