diff --git a/src/mednet/libs/common/models/transforms.py b/src/mednet/libs/common/models/transforms.py
index 5fa2d68fe8595504cde5062226df3053c7133ae0..84f3dfbe1e16503d108a11176b6929643aa8f7e4 100644
--- a/src/mednet/libs/common/models/transforms.py
+++ b/src/mednet/libs/common/models/transforms.py
@@ -75,6 +75,8 @@ def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor:
         The resized image.
     """
 
+    from torchvision.transforms import InterpolationMode
+
     if max_side <= 0:
         raise ValueError(f"The new max side ({max_side}) must be positive.")
 
@@ -86,7 +88,9 @@ def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor:
     else:
         new_size = (int(max_side * aspect_ratio), max_side)
 
-    return torchvision.transforms.Resize(new_size, antialias=True)(tensor)
+    return torchvision.transforms.Resize(
+        new_size, interpolation=InterpolationMode.NEAREST, antialias=True
+    )(tensor)
 
 
 def square_center_pad(img: torch.Tensor) -> torch.Tensor: