Skip to content
Snippets Groups Projects
Commit 4df399c3 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Move augmented images to the correct device

parent 23f20e94
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75558 failed
......@@ -196,7 +196,9 @@ class PASA(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
augmented_images = [self.augmentation_transforms(img) for img in images]
augmented_images = [
self.augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment