Skip to content
Snippets Groups Projects
test_transforms.py 2.34 KiB
Newer Older
André Anjos's avatar
André Anjos committed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for transforms."""

import numpy
import PIL.Image
import torchvision.transforms.functional as F  # noqa: N812
from mednet.libs.common.data.augmentations import ElasticDeformation
from mednet.libs.common.models.transforms import (
    crop_image_to_mask,
    resize_max_side,
)


def test_crop_mask():
    original_tensor_size = (3, 50, 100)
    original_mask_size = (1, 50, 100)
    slice_ = (slice(None), slice(10, 30), slice(50, 70))

    tensor = torch.rand(original_tensor_size)
    mask = torch.zeros(original_mask_size)
    mask[slice_] = 1

    cropped_tensor = crop_image_to_mask(tensor, mask)

    assert cropped_tensor.shape == (3, 20, 20)
    assert torch.all(cropped_tensor.eq(tensor[slice_]))


def test_resize_max_size():
    original_size = (3, 50, 100)
    original_ratio = original_size[1] / original_size[2]

    new_max_side = 120
    tensor = torch.rand(original_size)

    resized_tensor = resize_max_side(tensor, new_max_side)
    resized_ratio = resized_tensor.shape[1] / resized_tensor.shape[2]
    assert original_ratio == resized_ratio

    transposed_tensor = tensor.transpose(1, 2)

    resized_transposed_tensor = resize_max_side(transposed_tensor, new_max_side)
    inv_ratio = 1 / (
        resized_transposed_tensor.shape[1] / resized_transposed_tensor.shape[2]
    )
    assert original_ratio == inv_ratio

    assert resized_tensor.shape[1] == resized_transposed_tensor.shape[2]
    assert resized_tensor.shape[2] == resized_transposed_tensor.shape[1]
André Anjos's avatar
André Anjos committed


def test_elastic_deformation(datadir):
    # Get a raw sample without deformation
    data_file = str(datadir / "raw_without_elastic_deformation.png")
    raw_without_deformation = F.to_tensor(PIL.Image.open(data_file))
André Anjos's avatar
André Anjos committed

    # Elastic deforms the raw
    numpy.random.seed(seed=100)
    ed = ElasticDeformation()
André Anjos's avatar
André Anjos committed
    raw_deformed = ed(raw_without_deformation)

    # Get the same sample already deformed (with seed=100)
    data_file_2 = str(datadir / "raw_with_elastic_deformation.png")
    raw_2 = PIL.Image.open(data_file_2)

    # Compare both
    raw_deformed = (255 * numpy.asarray(raw_deformed)).astype(numpy.uint8)[
André Anjos's avatar
André Anjos committed
    raw_2 = numpy.asarray(raw_2)

    numpy.testing.assert_array_equal(raw_deformed, raw_2)