diff --git a/src/mednet/libs/common/models/transforms.py b/src/mednet/libs/common/models/transforms.py index 5fa2d68fe8595504cde5062226df3053c7133ae0..84f3dfbe1e16503d108a11176b6929643aa8f7e4 100644 --- a/src/mednet/libs/common/models/transforms.py +++ b/src/mednet/libs/common/models/transforms.py @@ -75,6 +75,8 @@ def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor: The resized image. """ + from torchvision.transforms import InterpolationMode + if max_side <= 0: raise ValueError(f"The new max side ({max_side}) must be positive.") @@ -86,7 +88,9 @@ def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor: else: new_size = (int(max_side * aspect_ratio), max_side) - return torchvision.transforms.Resize(new_size, antialias=True)(tensor) + return torchvision.transforms.Resize( + new_size, interpolation=InterpolationMode.NEAREST, antialias=True + )(tensor) def square_center_pad(img: torch.Tensor) -> torch.Tensor: