diff --git a/tests/test_transforms.py b/tests/test_transforms.py index a8f1d71479c7c954c7459f6b85ddbb0d00e0bef5..33e80bb9a8fe30cbacbd1241e65b7ee9694846a8 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -8,10 +8,7 @@ import PIL.Image import torch import torchvision.transforms.functional as F # noqa: N812 from mednet.libs.common.data.augmentations import ElasticDeformation -from mednet.libs.common.models.transforms import ( - crop_image_to_mask, - resize_max_side, -) +from mednet.libs.common.models.transforms import crop_image_to_mask def test_crop_mask(): @@ -29,29 +26,6 @@ def test_crop_mask(): assert torch.all(cropped_tensor.eq(tensor[slice_])) -def test_resize_max_size(): - original_size = (3, 50, 100) - original_ratio = original_size[1] / original_size[2] - - new_max_side = 120 - tensor = torch.rand(original_size) - - resized_tensor = resize_max_side(tensor, new_max_side) - resized_ratio = resized_tensor.shape[1] / resized_tensor.shape[2] - assert original_ratio == resized_ratio - - transposed_tensor = tensor.transpose(1, 2) - - resized_transposed_tensor = resize_max_side(transposed_tensor, new_max_side) - inv_ratio = 1 / ( - resized_transposed_tensor.shape[1] / resized_transposed_tensor.shape[2] - ) - assert original_ratio == inv_ratio - - assert resized_tensor.shape[1] == resized_transposed_tensor.shape[2] - assert resized_tensor.shape[2] == resized_transposed_tensor.shape[1] - - def test_elastic_deformation(datadir): # Get a raw sample without deformation data_file = str(datadir / "raw_without_elastic_deformation.png")