From d910546742e4933c2ade9ac297cee4ae8bfb7f71 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 5 Jul 2023 12:42:58 +0200 Subject: [PATCH] Make ElastiCDeformation work with both greayscale and rgb images --- src/ptbench/data/transforms.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py index cf1946e9..ad516e19 100644 --- a/src/ptbench/data/transforms.py +++ b/src/ptbench/data/transforms.py @@ -46,13 +46,15 @@ class ElasticDeformation: def __call__(self, img): if random.random() < self.p: - img = transforms.ToPILImage()(img) + assert img.ndim == 3 + # Input tensor is of shape C x H x W + # If the tensor only contains one channel, this conversion results in H x W. + # With 3 channels, we get H x W x C + img = transforms.ToPILImage()(img) img = numpy.asarray(img) - assert img.ndim == 2 - - shape = img.shape + shape = img.shape[:2] dx = ( gaussian_filter( @@ -81,9 +83,22 @@ class ElasticDeformation: numpy.reshape(y + dy, (-1, 1)), ] result = numpy.empty_like(img) - result[:, :] = map_coordinates( - img[:, :], indices, order=self.spline_order, mode=self.mode - ).reshape(shape) + + if img.ndim == 2: + result[:, :] = map_coordinates( + img[:, :], indices, order=self.spline_order, mode=self.mode + ).reshape(shape) + + else: + for i in range(img.shape[2]): + result[:, :, i] = map_coordinates( + img[:, :, i], + indices, + order=self.spline_order, + mode=self.mode, + ).reshape(shape) + return transforms.ToTensor()(PIL.Image.fromarray(result)) + else: return img -- GitLab