From 4df399c38c727bf3cef9794fe3a73f995ed677b2 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 4 Jul 2023 16:42:26 +0200 Subject: [PATCH] Move augmented images to the correct device --- src/ptbench/models/pasa.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 4fe965b4..e257a4fc 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -196,7 +196,9 @@ class PASA(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - augmented_images = [self.augmentation_transforms(img) for img in images] + augmented_images = [ + self.augmentation_transforms(img).to(self.device) for img in images + ] # Combine list of augmented images back into a tensor augmented_images = torch.cat(augmented_images, 0).view(images.shape) outputs = self(augmented_images) -- GitLab