# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later """Tests for transforms.""" import numpy import PIL.Image import torch import torchvision.transforms.functional as F # noqa: N812 from mednet.libs.common.data.augmentations import ElasticDeformation from mednet.libs.common.models.transforms import ( crop_image_to_mask, resize_max_side, ) def test_crop_mask(): original_tensor_size = (3, 50, 100) original_mask_size = (1, 50, 100) slice_ = (slice(None), slice(10, 30), slice(50, 70)) tensor = torch.rand(original_tensor_size) mask = torch.zeros(original_mask_size) mask[slice_] = 1 cropped_tensor = crop_image_to_mask(tensor, mask) assert cropped_tensor.shape == (3, 20, 20) assert torch.all(cropped_tensor.eq(tensor[slice_])) def test_resize_max_size(): original_size = (3, 50, 100) original_ratio = original_size[1] / original_size[2] new_max_side = 120 tensor = torch.rand(original_size) resized_tensor = resize_max_side(tensor, new_max_side) resized_ratio = resized_tensor.shape[1] / resized_tensor.shape[2] assert original_ratio == resized_ratio transposed_tensor = tensor.transpose(1, 2) resized_transposed_tensor = resize_max_side(transposed_tensor, new_max_side) inv_ratio = 1 / ( resized_transposed_tensor.shape[1] / resized_transposed_tensor.shape[2] ) assert original_ratio == inv_ratio assert resized_tensor.shape[1] == resized_transposed_tensor.shape[2] assert resized_tensor.shape[2] == resized_transposed_tensor.shape[1] def test_elastic_deformation(datadir): # Get a raw sample without deformation data_file = str(datadir / "raw_without_elastic_deformation.png") raw_without_deformation = F.to_tensor(PIL.Image.open(data_file)) # Elastic deforms the raw numpy.random.seed(seed=100) ed = ElasticDeformation() raw_deformed = ed(raw_without_deformation) # Get the same sample already deformed (with seed=100) data_file_2 = str(datadir / "raw_with_elastic_deformation.png") raw_2 = PIL.Image.open(data_file_2) # Compare both raw_deformed = (255 * numpy.asarray(raw_deformed)).astype(numpy.uint8)[ 0, :, :, ] raw_2 = numpy.asarray(raw_2) numpy.testing.assert_array_equal(raw_deformed, raw_2)