Skip to content
Snippets Groups Projects
Commit d9105467 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Make ElastiCDeformation work with both greayscale and rgb images

parent 87285a8e
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
...@@ -46,13 +46,15 @@ class ElasticDeformation: ...@@ -46,13 +46,15 @@ class ElasticDeformation:
def __call__(self, img): def __call__(self, img):
if random.random() < self.p: if random.random() < self.p:
img = transforms.ToPILImage()(img) assert img.ndim == 3
# Input tensor is of shape C x H x W
# If the tensor only contains one channel, this conversion results in H x W.
# With 3 channels, we get H x W x C
img = transforms.ToPILImage()(img)
img = numpy.asarray(img) img = numpy.asarray(img)
assert img.ndim == 2 shape = img.shape[:2]
shape = img.shape
dx = ( dx = (
gaussian_filter( gaussian_filter(
...@@ -81,9 +83,22 @@ class ElasticDeformation: ...@@ -81,9 +83,22 @@ class ElasticDeformation:
numpy.reshape(y + dy, (-1, 1)), numpy.reshape(y + dy, (-1, 1)),
] ]
result = numpy.empty_like(img) result = numpy.empty_like(img)
result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode if img.ndim == 2:
).reshape(shape) result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape)
else:
for i in range(img.shape[2]):
result[:, :, i] = map_coordinates(
img[:, :, i],
indices,
order=self.spline_order,
mode=self.mode,
).reshape(shape)
return transforms.ToTensor()(PIL.Image.fromarray(result)) return transforms.ToTensor()(PIL.Image.fromarray(result))
else: else:
return img return img
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment