Skip to content
Snippets Groups Projects
Commit 64fa5458 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Preserve number of channels when augmenting images

parent 759c7e04
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -197,7 +197,8 @@ class PASA(pl.LightningModule): ...@@ -197,7 +197,8 @@ class PASA(pl.LightningModule):
# Forward pass on the network # Forward pass on the network
augmented_images = [self.augmentation_transforms(img) for img in images] augmented_images = [self.augmentation_transforms(img) for img in images]
augmented_images = torch.unsqueeze(torch.cat(augmented_images, 0), 1) # Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images) outputs = self(augmented_images)
training_loss = self.criterion(outputs, labels.double()) training_loss = self.criterion(outputs, labels.double())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment