diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index e93b4e61f3bd5c9a0259e7d21728362f8e26999a..4fe965b4c44f852c382c37b7c5345fc735e28050 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -197,7 +197,8 @@ class PASA(pl.LightningModule): # Forward pass on the network augmented_images = [self.augmentation_transforms(img) for img in images] - augmented_images = torch.unsqueeze(torch.cat(augmented_images, 0), 1) + # Combine list of augmented images back into a tensor + augmented_images = torch.cat(augmented_images, 0).view(images.shape) outputs = self(augmented_images) training_loss = self.criterion(outputs, labels.double())