From b09eb8984fb64bc75daf9ede2a1190fb0463491e Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 18 Jun 2024 10:19:16 +0200 Subject: [PATCH] [segmentation.transforms] Set interpolation to nearest during resize --- src/mednet/libs/common/models/transforms.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mednet/libs/common/models/transforms.py b/src/mednet/libs/common/models/transforms.py index 5fa2d68f..84f3dfbe 100644 --- a/src/mednet/libs/common/models/transforms.py +++ b/src/mednet/libs/common/models/transforms.py @@ -75,6 +75,8 @@ def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor: The resized image. """ + from torchvision.transforms import InterpolationMode + if max_side <= 0: raise ValueError(f"The new max side ({max_side}) must be positive.") @@ -86,7 +88,9 @@ def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor: else: new_size = (int(max_side * aspect_ratio), max_side) - return torchvision.transforms.Resize(new_size, antialias=True)(tensor) + return torchvision.transforms.Resize( + new_size, interpolation=InterpolationMode.NEAREST, antialias=True + )(tensor) def square_center_pad(img: torch.Tensor) -> torch.Tensor: -- GitLab