diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py index 34c3a6050a0611d86e59a488ada342444235c7e6..6d3c17d243a4987f889bae8da0bb6a21ce73459b 100644 --- a/src/ptbench/data/transforms.py +++ b/src/ptbench/data/transforms.py @@ -77,10 +77,10 @@ class ElasticDeformation: self.random_state = random_state self.p = p - self.tensor_transform = transforms.Compose([transforms.ToTensor()]) - def __call__(self, img): if random.random() < self.p: + img = transforms.ToPILImage()(img) + img = numpy.asarray(img) assert img.ndim == 2 @@ -117,6 +117,6 @@ class ElasticDeformation: result[:, :] = map_coordinates( img[:, :], indices, order=self.spline_order, mode=self.mode ).reshape(shape) - return self.tensor_transform(PIL.Image.fromarray(result)) + return transforms.ToTensor()(PIL.Image.fromarray(result)) else: return img diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index c127239d357d556ce22e13eec3f0aa43da2e2521..ed8c4b30a9412ebd2afb78c0b21f5a0ecb26fd91 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -133,10 +133,9 @@ class PASA(pl.LightningModule): def training_step(self, batch, batch_idx): images = batch[1] labels = batch[2] + for img in images: - img = torch.unsqueeze( - self.train_transforms(torch.squeeze(img, 0)), 0 - ) + img = self.train_transforms(img) # Increase label dimension if too low # Allows single and multiclass usage