diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 4fe965b4c44f852c382c37b7c5345fc735e28050..e257a4fc54190c778fdaa6c8d36522fa713fbd76 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -196,7 +196,9 @@ class PASA(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        augmented_images = [self.augmentation_transforms(img) for img in images]
+        augmented_images = [
+            self.augmentation_transforms(img).to(self.device) for img in images
+        ]
         # Combine list of augmented images back into a tensor
         augmented_images = torch.cat(augmented_images, 0).view(images.shape)
         outputs = self(augmented_images)