From 4df399c38c727bf3cef9794fe3a73f995ed677b2 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 4 Jul 2023 16:42:26 +0200
Subject: [PATCH] Move augmented images to the correct device

---
 src/ptbench/models/pasa.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 4fe965b4..e257a4fc 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -196,7 +196,9 @@ class PASA(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        augmented_images = [self.augmentation_transforms(img) for img in images]
+        augmented_images = [
+            self.augmentation_transforms(img).to(self.device) for img in images
+        ]
         # Combine list of augmented images back into a tensor
         augmented_images = torch.cat(augmented_images, 0).view(images.shape)
         outputs = self(augmented_images)
-- 
GitLab