From dd849de43e5fb91ef2533aa03a490a59e826149b Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 14 Jun 2023 14:57:00 +0200
Subject: [PATCH] Cleaner way to apply train_transforms

---
 src/ptbench/data/transforms.py | 6 +++---
 src/ptbench/models/pasa.py     | 5 ++---
 2 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py
index 34c3a605..6d3c17d2 100644
--- a/src/ptbench/data/transforms.py
+++ b/src/ptbench/data/transforms.py
@@ -77,10 +77,10 @@ class ElasticDeformation:
         self.random_state = random_state
         self.p = p
 
-        self.tensor_transform = transforms.Compose([transforms.ToTensor()])
-
     def __call__(self, img):
         if random.random() < self.p:
+            img = transforms.ToPILImage()(img)
+
             img = numpy.asarray(img)
 
             assert img.ndim == 2
@@ -117,6 +117,6 @@ class ElasticDeformation:
             result[:, :] = map_coordinates(
                 img[:, :], indices, order=self.spline_order, mode=self.mode
             ).reshape(shape)
-            return self.tensor_transform(PIL.Image.fromarray(result))
+            return transforms.ToTensor()(PIL.Image.fromarray(result))
         else:
             return img
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index c127239d..ed8c4b30 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -133,10 +133,9 @@ class PASA(pl.LightningModule):
     def training_step(self, batch, batch_idx):
         images = batch[1]
         labels = batch[2]
+
         for img in images:
-            img = torch.unsqueeze(
-                self.train_transforms(torch.squeeze(img, 0)), 0
-            )
+            img = self.train_transforms(img)
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
-- 
GitLab