Skip to content
Snippets Groups Projects
test_tranforms.py 2.42 KiB
Newer Older
André Anjos's avatar
André Anjos committed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Tests for transforms."""

import numpy
import PIL.Image

from ptbench.data.loader import load_pil
from ptbench.data.transforms import (
    ElasticDeformation,
    RemoveBlackBorders,
    SingleAutoLevel16to8,
)


def test_remove_black_borders(datadir):
    # Get a raw sample with black border
    data_file = str(datadir / "raw_with_black_border.png")
    raw_with_black_border = PIL.Image.open(data_file)

    # Remove the black border
    rbb = RemoveBlackBorders()
    raw_rbb_removed = rbb(raw_with_black_border)

    # Get the same sample without black border
    data_file_2 = str(datadir / "raw_without_black_border.png")
    raw_without_black_border = PIL.Image.open(data_file_2)

    # Compare both
    raw_rbb_removed = numpy.asarray(raw_rbb_removed)
    raw_without_black_border = numpy.asarray(raw_without_black_border)

    numpy.testing.assert_array_equal(raw_without_black_border, raw_rbb_removed)


def test_elastic_deformation(datadir):
    # Get a raw sample without deformation
    data_file = str(datadir / "raw_without_elastic_deformation.png")
    raw_without_deformation = PIL.Image.open(data_file)

    # Elastic deforms the raw
    ed = ElasticDeformation(random_state=numpy.random.RandomState(seed=100))
    raw_deformed = ed(raw_without_deformation)

    # Get the same sample already deformed (with seed=100)
    data_file_2 = str(datadir / "raw_with_elastic_deformation.png")
    raw_2 = PIL.Image.open(data_file_2)

    # Compare both
    raw_deformed = numpy.asarray(raw_deformed)
    raw_2 = numpy.asarray(raw_2)

    numpy.testing.assert_array_equal(raw_deformed, raw_2)


def test_load_pil_16bit(datadir):
    # If the ratio is higher 0.5, image is probably clipped
    Level16to8 = SingleAutoLevel16to8()

    data_file = str(datadir / "16bits.png")
    image = numpy.array(Level16to8(load_pil(data_file)))

    count_pixels = numpy.count_nonzero(image)
    count_max_value = numpy.count_nonzero(image == image.max())

    assert count_max_value / count_pixels < 0.5

    # It should not do anything to an image already in 8 bits
    data_file = str(datadir / "raw_without_black_border.png")
    img_loaded = load_pil(data_file)

    original_8bits = numpy.array(img_loaded)
    leveled_8bits = numpy.array(Level16to8(img_loaded))

    numpy.testing.assert_array_equal(original_8bits, leveled_8bits)