From b09eb8984fb64bc75daf9ede2a1190fb0463491e Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 18 Jun 2024 10:19:16 +0200
Subject: [PATCH] [segmentation.transforms] Set interpolation to nearest during
 resize

---
 src/mednet/libs/common/models/transforms.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/mednet/libs/common/models/transforms.py b/src/mednet/libs/common/models/transforms.py
index 5fa2d68f..84f3dfbe 100644
--- a/src/mednet/libs/common/models/transforms.py
+++ b/src/mednet/libs/common/models/transforms.py
@@ -75,6 +75,8 @@ def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor:
         The resized image.
     """
 
+    from torchvision.transforms import InterpolationMode
+
     if max_side <= 0:
         raise ValueError(f"The new max side ({max_side}) must be positive.")
 
@@ -86,7 +88,9 @@ def resize_max_side(tensor: torch.Tensor, max_side: int) -> torch.Tensor:
     else:
         new_size = (int(max_side * aspect_ratio), max_side)
 
-    return torchvision.transforms.Resize(new_size, antialias=True)(tensor)
+    return torchvision.transforms.Resize(
+        new_size, interpolation=InterpolationMode.NEAREST, antialias=True
+    )(tensor)
 
 
 def square_center_pad(img: torch.Tensor) -> torch.Tensor:
-- 
GitLab