From 23f20e94c9ffce712bf61097103f7fbae02e6dc2 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 4 Jul 2023 12:06:21 +0200
Subject: [PATCH] Preserve number of channels when augmenting images

---
 src/ptbench/models/pasa.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index e93b4e61..4fe965b4 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -197,7 +197,8 @@ class PASA(pl.LightningModule):
 
         # Forward pass on the network
         augmented_images = [self.augmentation_transforms(img) for img in images]
-        augmented_images = torch.unsqueeze(torch.cat(augmented_images, 0), 1)
+        # Combine list of augmented images back into a tensor
+        augmented_images = torch.cat(augmented_images, 0).view(images.shape)
         outputs = self(augmented_images)
 
         training_loss = self.criterion(outputs, labels.double())
-- 
GitLab