diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py index cf1946e97b7701abf6a1b557ea3ef3dcaaa2a1eb..ad516e194570dca2b5794959f816a0e717733a9d 100644 --- a/src/ptbench/data/transforms.py +++ b/src/ptbench/data/transforms.py @@ -46,13 +46,15 @@ class ElasticDeformation: def __call__(self, img): if random.random() < self.p: - img = transforms.ToPILImage()(img) + assert img.ndim == 3 + # Input tensor is of shape C x H x W + # If the tensor only contains one channel, this conversion results in H x W. + # With 3 channels, we get H x W x C + img = transforms.ToPILImage()(img) img = numpy.asarray(img) - assert img.ndim == 2 - - shape = img.shape + shape = img.shape[:2] dx = ( gaussian_filter( @@ -81,9 +83,22 @@ class ElasticDeformation: numpy.reshape(y + dy, (-1, 1)), ] result = numpy.empty_like(img) - result[:, :] = map_coordinates( - img[:, :], indices, order=self.spline_order, mode=self.mode - ).reshape(shape) + + if img.ndim == 2: + result[:, :] = map_coordinates( + img[:, :], indices, order=self.spline_order, mode=self.mode + ).reshape(shape) + + else: + for i in range(img.shape[2]): + result[:, :, i] = map_coordinates( + img[:, :, i], + indices, + order=self.spline_order, + mode=self.mode, + ).reshape(shape) + return transforms.ToTensor()(PIL.Image.fromarray(result)) + else: return img