diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 4fe965b4c44f852c382c37b7c5345fc735e28050..e257a4fc54190c778fdaa6c8d36522fa713fbd76 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -196,7 +196,9 @@ class PASA(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - augmented_images = [self.augmentation_transforms(img) for img in images] + augmented_images = [ + self.augmentation_transforms(img).to(self.device) for img in images + ] # Combine list of augmented images back into a tensor augmented_images = torch.cat(augmented_images, 0).view(images.shape) outputs = self(augmented_images)