diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py
index 34c3a6050a0611d86e59a488ada342444235c7e6..6d3c17d243a4987f889bae8da0bb6a21ce73459b 100644
--- a/src/ptbench/data/transforms.py
+++ b/src/ptbench/data/transforms.py
@@ -77,10 +77,10 @@ class ElasticDeformation:
         self.random_state = random_state
         self.p = p
 
-        self.tensor_transform = transforms.Compose([transforms.ToTensor()])
-
     def __call__(self, img):
         if random.random() < self.p:
+            img = transforms.ToPILImage()(img)
+
             img = numpy.asarray(img)
 
             assert img.ndim == 2
@@ -117,6 +117,6 @@ class ElasticDeformation:
             result[:, :] = map_coordinates(
                 img[:, :], indices, order=self.spline_order, mode=self.mode
             ).reshape(shape)
-            return self.tensor_transform(PIL.Image.fromarray(result))
+            return transforms.ToTensor()(PIL.Image.fromarray(result))
         else:
             return img
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index c127239d357d556ce22e13eec3f0aa43da2e2521..ed8c4b30a9412ebd2afb78c0b21f5a0ecb26fd91 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -133,10 +133,9 @@ class PASA(pl.LightningModule):
     def training_step(self, batch, batch_idx):
         images = batch[1]
         labels = batch[2]
+
         for img in images:
-            img = torch.unsqueeze(
-                self.train_transforms(torch.squeeze(img, 0)), 0
-            )
+            img = self.train_transforms(img)
 
         # Increase label dimension if too low
         # Allows single and multiclass usage