diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py
index cf1946e97b7701abf6a1b557ea3ef3dcaaa2a1eb..ad516e194570dca2b5794959f816a0e717733a9d 100644
--- a/src/ptbench/data/transforms.py
+++ b/src/ptbench/data/transforms.py
@@ -46,13 +46,15 @@ class ElasticDeformation:
 
     def __call__(self, img):
         if random.random() < self.p:
-            img = transforms.ToPILImage()(img)
+            assert img.ndim == 3
 
+            # Input tensor is of shape C x H x W
+            # If the tensor only contains one channel, this conversion results in H x W.
+            # With 3 channels, we get H x W x C
+            img = transforms.ToPILImage()(img)
             img = numpy.asarray(img)
 
-            assert img.ndim == 2
-
-            shape = img.shape
+            shape = img.shape[:2]
 
             dx = (
                 gaussian_filter(
@@ -81,9 +83,22 @@ class ElasticDeformation:
                 numpy.reshape(y + dy, (-1, 1)),
             ]
             result = numpy.empty_like(img)
-            result[:, :] = map_coordinates(
-                img[:, :], indices, order=self.spline_order, mode=self.mode
-            ).reshape(shape)
+
+            if img.ndim == 2:
+                result[:, :] = map_coordinates(
+                    img[:, :], indices, order=self.spline_order, mode=self.mode
+                ).reshape(shape)
+
+            else:
+                for i in range(img.shape[2]):
+                    result[:, :, i] = map_coordinates(
+                        img[:, :, i],
+                        indices,
+                        order=self.spline_order,
+                        mode=self.mode,
+                    ).reshape(shape)
+
             return transforms.ToTensor()(PIL.Image.fromarray(result))
+
         else:
             return img