From d910546742e4933c2ade9ac297cee4ae8bfb7f71 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 5 Jul 2023 12:42:58 +0200
Subject: [PATCH] Make ElastiCDeformation work with both greayscale and rgb
 images

---
 src/ptbench/data/transforms.py | 29 ++++++++++++++++++++++-------
 1 file changed, 22 insertions(+), 7 deletions(-)

diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py
index cf1946e9..ad516e19 100644
--- a/src/ptbench/data/transforms.py
+++ b/src/ptbench/data/transforms.py
@@ -46,13 +46,15 @@ class ElasticDeformation:
 
     def __call__(self, img):
         if random.random() < self.p:
-            img = transforms.ToPILImage()(img)
+            assert img.ndim == 3
 
+            # Input tensor is of shape C x H x W
+            # If the tensor only contains one channel, this conversion results in H x W.
+            # With 3 channels, we get H x W x C
+            img = transforms.ToPILImage()(img)
             img = numpy.asarray(img)
 
-            assert img.ndim == 2
-
-            shape = img.shape
+            shape = img.shape[:2]
 
             dx = (
                 gaussian_filter(
@@ -81,9 +83,22 @@ class ElasticDeformation:
                 numpy.reshape(y + dy, (-1, 1)),
             ]
             result = numpy.empty_like(img)
-            result[:, :] = map_coordinates(
-                img[:, :], indices, order=self.spline_order, mode=self.mode
-            ).reshape(shape)
+
+            if img.ndim == 2:
+                result[:, :] = map_coordinates(
+                    img[:, :], indices, order=self.spline_order, mode=self.mode
+                ).reshape(shape)
+
+            else:
+                for i in range(img.shape[2]):
+                    result[:, :, i] = map_coordinates(
+                        img[:, :, i],
+                        indices,
+                        order=self.spline_order,
+                        mode=self.mode,
+                    ).reshape(shape)
+
             return transforms.ToTensor()(PIL.Image.fromarray(result))
+
         else:
             return img
-- 
GitLab